diff --git a/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/context/StreamHandler.java b/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/context/StreamHandler.java index fb8564f6d..92c60dad7 100644 --- a/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/context/StreamHandler.java +++ b/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/context/StreamHandler.java @@ -1,7 +1,7 @@ package com.yomahub.liteflow.ai.context; import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; -import io.reactivex.rxjava3.core.Flowable; +import org.reactivestreams.Publisher; /** * 响应式流处理器 @@ -37,7 +37,7 @@ public interface StreamHandler { * @param eventStream 原始的 ChunkEvent 流 * @return 处理后的 ChunkEvent 流 */ - Flowable handle(Flowable eventStream); + Publisher handle(Publisher eventStream); /** * 创建一个 pass-through 处理器,不做任何转换直接返回原始流 @@ -56,7 +56,7 @@ public interface StreamHandler { */ static StreamHandler composite(StreamHandler... handlers) { return eventStream -> { - Flowable result = eventStream; + Publisher result = eventStream; for (StreamHandler handler : handlers) { result = handler.handle(result); } diff --git a/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/proxy/invocation/ChatAIInvocationHandler.java b/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/proxy/invocation/ChatAIInvocationHandler.java index 49088cbf8..5602b0c52 100644 --- a/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/proxy/invocation/ChatAIInvocationHandler.java +++ b/liteflow-ai/liteflow-ai-core/src/main/java/com/yomahub/liteflow/ai/proxy/invocation/ChatAIInvocationHandler.java @@ -9,6 +9,7 @@ import com.yomahub.liteflow.ai.model.ModelFactory; import com.yomahub.liteflow.ai.parse.context.ProcessorContext; import com.yomahub.liteflow.ai.proxy.wrap.ChatProxyWrapBean; import io.reactivex.rxjava3.core.Flowable; +import org.reactivestreams.Publisher; import java.util.Objects; @@ -51,12 +52,13 @@ public class ChatAIInvocationHandler extends AbstractAIInvocationHandler eventStream = chatModel.stream(chatRequest); + Publisher eventStream = chatModel.stream(chatRequest); // 应用用户的 StreamHandler 进行响应式转换 - Flowable handledStream = streamHandler.handle(eventStream); + Publisher handledStream = streamHandler.handle(eventStream); - return handledStream.blockingLast() + return Flowable.fromPublisher(handledStream) + .blockingLast() .getFinalResponse(); } } diff --git a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/InteractClient.java b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/InteractClient.java index 26215112a..b3ce48fcc 100644 --- a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/InteractClient.java +++ b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/InteractClient.java @@ -4,7 +4,7 @@ import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse; -import io.reactivex.rxjava3.core.Flowable; +import org.reactivestreams.Publisher; import java.util.concurrent.CompletableFuture; @@ -17,7 +17,7 @@ import java.util.concurrent.CompletableFuture; public interface InteractClient { - Flowable stream(ChatConfig config, ChatRequest request); + Publisher stream(ChatConfig config, ChatRequest request); ChatResponse chat(ChatConfig config, ChatRequest request); diff --git a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/LlmInteractClient.java b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/LlmInteractClient.java index 7ac8eaf1e..507ceb8e3 100644 --- a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/LlmInteractClient.java +++ b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/LlmInteractClient.java @@ -2,9 +2,9 @@ package com.yomahub.liteflow.ai.engine.interact; import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer; import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformerFactory; -import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; import com.yomahub.liteflow.ai.engine.interact.transport.Transport; import com.yomahub.liteflow.ai.engine.log.EngineLog; import com.yomahub.liteflow.ai.engine.log.EngineLogManager; @@ -18,6 +18,7 @@ import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; +import org.reactivestreams.Publisher; import java.util.List; import java.util.Objects; @@ -93,7 +94,7 @@ public class LlmInteractClient implements InteractClient { * @param request 聊天请求 * @return 包含 ChunkEvent 的流 */ - public Flowable stream(ChatConfig config, ChatRequest request) { + public Publisher stream(ChatConfig config, ChatRequest request) { InteractContext context = new InteractContext(); ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider()); diff --git a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/chunk/StreamingProtocolType.java b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/chunk/StreamingProtocolType.java index 02b8a6188..9fc768233 100644 --- a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/chunk/StreamingProtocolType.java +++ b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/interact/chunk/StreamingProtocolType.java @@ -29,6 +29,8 @@ public enum StreamingProtocolType { BASE64_IMAGE(7, "base64_image"), // 未知类型数据 DATA(8, "未知数据"), + // workflow 数据(coze 和 dashscope 返回值) + WORKFLOW_DATA(9, "workflow_data") ; private final Integer code; private final String desc; diff --git a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/AbstractChatModel.java b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/AbstractChatModel.java index bc3a412dd..40c527401 100644 --- a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/AbstractChatModel.java +++ b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/AbstractChatModel.java @@ -6,7 +6,7 @@ import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse; -import io.reactivex.rxjava3.core.Flowable; +import org.reactivestreams.Publisher; import java.time.Duration; import java.util.Map; @@ -45,7 +45,7 @@ public abstract class AbstractChatModel implements ChatModel { } @Override - public Flowable stream(ChatRequest request) { + public Publisher stream(ChatRequest request) { return interactClient.stream(config, request); } diff --git a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/ChatModel.java b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/ChatModel.java index 1037cb7ec..319d2a9f1 100644 --- a/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/ChatModel.java +++ b/liteflow-ai/liteflow-ai-engine/src/main/java/com/yomahub/liteflow/ai/engine/model/chat/ChatModel.java @@ -5,7 +5,7 @@ import com.yomahub.liteflow.ai.engine.model.BaseModel; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse; -import io.reactivex.rxjava3.core.Flowable; +import org.reactivestreams.Publisher; import java.util.concurrent.CompletableFuture; @@ -46,5 +46,5 @@ public interface ChatModel extends BaseModel { * @param request 聊天请求 * @return 包含所有流式事件的 Flowable 流 */ - Flowable stream(ChatRequest request); + Publisher stream(ChatRequest request); } diff --git a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowChatInvocationHandler.java b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowChatInvocationHandler.java index d54ab17d5..0b505171b 100644 --- a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowChatInvocationHandler.java +++ b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowChatInvocationHandler.java @@ -1,8 +1,29 @@ package com.yomahub.liteflow.ai.workflow.coze.invocation; +import cn.hutool.core.util.StrUtil; +import com.coze.openapi.client.chat.model.ChatEvent; +import com.coze.openapi.client.workflows.chat.WorkflowChatReq; +import com.coze.openapi.service.auth.TokenAuth; +import com.coze.openapi.service.config.Consts; +import com.coze.openapi.service.service.CozeAPI; +import com.yomahub.liteflow.ai.context.StreamHandler; +import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; +import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType; import com.yomahub.liteflow.ai.parse.context.ProcessorContext; import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler; +import com.yomahub.liteflow.ai.util.SpringUtil; +import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowChat; +import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty; +import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil; import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowChatProxyWrapBean; +import io.reactivex.Flowable; + +import java.util.List; +import java.util.Map; + +import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent; /** * Coze 对话流 调用处理器 @@ -18,61 +39,69 @@ public class CozeWorkflowChatInvocationHandler extends AbstractAIInvocationHandl @Override protected Object doExecuteAIProcess(ProcessorContext processorContext, Object[] args) { + CozeWorkflowChat annotation = wrapBean.getAnnotation(); + + CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class); + + CozeAPI api = new CozeAPI.Builder() + .auth(new TokenAuth(property.getApiKey())) + .baseURL( + StrUtil.isBlank(property.getBaseUrl()) ? + Consts.COZE_CN_BASE_URL : + property.getBaseUrl() + ) + .build(); + + WorkflowChatReq.WorkflowChatReqBuilder builder = WorkflowChatReq.builder(); + + setIfPresent(builder::workflowID, annotation.workflowId()); + + setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class); + + if (annotation.parameters().length > 0) { + Map paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext); + if (!paramMap.isEmpty()) { + builder.parameters(paramMap); + } + } + + setIfPresent(builder::appID, annotation.appId()); + + setIfPresent(builder::botID, annotation.botId()); + + setIfPresent(builder::conversationID, annotation.conversationId()); + + if (annotation.ext().length > 0) { + Map extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext); + if (!extMap.isEmpty()) { + builder.ext(extMap); + } + } + + setIfPresent(builder::connectTimeout, annotation.connectTimeout()); + + setIfPresent(builder::readTimeout, annotation.readTimeout()); + + setIfPresent(builder::writeTimeout, annotation.writeTimeout()); + + setIfPresent(builder::customerToken, annotation.customerToken()); + + Flowable res = api.workflows().chat().stream(builder.build()); + InteractContext context = new InteractContext(); + StreamHandler streamHandler = processorContext.getChatContext().getStreamHandler(); + io.reactivex.rxjava3.core.Flowable.fromPublisher( + streamHandler.handle( + res.map(workflowEvent -> { + StreamingProtocolChunk chunk = new StreamingProtocolChunk(); + chunk.setData(workflowEvent); + chunk.setType(StreamingProtocolType.WORKFLOW_DATA); + chunk.setId(context.getChatId()); + return ChunkEvent.chunk(workflowEvent.getMessage().getContent(), + chunk, context); + }) + ) + ).blockingLast(); return null; -// CozeWorkflowChat annotation = wrapBean.getAnnotation(); -// -// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class); -// -// CozeAPI api = new CozeAPI.Builder() -// .auth(new TokenAuth(property.getApiKey())) -// .baseURL( -// StrUtil.isBlank(property.getBaseUrl()) ? -// Consts.COZE_CN_BASE_URL : -// property.getBaseUrl() -// ) -// .build(); -// -// WorkflowChatReq.WorkflowChatReqBuilder builder = WorkflowChatReq.builder(); -// -// setIfPresent(builder::workflowID, annotation.workflowId()); -// -// setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class); -// -// if (annotation.parameters().length > 0) { -// Map paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext); -// if (!paramMap.isEmpty()) { -// builder.parameters(paramMap); -// } -// } -// -// setIfPresent(builder::appID, annotation.appId()); -// -// setIfPresent(builder::botID, annotation.botId()); -// -// setIfPresent(builder::conversationID, annotation.conversationId()); -// -// if (annotation.ext().length > 0) { -// Map extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext); -// if (!extMap.isEmpty()) { -// builder.ext(extMap); -// } -// } -// -// setIfPresent(builder::connectTimeout, annotation.connectTimeout()); -// -// setIfPresent(builder::readTimeout, annotation.readTimeout()); -// -// setIfPresent(builder::writeTimeout, annotation.writeTimeout()); -// -// setIfPresent(builder::customerToken, annotation.customerToken()); -// -// Flowable res = api.workflows().chat().stream(builder.build()); -// InteractContext context = new InteractContext(); -// res.blockingForEach(chunk -> { -// ChatContext chatContext = processorContext.getChatContext(); -// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context); -// }); -// return null; } @Override diff --git a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowRunInvocationHandler.java b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowRunInvocationHandler.java index 4ca01e8c0..90fb8e0e0 100644 --- a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowRunInvocationHandler.java +++ b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-coze/src/main/java/com/yomahub/liteflow/ai/workflow/coze/invocation/CozeWorkflowRunInvocationHandler.java @@ -1,8 +1,29 @@ package com.yomahub.liteflow.ai.workflow.coze.invocation; +import cn.hutool.core.util.StrUtil; +import com.coze.openapi.client.workflows.run.RunWorkflowReq; +import com.coze.openapi.client.workflows.run.model.WorkflowEvent; +import com.coze.openapi.service.auth.TokenAuth; +import com.coze.openapi.service.config.Consts; +import com.coze.openapi.service.service.CozeAPI; +import com.coze.openapi.service.service.workflow.WorkflowRunService; +import com.yomahub.liteflow.ai.context.StreamHandler; +import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; +import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType; import com.yomahub.liteflow.ai.parse.context.ProcessorContext; import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler; +import com.yomahub.liteflow.ai.util.SpringUtil; +import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowRun; +import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty; +import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil; import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowRunProxyWrapBean; +import io.reactivex.Flowable; + +import java.util.Map; + +import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent; /** * Coze 工作流 调用处理器 @@ -19,63 +40,71 @@ public class CozeWorkflowRunInvocationHandler extends AbstractAIInvocationHandle @Override protected Object doExecuteAIProcess(ProcessorContext processorContext, Object[] args) { - return null; -// CozeWorkflowRun annotation = wrapBean.getAnnotation(); -// -// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class); -// -// CozeAPI api = new CozeAPI.Builder() -// .auth(new TokenAuth(property.getApiKey())) -// .baseURL( -// StrUtil.isBlank(property.getBaseUrl()) ? -// Consts.COZE_CN_BASE_URL : -// property.getBaseUrl() -// ) -// .build(); -// -// RunWorkflowReq.RunWorkflowReqBuilder builder = RunWorkflowReq.builder(); -// -// setIfPresent(builder::workflowID, annotation.workflowId()); -// -// if (annotation.parameters().length > 0) { -// Map paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext); -// if (!paramMap.isEmpty()) { -// builder.parameters(paramMap); -// } -// } -// -// setIfPresent(builder::appID, annotation.appId()); -// -// setIfPresent(builder::botID, annotation.botId()); -// -// if (annotation.ext().length > 0) { -// Map extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext); -// if (!extMap.isEmpty()) { -// builder.ext(extMap); -// } -// } -// -// setIfPresent(builder::connectTimeout, annotation.connectTimeout()); -// -// setIfPresent(builder::readTimeout, annotation.readTimeout()); -// -// setIfPresent(builder::writeTimeout, annotation.writeTimeout()); -// -// setIfPresent(builder::customerToken, annotation.customerToken()); -// -// WorkflowRunService runs = api.workflows().runs(); -// -// if (annotation.stream()) { -// Flowable res = runs.stream(builder.build()); -// InteractContext context = new InteractContext(); -// res.blockingForEach(chunk -> { -// ChatContext chatContext = processorContext.getChatContext(); -// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context); -// }); -// return null; -// } else { -// return runs.create(builder.build()); -// } + CozeWorkflowRun annotation = wrapBean.getAnnotation(); + + CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class); + + CozeAPI api = new CozeAPI.Builder() + .auth(new TokenAuth(property.getApiKey())) + .baseURL( + StrUtil.isBlank(property.getBaseUrl()) ? + Consts.COZE_CN_BASE_URL : + property.getBaseUrl() + ) + .build(); + + RunWorkflowReq.RunWorkflowReqBuilder builder = RunWorkflowReq.builder(); + + setIfPresent(builder::workflowID, annotation.workflowId()); + + if (annotation.parameters().length > 0) { + Map paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext); + if (!paramMap.isEmpty()) { + builder.parameters(paramMap); + } + } + + setIfPresent(builder::appID, annotation.appId()); + + setIfPresent(builder::botID, annotation.botId()); + + if (annotation.ext().length > 0) { + Map extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext); + if (!extMap.isEmpty()) { + builder.ext(extMap); + } + } + + setIfPresent(builder::connectTimeout, annotation.connectTimeout()); + + setIfPresent(builder::readTimeout, annotation.readTimeout()); + + setIfPresent(builder::writeTimeout, annotation.writeTimeout()); + + setIfPresent(builder::customerToken, annotation.customerToken()); + + WorkflowRunService runs = api.workflows().runs(); + + if (annotation.stream()) { + Flowable res = runs.stream(builder.build()); + InteractContext context = new InteractContext(); + StreamHandler streamHandler = processorContext.getChatContext().getStreamHandler(); + io.reactivex.rxjava3.core.Flowable.fromPublisher( + streamHandler.handle( + res.map(workflowEvent -> { + StreamingProtocolChunk chunk = new StreamingProtocolChunk(); + chunk.setData(workflowEvent); + chunk.setType(StreamingProtocolType.WORKFLOW_DATA); + chunk.setId(context.getChatId()); + return ChunkEvent.chunk(workflowEvent.getMessage().getContent(), + chunk, context); + }) + ) + ).blockingLast(); + return null; + } else { + return runs.create(builder.build()); + } } @Override diff --git a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-dashscope/src/main/java/com/yomahub/liteflow/ai/workflow/dashscope/invocation/DashScopeWorkflowInvocationHandler.java b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-dashscope/src/main/java/com/yomahub/liteflow/ai/workflow/dashscope/invocation/DashScopeWorkflowInvocationHandler.java index 6faf0fb7f..0f0a5b9d7 100644 --- a/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-dashscope/src/main/java/com/yomahub/liteflow/ai/workflow/dashscope/invocation/DashScopeWorkflowInvocationHandler.java +++ b/liteflow-ai/liteflow-ai-workflow/liteflow-ai-workflow-dashscope/src/main/java/com/yomahub/liteflow/ai/workflow/dashscope/invocation/DashScopeWorkflowInvocationHandler.java @@ -1,11 +1,29 @@ package com.yomahub.liteflow.ai.workflow.dashscope.invocation; +import cn.hutool.core.util.StrUtil; +import com.alibaba.dashscope.app.Application; +import com.alibaba.dashscope.app.ApplicationParam; +import com.alibaba.dashscope.app.ApplicationResult; import com.alibaba.dashscope.app.RagOptions; +import com.alibaba.dashscope.exception.InputRequiredException; +import com.alibaba.dashscope.exception.NoApiKeyException; +import com.alibaba.dashscope.utils.JsonUtils; import com.google.gson.JsonObject; +import com.yomahub.liteflow.ai.context.StreamHandler; +import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent; +import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType; +import com.yomahub.liteflow.ai.exception.LiteFlowAIException; +import com.yomahub.liteflow.ai.parse.context.ContextAccessor; import com.yomahub.liteflow.ai.parse.context.ProcessorContext; import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler; import com.yomahub.liteflow.ai.util.KeyValue; +import com.yomahub.liteflow.ai.util.SpringUtil; +import com.yomahub.liteflow.ai.workflow.dashscope.annotation.DashScopeWorkflow; +import com.yomahub.liteflow.ai.workflow.dashscope.config.DashScopeWorkflowProperty; import com.yomahub.liteflow.ai.workflow.dashscope.wrap.DashScopeWorkflowProxyWrapBean; +import io.reactivex.Flowable; import java.util.Arrays; import java.util.List; @@ -31,96 +49,103 @@ public class DashScopeWorkflowInvocationHandler extends AbstractAIInvocationHand @Override protected Object doExecuteAIProcess(ProcessorContext processorContext, Object[] args) { - return null; -// DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation(); -// -// DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class); -// -// ApplicationParam.ApplicationParamBuilder builder = ApplicationParam.builder(); -// -// builder.appId(dashScopeWorkflow.appId()); -// builder.apiKey(property.getApiKey()); -// -// setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext); -// -// String history = dashScopeWorkflow.history(); -// setIfPresent(builder::history, StrUtil.isNotBlank(history) ? -// ContextAccessor.searchContextByExpression(history, processorContext) : null); -// -// String messages = dashScopeWorkflow.messages(); -// setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ? -// ContextAccessor.searchContextByExpression(messages, processorContext) : null); -// -// setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId()); -// -// setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts()); -// -// setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(), -// processorContext, JsonObject.class, JsonUtils::parse); -// -// setIfPresent(builder::topP, dashScopeWorkflow.topP()); -// -// setIfPresent(builder::topK, dashScopeWorkflow.topK()); -// -// setIfPresent(builder::seed, dashScopeWorkflow.seed()); -// -// setIfPresent(builder::temperature, dashScopeWorkflow.temperature()); -// -// setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput()); -// -// setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId()); -// -// setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class, -// s -> Arrays.stream(s.split(COMMA_SPLITTER)) -// .map(String::trim) -// .collect(Collectors.toList())); -// -// setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext)); -// -// setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class, -// s -> Arrays.stream(s.split(COMMA_SPLITTER)) -// .map(String::trim) -// .collect(Collectors.toList())); -// -// setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch()); -// -// setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime()); -// -// setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium()); -// -// setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound()); -// -// setIfPresent(builder::modelId, dashScopeWorkflow.modelId()); -// -// setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode()); -// -// setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking()); -// -// ApplicationParam applicationParam = builder.build(); -// Application application = StrUtil.isNotBlank(property.getApiUrl()) ? -// new Application(property.getApiUrl()) : -// new Application(); -// -// if (dashScopeWorkflow.stream()) { -// try { -// Flowable res = application.streamCall(applicationParam); -// InteractContext context = new InteractContext(); -// res.blockingForEach(chunk -> { -// ChatContext chatContext = processorContext.getChatContext(); -// chatContext.getStreamHandler().onText(chunk.getOutput().getText(), context); -// context.addText(chunk.getOutput().getText()); -// }); -// return null; -// } catch (NoApiKeyException | InputRequiredException e) { -// throw new LiteFlowAIException("DashScope stream call failed", e); -// } -// } else { -// try { -// return application.call(applicationParam); -// } catch (NoApiKeyException | InputRequiredException e) { -// throw new LiteFlowAIException("DashScope call failed", e); -// } -// } + DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation(); + + DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class); + + ApplicationParam.ApplicationParamBuilder builder = ApplicationParam.builder(); + + builder.appId(dashScopeWorkflow.appId()); + builder.apiKey(property.getApiKey()); + + setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext); + + String history = dashScopeWorkflow.history(); + setIfPresent(builder::history, StrUtil.isNotBlank(history) ? + ContextAccessor.searchContextByExpression(history, processorContext) : null); + + String messages = dashScopeWorkflow.messages(); + setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ? + ContextAccessor.searchContextByExpression(messages, processorContext) : null); + + setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId()); + + setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts()); + + setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(), + processorContext, JsonObject.class, JsonUtils::parse); + + setIfPresent(builder::topP, dashScopeWorkflow.topP()); + + setIfPresent(builder::topK, dashScopeWorkflow.topK()); + + setIfPresent(builder::seed, dashScopeWorkflow.seed()); + + setIfPresent(builder::temperature, dashScopeWorkflow.temperature()); + + setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput()); + + setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId()); + + setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class, + s -> Arrays.stream(s.split(COMMA_SPLITTER)) + .map(String::trim) + .collect(Collectors.toList())); + + setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext)); + + setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class, + s -> Arrays.stream(s.split(COMMA_SPLITTER)) + .map(String::trim) + .collect(Collectors.toList())); + + setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch()); + + setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime()); + + setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium()); + + setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound()); + + setIfPresent(builder::modelId, dashScopeWorkflow.modelId()); + + setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode()); + + setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking()); + + ApplicationParam applicationParam = builder.build(); + Application application = StrUtil.isNotBlank(property.getApiUrl()) ? + new Application(property.getApiUrl()) : + new Application(); + + if (dashScopeWorkflow.stream()) { + try { + Flowable res = application.streamCall(applicationParam); + InteractContext context = new InteractContext(); + StreamHandler streamHandler = processorContext.getChatContext().getStreamHandler(); + io.reactivex.rxjava3.core.Flowable.fromPublisher( + streamHandler.handle( + res.map(workflowEvent -> { + StreamingProtocolChunk chunk = new StreamingProtocolChunk(); + chunk.setData(workflowEvent); + chunk.setType(StreamingProtocolType.WORKFLOW_DATA); + chunk.setId(context.getChatId()); + return ChunkEvent.chunk(workflowEvent.getOutput().getText(), + chunk, context); + }) + ) + ).blockingLast(); + return null; + } catch (NoApiKeyException | InputRequiredException e) { + throw new LiteFlowAIException("DashScope stream call failed", e); + } + } else { + try { + return application.call(applicationParam); + } catch (NoApiKeyException | InputRequiredException e) { + throw new LiteFlowAIException("DashScope call failed", e); + } + } } @SuppressWarnings("unchecked") diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/chat/ChatTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/chat/ChatTest.java index 8a0e4a29f..e239a5b35 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/chat/ChatTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/chat/ChatTest.java @@ -11,6 +11,7 @@ import com.yomahub.liteflow.core.FlowExecutor; import com.yomahub.liteflow.flow.LiteflowResponse; import com.yomahub.liteflow.test.ai.mock.MockAITest; import com.yomahub.liteflow.test.ai.mock.TestDataReader; +import io.reactivex.rxjava3.core.Flowable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -75,47 +76,48 @@ public class ChatTest extends MockAITest { } private StreamHandler getStreamHandler() { - return eventStream -> eventStream.doOnNext(chunkEvent -> { - if (chunkEvent.isStart()) { - System.out.println("chat start"); - } - if (chunkEvent.isChunk()) { - StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); - switch (chunk.getType()) { - case TEXT: - System.out.println(chunk.getData()); - break; - case THINKING: - System.out.println("[Thinking] " + chunk.getData()); - } - } - if (chunkEvent.isComplete()) { - ChatResponse response = chunkEvent.getFinalResponse(); - AssistantMessage message = response.getOutput(); - if (message.getContent() != null && !message.getContent().trim().isEmpty()) { - System.out.println("内容长度: " + message.getContent().length()); - if (message.getContent().length() > 200) { - System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); - } else { - System.out.println("内容: " + message.getContent()); + return eventStream -> Flowable.fromPublisher(eventStream) + .doOnNext(chunkEvent -> { + if (chunkEvent.isStart()) { + System.out.println("chat start"); } - } - if (response.hasToolCalls()) { - System.out.println("工具调用数量: " + message.getToolCalls().size()); - for (int i = 0; i < message.getToolCalls().size(); i++) { - ToolCall toolCall = message.getToolCalls().get(i); - System.out.println("工具调用 " + (i + 1) + ":"); - System.out.println(" ID: " + toolCall.getId()); - System.out.println(" 名称: " + toolCall.getName()); - System.out.println(" 类型: " + toolCall.getType()); - System.out.println(" 参数: " + toolCall.getArguments()); + if (chunkEvent.isChunk()) { + StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); + switch (chunk.getType()) { + case TEXT: + System.out.println(chunk.getData()); + break; + case THINKING: + System.out.println("[Thinking] " + chunk.getData()); + } } - } - System.out.println("Token使用情况: " + response.getTokenUsage()); - System.out.println("完成原因: " + response.getFinishReason()); - System.out.println("chat close"); - } - }); + if (chunkEvent.isComplete()) { + ChatResponse response = chunkEvent.getFinalResponse(); + AssistantMessage message = response.getOutput(); + if (message.getContent() != null && !message.getContent().trim().isEmpty()) { + System.out.println("内容长度: " + message.getContent().length()); + if (message.getContent().length() > 200) { + System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); + } else { + System.out.println("内容: " + message.getContent()); + } + } + if (response.hasToolCalls()) { + System.out.println("工具调用数量: " + message.getToolCalls().size()); + for (int i = 0; i < message.getToolCalls().size(); i++) { + ToolCall toolCall = message.getToolCalls().get(i); + System.out.println("工具调用 " + (i + 1) + ":"); + System.out.println(" ID: " + toolCall.getId()); + System.out.println(" 名称: " + toolCall.getName()); + System.out.println(" 类型: " + toolCall.getType()); + System.out.println(" 参数: " + toolCall.getArguments()); + } + } + System.out.println("Token使用情况: " + response.getTokenUsage()); + System.out.println("完成原因: " + response.getFinishReason()); + System.out.println("chat close"); + } + }); } } diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/structure/StructureTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/structure/StructureTest.java index 5633d4388..3e117cfba 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/structure/StructureTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/structure/StructureTest.java @@ -11,6 +11,7 @@ import com.yomahub.liteflow.core.FlowExecutor; import com.yomahub.liteflow.flow.LiteflowResponse; import com.yomahub.liteflow.test.ai.mock.MockAITest; import com.yomahub.liteflow.test.ai.mock.TestDataReader; +import io.reactivex.rxjava3.core.Flowable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -82,46 +83,47 @@ public class StructureTest extends MockAITest { } private StreamHandler getStreamHandler() { - return eventStream -> eventStream.doOnNext(chunkEvent -> { - if (chunkEvent.isStart()) { - System.out.println("chat start"); - } - if (chunkEvent.isChunk()) { - StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); - switch (chunk.getType()) { - case TEXT: - System.out.println(chunk.getData()); - break; - case THINKING: - System.out.println("[Thinking] " + chunk.getData()); - } - } - if (chunkEvent.isComplete()) { - ChatResponse response = chunkEvent.getFinalResponse(); - AssistantMessage message = response.getOutput(); - if (message.getContent() != null && !message.getContent().trim().isEmpty()) { - System.out.println("内容长度: " + message.getContent().length()); - if (message.getContent().length() > 200) { - System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); - } else { - System.out.println("内容: " + message.getContent()); + return eventStream -> Flowable.fromPublisher(eventStream) + .doOnNext(chunkEvent -> { + if (chunkEvent.isStart()) { + System.out.println("chat start"); } - } - if (response.hasToolCalls()) { - System.out.println("工具调用数量: " + message.getToolCalls().size()); - for (int i = 0; i < message.getToolCalls().size(); i++) { - ToolCall toolCall = message.getToolCalls().get(i); - System.out.println("工具调用 " + (i + 1) + ":"); - System.out.println(" ID: " + toolCall.getId()); - System.out.println(" 名称: " + toolCall.getName()); - System.out.println(" 类型: " + toolCall.getType()); - System.out.println(" 参数: " + toolCall.getArguments()); + if (chunkEvent.isChunk()) { + StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); + switch (chunk.getType()) { + case TEXT: + System.out.println(chunk.getData()); + break; + case THINKING: + System.out.println("[Thinking] " + chunk.getData()); + } } - } - System.out.println("Token使用情况: " + response.getTokenUsage()); - System.out.println("完成原因: " + response.getFinishReason()); - System.out.println("chat close"); - } - }); + if (chunkEvent.isComplete()) { + ChatResponse response = chunkEvent.getFinalResponse(); + AssistantMessage message = response.getOutput(); + if (message.getContent() != null && !message.getContent().trim().isEmpty()) { + System.out.println("内容长度: " + message.getContent().length()); + if (message.getContent().length() > 200) { + System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); + } else { + System.out.println("内容: " + message.getContent()); + } + } + if (response.hasToolCalls()) { + System.out.println("工具调用数量: " + message.getToolCalls().size()); + for (int i = 0; i < message.getToolCalls().size(); i++) { + ToolCall toolCall = message.getToolCalls().get(i); + System.out.println("工具调用 " + (i + 1) + ":"); + System.out.println(" ID: " + toolCall.getId()); + System.out.println(" 名称: " + toolCall.getName()); + System.out.println(" 类型: " + toolCall.getType()); + System.out.println(" 参数: " + toolCall.getArguments()); + } + } + System.out.println("Token使用情况: " + response.getTokenUsage()); + System.out.println("完成原因: " + response.getFinishReason()); + System.out.println("chat close"); + } + }); } } diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/tool/ToolTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/tool/ToolTest.java index fe1da4c32..525d51ba3 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/tool/ToolTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/core/tool/ToolTest.java @@ -14,6 +14,7 @@ import com.yomahub.liteflow.flow.LiteflowResponse; import com.yomahub.liteflow.test.ai.core.tool.tools.ToolConfig; import com.yomahub.liteflow.test.ai.mock.MockAITest; import com.yomahub.liteflow.test.ai.mock.TestDataReader; +import io.reactivex.rxjava3.core.Flowable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -110,46 +111,47 @@ public class ToolTest extends MockAITest { } private StreamHandler getStreamHandler() { - return eventStream -> eventStream.doOnNext(chunkEvent -> { - if (chunkEvent.isStart()) { - System.out.println("chat start"); - } - if (chunkEvent.isChunk()) { - StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); - switch (chunk.getType()) { - case TEXT: - System.out.println(chunk.getData()); - break; - case THINKING: - System.out.println("[Thinking] " + chunk.getData()); - } - } - if (chunkEvent.isComplete()) { - ChatResponse response = chunkEvent.getFinalResponse(); - AssistantMessage message = response.getOutput(); - if (message.getContent() != null && !message.getContent().trim().isEmpty()) { - System.out.println("内容长度: " + message.getContent().length()); - if (message.getContent().length() > 200) { - System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); - } else { - System.out.println("内容: " + message.getContent()); + return eventStream -> Flowable.fromPublisher(eventStream) + .doOnNext(chunkEvent -> { + if (chunkEvent.isStart()) { + System.out.println("chat start"); } - } - if (response.hasToolCalls()) { - System.out.println("工具调用数量: " + message.getToolCalls().size()); - for (int i = 0; i < message.getToolCalls().size(); i++) { - ToolCall toolCall = message.getToolCalls().get(i); - System.out.println("工具调用 " + (i + 1) + ":"); - System.out.println(" ID: " + toolCall.getId()); - System.out.println(" 名称: " + toolCall.getName()); - System.out.println(" 类型: " + toolCall.getType()); - System.out.println(" 参数: " + toolCall.getArguments()); + if (chunkEvent.isChunk()) { + StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk(); + switch (chunk.getType()) { + case TEXT: + System.out.println(chunk.getData()); + break; + case THINKING: + System.out.println("[Thinking] " + chunk.getData()); + } } - } - System.out.println("Token使用情况: " + response.getTokenUsage()); - System.out.println("完成原因: " + response.getFinishReason()); - System.out.println("chat close"); - } - }); + if (chunkEvent.isComplete()) { + ChatResponse response = chunkEvent.getFinalResponse(); + AssistantMessage message = response.getOutput(); + if (message.getContent() != null && !message.getContent().trim().isEmpty()) { + System.out.println("内容长度: " + message.getContent().length()); + if (message.getContent().length() > 200) { + System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); + } else { + System.out.println("内容: " + message.getContent()); + } + } + if (response.hasToolCalls()) { + System.out.println("工具调用数量: " + message.getToolCalls().size()); + for (int i = 0; i < message.getToolCalls().size(); i++) { + ToolCall toolCall = message.getToolCalls().get(i); + System.out.println("工具调用 " + (i + 1) + ":"); + System.out.println(" ID: " + toolCall.getId()); + System.out.println(" 名称: " + toolCall.getName()); + System.out.println(" 类型: " + toolCall.getType()); + System.out.println(" 参数: " + toolCall.getArguments()); + } + } + System.out.println("Token使用情况: " + response.getTokenUsage()); + System.out.println("完成原因: " + response.getFinishReason()); + System.out.println("chat close"); + } + }); } } diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/engine/interact/InteractTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/engine/interact/InteractTest.java index 000acc6d3..56c1398cf 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/engine/interact/InteractTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/engine/interact/InteractTest.java @@ -1,5 +1,28 @@ package com.yomahub.liteflow.test.ai.engine.interact; +import com.yomahub.liteflow.ai.domain.enums.ProviderEnum; +import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext; +import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk; +import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer; +import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse; +import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage; +import com.yomahub.liteflow.ai.engine.tool.ToolCall; +import com.yomahub.liteflow.ai.model.dashscope.interact.DashScopeProtocolTransformer; +import com.yomahub.liteflow.ai.model.ollama.interact.OllamaProtocolTransformer; +import com.yomahub.liteflow.ai.model.openai.interact.OpenAIProtocolTransformer; +import com.yomahub.liteflow.test.ai.mock.TestDataReader; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.List; +import java.util.Objects; + +import static org.junit.jupiter.api.Assertions.*; + /** * 交互处理器测试类 * 测试各种协议转换器和流式处理管道的功能 @@ -8,327 +31,349 @@ package com.yomahub.liteflow.test.ai.engine.interact; * @since 2.16.0 */ public class InteractTest { -// -// private ProtocolTransformer protocolTransformer; -// private InteractContext context; -// -// /** -// * 创建协议转换器的工厂方法 -// * -// * @param provider 提供者枚举 -// * @return 对应的协议转换器实例 -// */ -// private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) { -// switch (provider) { -// case OPENAI: -// return new OpenAIProtocolTransformer(); -// case DASHSCOPE: -// return new DashScopeProtocolTransformer(); -// case OLLAMA: -// return new OllamaProtocolTransformer(); -// default: -// throw new IllegalArgumentException("不支持的提供者: " + provider); -// } -// } -// -// @Nested -// @DisplayName("OpenAI协议转换器测试") -// class OpenAITests { -// -// @BeforeEach -// void setUp() { -// context = new InteractContext(); -// protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI); -// } -// -// @Test -// @DisplayName("测试OpenAI流式文本响应转换") -// void testStreamingTextResponse() { -// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试OpenAI流式工具调用响应转换") -// void testStreamingToolCallResponse() { -// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试OpenAI流式结构化响应转换") -// void testStreamingStructuredResponse() { -// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false); -// } -// -// @Test -// @DisplayName("测试OpenAI阻塞式文本响应转换") -// void testBlockingTextResponse() { -// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试OpenAI阻塞式工具调用响应转换") -// void testBlockingToolCallResponse() { -// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试OpenAI阻塞式结构化响应转换") -// void testBlockingStructuredResponse() { -// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); -// } -// } -// -// @Nested -// @DisplayName("DashScope协议转换器测试") -// class DashScopeTests { -// -// @BeforeEach -// void setUp() { -// context = new InteractContext(); -// protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE); -// } -// -// @Test -// @DisplayName("测试DashScope流式文本响应转换") -// void testStreamingTextResponse() { -// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试DashScope流式工具调用响应转换") -// void testStreamingToolCallResponse() { -// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试DashScope流式结构化响应转换") -// void testStreamingStructuredResponse() { -// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false); -// } -// -// @Test -// @DisplayName("测试DashScope阻塞式文本响应转换") -// void testBlockingTextResponse() { -// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试DashScope阻塞式工具调用响应转换") -// void testBlockingToolCallResponse() { -// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试DashScope阻塞式结构化响应转换") -// void testBlockingStructuredResponse() { -// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); -// } -// } -// -// @Nested -// @DisplayName("Ollama协议转换器测试") -// class OllamaTests { -// -// @BeforeEach -// void setUp() { -// context = new InteractContext(); -// protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA); -// } -// -// @Test -// @DisplayName("测试Ollama流式文本响应转换") -// void testStreamingTextResponse() { -// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试Ollama流式工具调用响应转换") -// void testStreamingToolCallResponse() { -// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试Ollama流式结构化响应转换") -// void testStreamingStructuredResponse() { -// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false); -// } -// -// @Test -// @DisplayName("测试Ollama阻塞式文本响应转换") -// void testBlockingTextResponse() { -// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false); -// } -// -// @Test -// @DisplayName("测试Ollama阻塞式工具调用响应转换") -// void testBlockingToolCallResponse() { -// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); -// } -// -// @Test -// @DisplayName("测试Ollama阻塞式结构化响应转换") -// void testBlockingStructuredResponse() { -// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); -// } -// } -// -// @Nested -// @DisplayName("参数化测试 - 所有协议转换器") -// class ParameterizedTests { -// -// @ParameterizedTest -// @EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" }) -// @DisplayName("参数化测试协议转换器提供者名称") -// void testProtocolTransformerProviderEnumName(ProviderEnum provider) { -// ProtocolTransformer transformer = createProtocolTransformer(provider); -// assertEquals(provider.getProviderName(), transformer.getProviderName(), -// provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'"); -// } -// } -// -// /** -// * 测试流式响应处理的通用方法 -// * -// * @param provider 提供者 -// * @param requestType 响应类型 -// * @param hasToolCalls 是否包含工具调用 -// */ -// private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) { -// // 重新构建pipeline -// context = new InteractContext(); -// protocolTransformer = createProtocolTransformer(provider); -// pipeline = ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, chunkCallbackTransformer); -// -// // 读取测试数据 -// List streamChunks = TestDataReader.getStreamingChunks(provider, requestType); -// assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空"); -// -// // 处理流式数据 -// for (String streamChunk : streamChunks) { -// pipeline.processStreaming(streamChunk); -// // 注意:这里不强制要求最后一次返回true,因为不同模型的结束标志可能不同 -// } -// -// // 获取最终响应 -// ChatResponse chatResponse = pipeline.buildFinalStreamingResponse(); -// -// // 验证响应 -// assertNotNull(chatResponse, "聊天响应不应为null"); -// assertNotNull(chatResponse.getOutput(), "响应输出不应为null"); -// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型"); -// -// AssistantMessage message = (AssistantMessage) chatResponse.getOutput(); -// -// // 验证工具调用 -// if (hasToolCalls) { -// assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用"); -// assertNotNull(message.getToolCalls(), "工具调用列表不应为null"); -// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空"); -// -// // 验证工具调用内容 -// for (ToolCall toolCall : message.getToolCalls()) { -// if (!provider.equals(ProviderEnum.OLLAMA)) { -// assertNotNull(toolCall.getId(), "工具调用ID不应为null"); -// } -// assertNotNull(toolCall.getName(), "工具调用名称不应为null"); -// assertNotNull(toolCall.getType(), "工具调用类型不应为null"); -// } -// } else { -// // 对于非工具调用响应,验证内容存在 -// assertNotNull(message.getContent(), "消息内容不应为null"); -// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空 -// } -// -// // 打印测试结果 -// System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ==="); -// if (message.getContent() != null && !message.getContent().trim().isEmpty()) { -// System.out.println( -// "内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..." -// : message.getContent())); -// } -// if (hasToolCalls && message.getToolCalls() != null) { -// System.out.println("工具调用数量: " + message.getToolCalls().size()); -// } -// System.out.println("Token使用情况: " + chatResponse.getTokenUsage()); -// System.out.println("完成原因: " + chatResponse.getFinishReason()); -// } -// -// /** -// * 测试阻塞式响应处理的通用方法 -// * -// * @param provider 提供者 -// * @param responseType 响应类型 -// * @param hasToolCalls 是否包含工具调用 -// */ -// private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) { -// // 重新构建pipeline -// context = new InteractContext(); -// protocolTransformer = createProtocolTransformer(provider); -// pipeline = ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer); -// -// // 读取测试数据 -// String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType); -// assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null"); -// assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空"); -// -// // 处理阻塞式响应 -// ChatResponse chatResponse = pipeline.processBlocking(blockingResponse); -// -// // 验证响应 -// assertNotNull(chatResponse, "聊天响应不应为null"); -// assertNotNull(chatResponse.getOutput(), "响应输出不应为null"); -// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型"); -// -// AssistantMessage message = (AssistantMessage) chatResponse.getOutput(); -// -// // 验证工具调用 -// if (hasToolCalls) { -// assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用"); -// assertNotNull(message.getToolCalls(), "工具调用列表不应为null"); -// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空"); -// -// // 验证每个工具调用 -// for (ToolCall toolCall : message.getToolCalls()) { -// if (!provider.equals(ProviderEnum.OLLAMA)) { -// assertNotNull(toolCall.getId(), "工具调用ID不应为null"); -// } -// assertNotNull(toolCall.getName(), "工具调用名称不应为null"); -// assertNotNull(toolCall.getType(), "工具调用类型不应为null"); -// assertEquals("function", toolCall.getType(), "工具调用类型应为function"); -// } -// } else { -// // 对于非工具调用响应,验证内容存在 -// assertNotNull(message.getContent(), "消息内容不应为null"); -// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空 -// } -// -// // 验证token使用情况(如果存在) -// if (chatResponse.getTokenUsage() != null) { -// assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0"); -// } -// -// // 打印测试结果 -// System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ==="); -// if (message.getContent() != null && !message.getContent().trim().isEmpty()) { -// System.out.println("内容长度: " + message.getContent().length()); -// if (message.getContent().length() > 200) { -// System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); -// } else { -// System.out.println("内容: " + message.getContent()); -// } -// } -// if (hasToolCalls && message.getToolCalls() != null) { -// System.out.println("工具调用数量: " + message.getToolCalls().size()); -// for (int i = 0; i < message.getToolCalls().size(); i++) { -// ToolCall toolCall = message.getToolCalls().get(i); -// System.out.println("工具调用 " + (i + 1) + ":"); -// System.out.println(" ID: " + toolCall.getId()); -// System.out.println(" 名称: " + toolCall.getName()); -// System.out.println(" 类型: " + toolCall.getType()); -// System.out.println(" 参数: " + toolCall.getArguments()); -// } -// } -// System.out.println("Token使用情况: " + chatResponse.getTokenUsage()); -// System.out.println("完成原因: " + chatResponse.getFinishReason()); -// } + + private ProtocolTransformer protocolTransformer; + private InteractContext context; + + /** + * 创建协议转换器的工厂方法 + * + * @param provider 提供者枚举 + * @return 对应的协议转换器实例 + */ + private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) { + switch (provider) { + case OPENAI: + return new OpenAIProtocolTransformer(); + case DASHSCOPE: + return new DashScopeProtocolTransformer(); + case OLLAMA: + return new OllamaProtocolTransformer(); + default: + throw new IllegalArgumentException("不支持的提供者: " + provider); + } + } + + @Nested + @DisplayName("OpenAI协议转换器测试") + class OpenAITests { + + @BeforeEach + void setUp() { + context = new InteractContext(); + protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI); + } + + @Test + @DisplayName("测试OpenAI流式文本响应转换") + void testStreamingTextResponse() { + testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false); + } + + @Test + @DisplayName("测试OpenAI流式工具调用响应转换") + void testStreamingToolCallResponse() { + testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试OpenAI流式结构化响应转换") + void testStreamingStructuredResponse() { + testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false); + } + + @Test + @DisplayName("测试OpenAI阻塞式文本响应转换") + void testBlockingTextResponse() { + testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false); + } + + @Test + @DisplayName("测试OpenAI阻塞式工具调用响应转换") + void testBlockingToolCallResponse() { + testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试OpenAI阻塞式结构化响应转换") + void testBlockingStructuredResponse() { + testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); + } + } + + @Nested + @DisplayName("DashScope协议转换器测试") + class DashScopeTests { + + @BeforeEach + void setUp() { + context = new InteractContext(); + protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE); + } + + @Test + @DisplayName("测试DashScope流式文本响应转换") + void testStreamingTextResponse() { + testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false); + } + + @Test + @DisplayName("测试DashScope流式工具调用响应转换") + void testStreamingToolCallResponse() { + testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试DashScope流式结构化响应转换") + void testStreamingStructuredResponse() { + testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false); + } + + @Test + @DisplayName("测试DashScope阻塞式文本响应转换") + void testBlockingTextResponse() { + testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false); + } + + @Test + @DisplayName("测试DashScope阻塞式工具调用响应转换") + void testBlockingToolCallResponse() { + testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试DashScope阻塞式结构化响应转换") + void testBlockingStructuredResponse() { + testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); + } + } + + @Nested + @DisplayName("Ollama协议转换器测试") + class OllamaTests { + + @BeforeEach + void setUp() { + context = new InteractContext(); + protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA); + } + + @Test + @DisplayName("测试Ollama流式文本响应转换") + void testStreamingTextResponse() { + testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false); + } + + @Test + @DisplayName("测试Ollama流式工具调用响应转换") + void testStreamingToolCallResponse() { + testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试Ollama流式结构化响应转换") + void testStreamingStructuredResponse() { + testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false); + } + + @Test + @DisplayName("测试Ollama阻塞式文本响应转换") + void testBlockingTextResponse() { + testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false); + } + + @Test + @DisplayName("测试Ollama阻塞式工具调用响应转换") + void testBlockingToolCallResponse() { + testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true); + } + + @Test + @DisplayName("测试Ollama阻塞式结构化响应转换") + void testBlockingStructuredResponse() { + testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false); + } + } + + @Nested + @DisplayName("参数化测试 - 所有协议转换器") + class ParameterizedTests { + + @ParameterizedTest + @EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" }) + @DisplayName("参数化测试协议转换器提供者名称") + void testProtocolTransformerProviderEnumName(ProviderEnum provider) { + ProtocolTransformer transformer = createProtocolTransformer(provider); + assertEquals(provider.getProviderName(), transformer.getProviderName(), + provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'"); + } + } + + /** + * 测试流式响应处理的通用方法 + * + * @param provider 提供者 + * @param requestType 响应类型 + * @param hasToolCalls 是否包含工具调用 + */ + private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) { + // 重新构建pipeline + context = new InteractContext(); + protocolTransformer = createProtocolTransformer(provider); + + // 读取测试数据 + List streamChunks = TestDataReader.getStreamingChunks(provider, requestType); + assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空"); + + // 处理流式数据 + for (String streamChunk : streamChunks) { + StreamingProtocolChunk chunk = protocolTransformer.transformStreamingChunk(streamChunk, context); + updateContextFromChunk(context, chunk); + } + + // 获取最终响应 + ChatResponse chatResponse = protocolTransformer.transformStreamingResponse(context); + + // 验证响应 + assertNotNull(chatResponse, "聊天响应不应为null"); + assertNotNull(chatResponse.getOutput(), "响应输出不应为null"); + assertNotNull(chatResponse.getOutput(), "输出应为AssistantMessage类型"); + + AssistantMessage message = chatResponse.getOutput(); + + // 验证工具调用 + if (hasToolCalls) { + assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用"); + assertNotNull(message.getToolCalls(), "工具调用列表不应为null"); + assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空"); + + // 验证工具调用内容 + for (ToolCall toolCall : message.getToolCalls()) { + if (!provider.equals(ProviderEnum.OLLAMA)) { + assertNotNull(toolCall.getId(), "工具调用ID不应为null"); + } + assertNotNull(toolCall.getName(), "工具调用名称不应为null"); + assertNotNull(toolCall.getType(), "工具调用类型不应为null"); + } + } else { + // 对于非工具调用响应,验证内容存在 + assertNotNull(message.getContent(), "消息内容不应为null"); + // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空 + } + + // 打印测试结果 + System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ==="); + if (message.getContent() != null && !message.getContent().trim().isEmpty()) { + System.out.println( + "内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..." + : message.getContent())); + } + if (hasToolCalls && message.getToolCalls() != null) { + System.out.println("工具调用数量: " + message.getToolCalls().size()); + } + System.out.println("Token使用情况: " + chatResponse.getTokenUsage()); + System.out.println("完成原因: " + chatResponse.getFinishReason()); + } + + /** + * 测试阻塞式响应处理的通用方法 + * + * @param provider 提供者 + * @param responseType 响应类型 + * @param hasToolCalls 是否包含工具调用 + */ + private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) { + // 重新构建pipeline + context = new InteractContext(); + protocolTransformer = createProtocolTransformer(provider); + + // 读取测试数据 + String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType); + assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null"); + assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空"); + + // 处理阻塞式响应 + ChatResponse chatResponse = protocolTransformer.transformBlockingResponse(blockingResponse, context); + + // 验证响应 + assertNotNull(chatResponse, "聊天响应不应为null"); + assertNotNull(chatResponse.getOutput(), "响应输出不应为null"); + assertNotNull(chatResponse.getOutput(), "输出应为AssistantMessage类型"); + + AssistantMessage message = chatResponse.getOutput(); + + // 验证工具调用 + if (hasToolCalls) { + assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用"); + assertNotNull(message.getToolCalls(), "工具调用列表不应为null"); + assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空"); + + // 验证每个工具调用 + for (ToolCall toolCall : message.getToolCalls()) { + if (!provider.equals(ProviderEnum.OLLAMA)) { + assertNotNull(toolCall.getId(), "工具调用ID不应为null"); + } + assertNotNull(toolCall.getName(), "工具调用名称不应为null"); + assertNotNull(toolCall.getType(), "工具调用类型不应为null"); + assertEquals("function", toolCall.getType(), "工具调用类型应为function"); + } + } else { + // 对于非工具调用响应,验证内容存在 + assertNotNull(message.getContent(), "消息内容不应为null"); + // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空 + } + + // 验证token使用情况(如果存在) + if (chatResponse.getTokenUsage() != null) { + assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0"); + } + + // 打印测试结果 + System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ==="); + if (message.getContent() != null && !message.getContent().trim().isEmpty()) { + System.out.println("内容长度: " + message.getContent().length()); + if (message.getContent().length() > 200) { + System.out.println("内容预览: " + message.getContent().substring(0, 200) + "..."); + } else { + System.out.println("内容: " + message.getContent()); + } + } + if (hasToolCalls && message.getToolCalls() != null) { + System.out.println("工具调用数量: " + message.getToolCalls().size()); + for (int i = 0; i < message.getToolCalls().size(); i++) { + ToolCall toolCall = message.getToolCalls().get(i); + System.out.println("工具调用 " + (i + 1) + ":"); + System.out.println(" ID: " + toolCall.getId()); + System.out.println(" 名称: " + toolCall.getName()); + System.out.println(" 类型: " + toolCall.getType()); + System.out.println(" 参数: " + toolCall.getArguments()); + } + } + System.out.println("Token使用情况: " + chatResponse.getTokenUsage()); + System.out.println("完成原因: " + chatResponse.getFinishReason()); + } + + /** + * 根据分块信息更新交互上下文 + * + * @param context 交互上下文 + * @param chunk 协议分块 + */ + private void updateContextFromChunk(InteractContext context, StreamingProtocolChunk chunk) { + if (Objects.isNull(chunk)) { + return; + } + + switch (chunk.getType()) { + case TEXT: + context.addText((String) chunk.getData()); + break; + case THINKING: + context.addThinking((String) chunk.getData()); + break; + default: + // 其他类型暂不处理 + break; + } + } } diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/mock/mockbean/MockInteractClient.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/mock/mockbean/MockInteractClient.java index 04a5d3d1b..bb0931089 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/mock/mockbean/MockInteractClient.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/mock/mockbean/MockInteractClient.java @@ -6,8 +6,8 @@ import com.yomahub.liteflow.ai.engine.interact.transport.TransportType; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest; import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse; -import io.reactivex.rxjava3.core.Flowable; import org.mockito.Mockito; +import org.reactivestreams.Publisher; import java.util.concurrent.CompletableFuture; @@ -53,7 +53,7 @@ public class MockInteractClient extends LlmInteractClient { * 重写 stream 方法 */ @Override - public Flowable stream(ChatConfig config, ChatRequest request) { + public Publisher stream(ChatConfig config, ChatRequest request) { // 创建一个 spied request,它在内部被配置为使用 MockTransport ChatRequest spiedRequest = createSpiedRequest(request); diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/dashscope/DashScopeChatTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/dashscope/DashScopeChatTest.java index 080810888..63aa1db05 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/dashscope/DashScopeChatTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/dashscope/DashScopeChatTest.java @@ -69,13 +69,13 @@ public class DashScopeChatTest extends MockAITest { new UserMessage("请给我讲一个关于宇宙探索的短故事") ); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) .messages(messages) .build() - ); + )); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/ollama/OllamaChatTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/ollama/OllamaChatTest.java index 441f26f92..3ffbb03b4 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/ollama/OllamaChatTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/ollama/OllamaChatTest.java @@ -68,13 +68,13 @@ public class OllamaChatTest extends MockAITest { new UserMessage("why is the sky blue?") ); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.DN_JSON) .messages(messages) .build() - ); + )); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/openai/OpenAIChatTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/openai/OpenAIChatTest.java index 93ce7eebb..977ef461d 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/openai/OpenAIChatTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/chat/openai/OpenAIChatTest.java @@ -22,7 +22,6 @@ import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutionException; /** * OpenAI chat 测试 @@ -58,20 +57,20 @@ public class OpenAIChatTest extends MockAITest { } @Test - public void testStreaming() throws ExecutionException, InterruptedException { + public void testStreaming() { setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT); List messages = Arrays.asList( new SystemMessage("You are a helpful assistant."), new UserMessage("请给我讲一个关于未来城市的短故事")); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) .messages(messages) .build() - ); + )); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/dashscope/DashScopeStructureTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/dashscope/DashScopeStructureTest.java index 13d361ac1..f44211338 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/dashscope/DashScopeStructureTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/dashscope/DashScopeStructureTest.java @@ -81,7 +81,7 @@ public class DashScopeStructureTest extends MockAITest { new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1") ); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -90,7 +90,8 @@ public class DashScopeStructureTest extends MockAITest { .responseType(ResponseType.JSON) .targetType(MathReasoning.class) // 结构化输出相关配置 - .build()); + .build()) + ); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/ollama/OllamaStructureTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/ollama/OllamaStructureTest.java index 662eba16b..b3b53cde8 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/ollama/OllamaStructureTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/ollama/OllamaStructureTest.java @@ -78,7 +78,7 @@ public class OllamaStructureTest extends MockAITest { new SystemMessage("你是一位数学辅导老师"), new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1")); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.DN_JSON) @@ -87,7 +87,8 @@ public class OllamaStructureTest extends MockAITest { .responseType(ResponseType.JSON) .targetType(MathReasoning.class) // 结构化输出相关配置 - .build()); + .build()) + ); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/openai/OpenAIStructureTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/openai/OpenAIStructureTest.java index 06989be63..12461c307 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/openai/OpenAIStructureTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/structure/openai/OpenAIStructureTest.java @@ -24,7 +24,6 @@ import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutionException; /** * OpenAI 结构化输出 测试 @@ -71,14 +70,14 @@ public class OpenAIStructureTest extends MockAITest { } @Test - public void testStructureStreaming() throws ExecutionException, InterruptedException { + public void testStructureStreaming() { setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED); List messages = Arrays.asList( new SystemMessage("你是一位数学辅导老师"), new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1")); - Flowable stream = chatModel.stream( + Flowable stream = Flowable.fromPublisher(chatModel.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -87,7 +86,8 @@ public class OpenAIStructureTest extends MockAITest { .responseType(ResponseType.JSON) .targetType(MathReasoning.class) // 结构化输出相关配置 - .build()); + .build()) + ); ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer()) .blockingLast() diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/dashscope/DashScopeToolTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/dashscope/DashScopeToolTest.java index 55a87bd20..e4c72f092 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/dashscope/DashScopeToolTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/dashscope/DashScopeToolTest.java @@ -88,7 +88,7 @@ public class DashScopeToolTest extends MockAITest { Collections.singletonList(toolRegistry.getTool("assemble_tool")) ); - Flowable stream = chatModelWithAutoToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithAutoToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -96,7 +96,7 @@ public class DashScopeToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -191,7 +191,7 @@ public class DashScopeToolTest extends MockAITest { Collections.singletonList(toolRegistry.getTool("assemble_tool")) ); - Flowable stream = chatModelWithManualToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -199,7 +199,7 @@ public class DashScopeToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -224,7 +224,7 @@ public class DashScopeToolTest extends MockAITest { MockConfigHolder.clear(); setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL_2); - stream = chatModelWithManualToolCall.stream( + stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -232,7 +232,7 @@ public class DashScopeToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); response = stream.blockingLast() .getFinalResponse(); diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/ollama/OllamaToolTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/ollama/OllamaToolTest.java index c52bb8a95..dd074bf4e 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/ollama/OllamaToolTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/ollama/OllamaToolTest.java @@ -84,7 +84,7 @@ public class OllamaToolTest extends MockAITest { ToolRegistry assembleTool = new StaticToolRegistry( Collections.singletonList(toolRegistry.getTool("assemble_tool"))); - Flowable stream = chatModelWithAutoToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithAutoToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.DN_JSON) @@ -92,7 +92,7 @@ public class OllamaToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -168,7 +168,7 @@ public class OllamaToolTest extends MockAITest { ToolRegistry assembleTool = new StaticToolRegistry( Collections.singletonList(toolRegistry.getTool("assemble_tool"))); - Flowable stream = chatModelWithManualToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.DN_JSON) @@ -176,7 +176,7 @@ public class OllamaToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -202,7 +202,7 @@ public class OllamaToolTest extends MockAITest { MockConfigHolder.clear(); setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL_2); - stream = chatModelWithManualToolCall.stream( + stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.DN_JSON) @@ -210,7 +210,7 @@ public class OllamaToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); response = stream.blockingLast() .getFinalResponse(); diff --git a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/openai/OpenAIToolTest.java b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/openai/OpenAIToolTest.java index 834506357..6b50cd805 100644 --- a/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/openai/OpenAIToolTest.java +++ b/liteflow-testcase-el/liteflow-testcase-el-ai/src/test/java/com/yomahub/liteflow/test/ai/model/tool/openai/OpenAIToolTest.java @@ -84,7 +84,7 @@ public class OpenAIToolTest extends MockAITest { ToolRegistry assembleTool = new StaticToolRegistry( Collections.singletonList(toolRegistry.getTool("assemble_tool"))); - Flowable stream = chatModelWithAutoToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithAutoToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -92,7 +92,7 @@ public class OpenAIToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -169,7 +169,7 @@ public class OpenAIToolTest extends MockAITest { ToolRegistry assembleTool = new StaticToolRegistry( Collections.singletonList(toolRegistry.getTool("assemble_tool"))); - Flowable stream = chatModelWithManualToolCall.stream( + Flowable stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -177,7 +177,7 @@ public class OpenAIToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); ChatResponse response = stream.blockingLast() .getFinalResponse(); @@ -202,7 +202,7 @@ public class OpenAIToolTest extends MockAITest { MockConfigHolder.clear(); setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL_2); - stream = chatModelWithManualToolCall.stream( + stream = Flowable.fromPublisher(chatModelWithManualToolCall.stream( chatRequestBuilder .streaming(true) .transportType(TransportType.SSE) @@ -210,7 +210,7 @@ public class OpenAIToolTest extends MockAITest { // 工具调用配置 .toolRegistry(assembleTool) .build() - ); + )); response = stream.blockingLast() .getFinalResponse();