Feat & Refactor: 引入 rxjava,将流式回调逻辑重构为响应式逻辑

This commit is contained in:
LuanY77
2025-11-14 18:34:06 +08:00
parent 036489e80b
commit a75ff4c3bd
55 changed files with 1619 additions and 2733 deletions

View File

@@ -42,7 +42,7 @@ public class LiteFlowAIAutoConfiguration {
@Bean
public StreamHandler streamHandler() {
return StreamHandler.builder().build();
return StreamHandler.passThrough();
}
@Bean

View File

@@ -1,290 +1,66 @@
package com.yomahub.liteflow.ai.context;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import io.reactivex.rxjava3.core.Flowable;
/**
* 流式输出处理器
* 响应式流处理器
* <p>
* 用户通过实现此接口,直接对 ChunkEvent 流进行响应式编程,
* 对流进行任意的变换、过滤、聚合等操作,最后返回处理后的 Flowable 流。
*
* @author 苍镜月
* @since 2.16.0
*/
@FunctionalInterface
public interface StreamHandler {
static Builder builder() {
return new Builder();
/**
* 对 ChunkEvent 流进行响应式处理
* <p>
* 用户可以在此方法中对事件流进行任意的响应式操作,如:
* - 过滤特定类型的事件
* - 转换事件数据
* - 聚合多个事件
* - 添加副作用(如日志、监控)
* - 处理错误和超时
* <p>
* 示例:
* <pre>
* StreamHandler handler = eventStream -> eventStream
* .filter(ChunkEvent::isChunk)
* .doOnNext(event -> System.out.println("Received chunk: " + event))
* .onErrorRetry(3);
* </pre>
*
* @param eventStream 原始的 ChunkEvent 流
* @return 处理后的 ChunkEvent 流
*/
Flowable<ChunkEvent> handle(Flowable<ChunkEvent> eventStream);
/**
* 创建一个 pass-through 处理器,不做任何转换直接返回原始流
*
* @return pass-through StreamHandler
*/
static StreamHandler passThrough() {
return eventStream -> eventStream;
}
/**
* 请求开始时的回调方法
* 创建一个组合多个处理器的处理器
*
* @param context 聊天上下文,包含处理过程中的状态和信息
* @param handlers 处理器数组
* @return 组合后的 StreamHandler
*/
void onStart(InteractContext context);
/**
* 请求关闭时的回调方法
*
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
void onClose(InteractContext context);
/**
* 处理过程中发生错误的回调方法。
*
* @param context 聊天上下文,包含处理过程中的状态和信息
* @param t
*/
void onError(InteractContext context, Throwable t);
/**
* 处理文本消息的回调方法。需要启用流式调用
*
* @param content 文本内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
String onText(String content, InteractContext context);
/**
* 处理思考消息的回调方法。需要启用流式调用
*
* @param content 思考内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
String onThinking(String content, InteractContext context);
/**
* 处理工具调用消息的回调方法。需要启用流式调用
*
* @param toolCalls 工具调用内容,可能是工具调用的结果或相关信息
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context);
/**
* 处理 Token 统计信息的回调方法。需要启用流式调用
*
* @param content Token 统计信息内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
Object onUsage(Object content, InteractContext context);
/**
* 处理基础信息/搜索结果的回调方法。需要启用流式调用
*
* @param content 基础信息或搜索结果内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
// TODO args
Object onGrounding(Object content, InteractContext context);
/**
* 消息处理完成的回调方法。(流式传输/阻塞式传输均会调用)
*
* @param response 处理后的聊天响应结果
* @param context 聊天上下文,包含处理过程中的状态和信息
* @return 处理后的结果
*/
ChatResponse onCompletion(ChatResponse response, InteractContext context);
/**
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。(流式传输/阻塞式传输均会调用)
*
* @param response 最终的聊天响应结果
* @param context 聊天上下文,包含处理过程中的状态和信息
* @return 处理后的结果
*/
ChatResponse onFinal(ChatResponse response, InteractContext context);
class Builder {
Consumer<InteractContext> onStart = context -> {
static StreamHandler composite(StreamHandler... handlers) {
return eventStream -> {
Flowable<ChunkEvent> result = eventStream;
for (StreamHandler handler : handlers) {
result = handler.handle(result);
}
return result;
};
Consumer<InteractContext> onClose = context -> {
};
BiConsumer<InteractContext, Throwable> onError = (context, t) -> {
throw new LiteFlowAIEngineException(t.getMessage(), t);
};
BiFunction<String, InteractContext, String> onText = (content, context) -> content;
BiFunction<String, InteractContext, String> onThinking = (content, context) -> content;
BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling = (toolCalls, context) -> toolCalls;
BiFunction<Object, InteractContext, Object> onUsage = (content, context) -> content;
BiFunction<Object, InteractContext, Object> onGrounding = (content, context) -> content;
BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion = (response, context) -> response;
BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal = (response, context) -> response;
/**
* 请求开始时的回调方法
*
* @param onStart 请求开始时的回调函数
* @see TransportListener#onStart(InteractContext)
*/
public Builder onStart(Consumer<InteractContext> onStart) {
this.onStart = onStart;
return this;
}
/**
* 请求结束时的回调方法
*
* @param onClose 请求结束时的回调函数
* @see TransportListener#onClose(InteractContext)
*/
public Builder onClose(Consumer<InteractContext> onClose) {
this.onClose = onClose;
return this;
}
/**
* 请求发生错误时的回调方法
*
* @param onError 请求发生错误时的回调函数
* @see TransportListener#onError(InteractContext, Throwable)
*/
public Builder onError(BiConsumer<InteractContext, Throwable> onError) {
this.onError = onError;
return this;
}
/**
* 文本消息的回调方法
*
* @param onText 文本消息的回调函数
* @see ChunkCallbackTransformer#onText(String, InteractContext)
*/
public Builder onText(BiFunction<String, InteractContext, String> onText) {
this.onText = onText;
return this;
}
/**
* 思考消息的回调方法
*
* @param onThinking 思考消息的回调函数
* @see ChunkCallbackTransformer#onThinking(String, InteractContext)
*/
public Builder onThinking(BiFunction<String, InteractContext, String> onThinking) {
this.onThinking = onThinking;
return this;
}
/**
* 工具调用消息的回调方法
*
* @param onToolsCalling 工具调用消息的回调函数
* @see ChunkCallbackTransformer#onToolsCalling(List, InteractContext)
*/
public Builder onToolsCalling(BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling) {
this.onToolsCalling = onToolsCalling;
return this;
}
/**
* Token 统计信息的回调方法
*
* @param onUsage Token 统计信息的回调函数
* @see ChunkCallbackTransformer#onUsage(Object, InteractContext)
*/
public Builder onUsage(BiFunction<Object, InteractContext, Object> onUsage) {
this.onUsage = onUsage;
return this;
}
/**
* 基础信息/搜索结果的回调方法
*
* @param onGrounding 基础信息/搜索结果的回调函数
* @see ChunkCallbackTransformer#onGrounding(Object, InteractContext)
*/
public Builder onGrounding(BiFunction<Object, InteractContext, Object> onGrounding) {
this.onGrounding = onGrounding;
return this;
}
/**
* 请求完成时的回调方法(流式传输/阻塞式传输均会调用)
*
* @param onCompletion 请求完成时的回调函数
* @see ResultHandler#onCompletion(ChatResponse, InteractContext)
*/
public Builder onCompletion(BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion) {
this.onCompletion = onCompletion;
return this;
}
/**
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。(流式传输/阻塞式传输均会调用)
*
* @param onFinal 请求最终结果的回调函数
* @see ResultHandler#onFinal(ChatResponse, InteractContext)
*/
public Builder onFinal(BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal) {
this.onFinal = onFinal;
return this;
}
public StreamHandler build() {
return new StreamHandler() {
@Override
public void onStart(InteractContext context) {
onStart.accept(context);
}
@Override
public void onClose(InteractContext context) {
onClose.accept(context);
}
@Override
public void onError(InteractContext context, Throwable t) {
onError.accept(context, t);
}
@Override
public String onText(String content, InteractContext context) {
return onText.apply(content, context);
}
@Override
public String onThinking(String content, InteractContext context) {
return onThinking.apply(content, context);
}
@Override
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
return onToolsCalling.apply(toolCalls, context);
}
@Override
public Object onUsage(Object content, InteractContext context) {
return onUsage.apply(content, context);
}
@Override
public Object onGrounding(Object content, InteractContext context) {
return onGrounding.apply(content, context);
}
@Override
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
return onCompletion.apply(response, context);
}
@Override
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
return onFinal.apply(response, context);
}
};
}
}
}

View File

@@ -1,7 +1,6 @@
package com.yomahub.liteflow.ai.parse.assemble;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
import com.yomahub.liteflow.ai.domain.dto.ParsedChatAnnotationConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
@@ -32,23 +31,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
protected ChatRequest doAssemble(ParsedChatAnnotationConfig annotationConfig, ModelConfigAggregator config, ChatContext context) {
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
// 1. 连接 StreamHandler 回调
StreamHandler streamHandler = context.getStreamHandler();
if (Objects.nonNull(streamHandler)) {
LOG.info("Connecting StreamHandler to ChatRequest");
builder.onStart(streamHandler::onStart)
.onClose(streamHandler::onClose)
.onError(streamHandler::onError)
.onText(streamHandler::onText)
.onThinking(streamHandler::onThinking)
.onToolsCalling(streamHandler::onToolsCalling)
.onUsage(streamHandler::onUsage)
.onGrounding(streamHandler::onGrounding)
.onCompletion(streamHandler::onCompletion)
.onFinal(streamHandler::onFinal);
}
// 2. ChatOptions
// 1. ChatOptions
ChatOptions.Builder<?> optionsBuilder = ChatOptions.builder();
setIfPresent(optionsBuilder::temperature, config.getTemperature());
setIfPresent(optionsBuilder::topP, config.getTopP());
@@ -58,7 +41,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
setIfPresent(optionsBuilder::enableThinking, config.getEnableThinking());
builder.options(optionsBuilder.build());
// 3. Message
// 2. Message
List<Message> messages = Optional.ofNullable(annotationConfig.getHistory())
.orElse(new ArrayList<>());
if (messages.isEmpty()) {
@@ -68,7 +51,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
builder.messages(messages);
// 4. streaming 相关参数
// 3. streaming 相关参数
builder.streaming(annotationConfig.isStreaming());
builder.transportType(annotationConfig.getTransportType());

View File

@@ -1,7 +1,6 @@
package com.yomahub.liteflow.ai.parse.assemble;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
import com.yomahub.liteflow.ai.domain.dto.ParsedClassifyAnnotationConfig;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
@@ -35,23 +34,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
protected ChatRequest doAssemble(ParsedClassifyAnnotationConfig annotationConfig, ModelConfigAggregator config, ChatContext context) {
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
// 1. 连接 StreamHandler 回调
StreamHandler streamHandler = context.getStreamHandler();
if (Objects.nonNull(streamHandler)) {
LOG.info("Connecting StreamHandler to ChatRequest");
builder.onStart(streamHandler::onStart)
.onClose(streamHandler::onClose)
.onError(streamHandler::onError)
.onText(streamHandler::onText)
.onThinking(streamHandler::onThinking)
.onToolsCalling(streamHandler::onToolsCalling)
.onUsage(streamHandler::onUsage)
.onGrounding(streamHandler::onGrounding)
.onCompletion(streamHandler::onCompletion)
.onFinal(streamHandler::onFinal);
}
// 2. ChatOptions
// 1. ChatOptions
ChatOptions.Builder<?> optionsBuilder = ChatOptions.builder();
setIfPresent(optionsBuilder::temperature, config.getTemperature());
setIfPresent(optionsBuilder::topP, config.getTopP());
@@ -62,7 +45,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
builder.options(optionsBuilder.build());
// 3. Message
// 2. Message
List<Message> messages = Optional.ofNullable(annotationConfig.getHistory())
.orElse(new ArrayList<>());
if (messages.isEmpty()) {
@@ -77,7 +60,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
builder.messages(messages);
// 4. streaming 相关参数
// 3. streaming 相关参数
// 定死使用阻塞式传输
builder.streaming(false);
builder.transportType(TransportType.HTTP);

View File

@@ -1,18 +1,16 @@
package com.yomahub.liteflow.ai.proxy.invocation;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
import com.yomahub.liteflow.ai.model.ModelFactory;
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
import com.yomahub.liteflow.ai.proxy.wrap.ChatProxyWrapBean;
import io.reactivex.rxjava3.core.Flowable;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* 聊天组件的调用处理器
@@ -39,53 +37,26 @@ public class ChatAIInvocationHandler extends AbstractAIInvocationHandler<ChatPro
ChatRequest chatRequest = processorContext.getModelRequest().toChatRequest();
if (chatRequest.isStreaming()) {
return processStreaming(chatModel, chatRequest);
return processStreaming(chatModel, chatRequest, processorContext.getChatContext().getStreamHandler());
} else {
return chatModel.chat(chatRequest);
}
}
private ChatResponse processStreaming(ChatModel chatModel, ChatRequest chatRequest) {
// 创建一个 CompletableFuture 用于异步处理,他将持有最终的 ChatResponse
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
// 获取用户传入的 ResultHandler
final ResultHandler externalResultHandler = chatRequest.getResultHandler();
// 定义内部 ResultHandler包装用户的处理逻辑
ResultHandler internalResultHandler = new ResultHandler() {
@Override
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
if (Objects.nonNull(externalResultHandler)) {
return externalResultHandler.onCompletion(response, context);
}
return response;
}
@Override
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
ChatResponse finalResponse = response;
try {
if (Objects.nonNull(externalResultHandler)) {
finalResponse = externalResultHandler.onFinal(response, context);
}
} finally {
future.complete(finalResponse);
}
return finalResponse;
}
};
// 设置内部 ResultHandler 到请求中
chatRequest.setResultHandler(internalResultHandler);
// 执行流式聊天请求
chatModel.stream(chatRequest);
// 阻塞等待 CompletableFuture 完成,并返回最终的 ChatResponse
try {
return future.get();
} catch (InterruptedException | ExecutionException e) {
Thread.currentThread().interrupt();
throw new LiteFlowAIException("error while processing streaming chat request", e);
private ChatResponse processStreaming(ChatModel chatModel, ChatRequest chatRequest, StreamHandler streamHandler) {
// 获取用户定义的 StreamHandler如果有
if (Objects.isNull(streamHandler)) {
// 如果未提供,使用 pass-through 处理器
streamHandler = StreamHandler.passThrough();
}
// 获取原始的事件流
Flowable<ChunkEvent> eventStream = chatModel.stream(chatRequest);
// 应用用户的 StreamHandler 进行响应式转换
Flowable<ChunkEvent> handledStream = streamHandler.handle(eventStream);
return handledStream.blockingLast()
.getFinalResponse();
}
}

View File

@@ -57,6 +57,11 @@
<groupId>com.github.victools</groupId>
<artifactId>jsonschema-module-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
</dependency>
</dependencies>
</project>

View File

@@ -1,8 +1,10 @@
package com.yomahub.liteflow.ai.engine.interact;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.Flowable;
import java.util.concurrent.CompletableFuture;
@@ -15,7 +17,7 @@ import java.util.concurrent.CompletableFuture;
public interface InteractClient {
void stream(ChatConfig config, ChatRequest request);
Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request);
ChatResponse chat(ChatConfig config, ChatRequest request);

View File

@@ -1,13 +1,11 @@
package com.yomahub.liteflow.ai.engine.interact;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformerFactory;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.log.EngineLog;
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
@@ -16,226 +14,241 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.disposables.CompositeDisposable;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
/**
* 大模型交互客户端,统筹消息传输、协议转换等功能。
* 大模型交互客户端
*
* @author 苍镜月
* @since 2.16.0
*/
public class LlmInteractClient implements InteractClient {
private static final EngineLog LOG = EngineLogManager.getLogger(LlmInteractClient.class);
@Override
public void stream(ChatConfig config, ChatRequest request) {
InteractManager manager = new InteractManager(config, request);
manager.executeStreaming();
}
/**
* 同步调用
*
* @param config 聊天配置
* @param request 聊天请求
* @return 聊天响应
*/
@Override
public ChatResponse chat(ChatConfig config, ChatRequest request) {
InteractManager interactManager = new InteractManager(config, request);
return interactManager.executeBlocking();
}
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
Transport transport = request.getTransportType().getTransportInstance();
@Override
public CompletableFuture<ChatResponse> chatAsync(ChatConfig config, ChatRequest request) {
CompletableFuture<ChatResponse> future = new CompletableFuture<>();
ChatResponse response; // 用于存储最终响应
CompletableFuture.runAsync(() -> {
try {
future.complete(chat(config, request));
} catch (Exception e) {
future.completeExceptionally(new LiteFlowAIEngineException("异步调用大模型失败", e));
while (true) {
InteractContext context = new InteractContext();
// 1. 执行单轮对话
String responseBody = transport.startBlocking(config, request);
response = protocolTransformer.transformBlockingResponse(responseBody, context);
// 2. 检查循环的 "退出条件"
// 如果 AI 没有要求工具调用,或者配置禁用了自动调用,
// 那么这就是最终答案,跳出循环。
if (!response.hasToolCalls() || !config.isAutoToolCallEnabled()) {
break;
}
});
return future;
// 3. 获取并执行工具调用
List<ToolCall> toolCalls = response.getOutput().getToolCalls();
// 目前只执行单轮单次的工具调用
ToolMessage toolMessage = request.getToolRegistry().executeToolCall(toolCalls.get(0));
// 4. 构建下一轮对话的上下文
buildNextRoundMessages(request, response, toolMessage);
}
// 5. 返回循环中断时的最后一个响应
return response;
}
/**
* 内部执行器
* 异步调用
*
* @param config 聊天配置
* @param request 聊天请求
* @return 异步聊天响应
*/
private static class InteractManager {
private final ChatConfig config;
private final ChatRequest request;
private final InteractContext context;
private final ChunkProcessPipeline pipeline;
private final TransportListener externalTransportListener;
private final ResultHandler resultHandler;
private final Transport transport;
private final InternalTransportListener internalTransportListener;
@Override
public CompletableFuture<ChatResponse> chatAsync(ChatConfig config, ChatRequest request) {
return CompletableFuture.supplyAsync(() -> chat(config, request));
}
public InteractManager(ChatConfig config, ChatRequest request) {
this.config = config;
this.request = request;
this.context = new InteractContext();
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
this.pipeline = request.isStreaming()
? ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, request.getChunkCallbackTransformer())
: ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
this.transport = request.getTransportType().getTransportInstance();
this.externalTransportListener = request.getTransportListener();
this.resultHandler = request.getResultHandler();
this.internalTransportListener = new InternalTransportListener();
}
/**
* 创建响应式流事件管道
*
* @param config 聊天配置
* @param request 聊天请求
* @return 包含 ChunkEvent 的流
*/
public Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request) {
InteractContext context = new InteractContext();
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
/**
* 流式调用
*/
public void executeStreaming() {
// 启动传输,使用内部监听器
transport.start(config, request, pipeline, internalTransportListener);
}
return Flowable.just(ChunkEvent.start(context))
.concatWith(
streamRecursive(config, request, context, protocolTransformer)
);
}
/**
* 内部传输监听器,用于处理流式调用的各种事件
*/
private class InternalTransportListener implements TransportListener {
/**
* 内部递归流逻辑
*
* @param config 聊天配置
* @param request 聊天请求
* @param context 当前轮次的交互上下文
* @param protocolTransformer 协议转换器
* @return 包含 ChunkEvent 的流
*/
private Flowable<ChunkEvent> streamRecursive(ChatConfig config, ChatRequest request,
InteractContext context, ProtocolTransformer protocolTransformer) {
return Flowable.create(emitter -> {
CompositeDisposable compositeDisposable = new CompositeDisposable();
emitter.setDisposable(compositeDisposable);
@Override
public void onStart(InteractContext context) {
externalTransportListener.onStart(context);
}
Transport transport = request.getTransportType().getTransportInstance();
@Override
public void onClose(InteractContext context) {
ChatResponse finalResponse = null;
// 判断是否需要继续进行工具调用
boolean isContinuingWithToolCall = false;
try {
// 构造最终响应
finalResponse = pipeline.buildFinalStreamingResponse();
try {
// 获取响应式流
Disposable transportDisposable = transport.startStreaming(config, request)
.subscribe(
// onNext: 处理每个原始 JSON 分块
rawChunk -> {
try {
// 使用协议转换器转换为框架标准格式
StreamingProtocolChunk protocolChunk = protocolTransformer.transformStreamingChunk(rawChunk, context);
// 调用结果处理器的完成回调
finalResponse = resultHandler.onCompletion(finalResponse, context);
// 根据分块信息更新上下文
updateContextFromChunk(context, protocolChunk);
// 工具调用
if (finalResponse.hasToolCalls() && config.isAutoToolCallEnabled()) {
// 1. 获取并执行工具调用
List<ToolCall> toolCalls = finalResponse.getOutput().getToolCalls();
// 目前只执行单轮单次的工具调用
ToolMessage toolMessage = executeToolCall(toolCalls.get(0), request.getToolRegistry());
// 发送分块事件
if (!emitter.isCancelled()) {
emitter.onNext(ChunkEvent.chunk(rawChunk, protocolChunk, context));
}
} catch (Exception e) {
if (!emitter.isCancelled()) {
emitter.onError(e); // 转发解析错误
}
}
},
// onError: 转发下游错误
error -> {
if (!emitter.isCancelled()) {
emitter.onError(error);
}
try {
transport.close(); // 确保关闭当前轮次的 transport
} catch (Exception e) {
LOG.warn("Error closing transport after onError", e);
}
},
// onComplete: 流完成处理
() -> {
try {
// 构建最终响应
ChatResponse finalResponse = protocolTransformer.transformStreamingResponse(context);
// 2. 构建下一轮对话的上下文
buildNextRoundMessages(request, finalResponse, toolMessage);
// 检查是否需要工具调用
if (finalResponse.hasToolCalls() && config.isAutoToolCallEnabled()) {
if (emitter.isCancelled()) {
return;
}
// 3. 递归调用下一轮消息
// 设置标志位为 true避免关闭 transport
isContinuingWithToolCall = true;
new InteractManager(config, request).executeStreaming();
}
// 1. 执行工具调用
List<ToolCall> toolCalls = finalResponse.getOutput().getToolCalls();
ToolMessage toolMessage = request.getToolRegistry().executeToolCall(toolCalls.get(0));
} catch (Exception e) {
onError(context, e);
} finally {
// 不需要进行工具调用时才进行清理操作
if (!isContinuingWithToolCall) {
try {
// 调用外部监听器的关闭事件
externalTransportListener.onClose(context);
} catch (Exception e) {
onError(context, e);
} finally {
// 清理资源
cleanup(finalResponse);
}
}
// 如果需要进行工具调用,将清理的责任委托给下一轮调用
// 2. 构建下一轮消息
buildNextRoundMessages(request, finalResponse, toolMessage);
// 3. 为下一轮创建新的上下文
InteractContext nextContext = new InteractContext();
// 4. 递归调用 "循环体",并将事件转发给当前 emitter
Disposable recursiveDisposable = streamRecursive(config, request, nextContext, protocolTransformer)
.subscribe(
emitter::onNext,
emitter::onError,
emitter::onComplete
);
compositeDisposable.add(recursiveDisposable);
} else {
if (!emitter.isCancelled()) {
// 没有工具调用,发送 complete 事件并完成流
emitter.onNext(ChunkEvent.complete(context, finalResponse));
emitter.onComplete();
}
}
} catch (Exception e) {
if (!emitter.isCancelled()) {
emitter.onError(e); // 转发 onComplete 逻辑中的错误
}
} finally {
try {
transport.close(); // 确保关闭当前轮次的 transport
} catch (Exception e) {
LOG.warn("Error closing transport", e);
}
}
}
);
compositeDisposable.add(transportDisposable);
} catch (Exception e) {
if (!emitter.isCancelled()) {
emitter.onError(e);
}
}
}, BackpressureStrategy.BUFFER);
}
@Override
public void onError(InteractContext context, Throwable t) {
externalTransportListener.onError(context, t);
}
/**
* 根据分块信息更新交互上下文
*
* @param context 交互上下文
* @param chunk 协议分块
*/
private void updateContextFromChunk(InteractContext context, StreamingProtocolChunk chunk) {
if (Objects.isNull(chunk)) {
return;
}
public ChatResponse executeBlocking() {
ChatResponse response = null;
try {
externalTransportListener.onStart(context);
response = transport.startBlocking(config, request, pipeline);
response = resultHandler.onCompletion(response, context);
// 处理工具调用
if (response.hasToolCalls() && config.isAutoToolCallEnabled()) {
// 1. 获取并执行工具调用
List<ToolCall> toolCalls = response.getOutput().getToolCalls();
// 目前只执行单轮单次的工具调用
ToolMessage toolMessage = executeToolCall(toolCalls.get(0), request.getToolRegistry());
// 2. 构建下一轮对话的上下文
buildNextRoundMessages(request, response, toolMessage);
// 3. 递归调用下一轮消息
response = new InteractManager(config, request).executeBlocking();
}
return response;
} catch (Exception e) {
this.externalTransportListener.onError(context, e);
return response;
} finally {
cleanup(response);
}
}
/**
* 执行工具调用
*
* @param toolCall 工具调用信息
* @param toolRegistry 工具注册中心
* @return 工具调用结果
*/
private ToolMessage executeToolCall(ToolCall toolCall, ToolRegistry toolRegistry) {
// 找到对应的工具回调
ToolCallBack toolCallBack = toolRegistry.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new LiteFlowAIEngineException(
"Unable to find target tool with tool name: " + toolCall.getName()));
// 调用工具
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
// 返回工具调用结果
return new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
}
/**
* 工具调用之后构建下一轮的消息列表
*
* @param request 原始请求
* @param response 大模型响应(要求工具调用)
* @param toolMessage 工具调用结果
*/
private void buildNextRoundMessages(ChatRequest request, ChatResponse response, ToolMessage toolMessage) {
List<Message> messagesHistory = request.getMessages();
messagesHistory.add(response.getOutput());
messagesHistory.add(toolMessage);
}
/**
* 清理资源
*/
private void cleanup(ChatResponse response) {
try {
resultHandler.onFinal(response, context);
} catch (Exception e) {
LOG.error("ResultHandler.onFinal 执行失败: {}", e.getMessage());
} finally {
transport.close();
}
switch (chunk.getType()) {
case TEXT:
context.addText((String) chunk.getData());
break;
case THINKING:
context.addThinking((String) chunk.getData());
break;
default:
// 其他类型暂不处理
break;
}
}
/**
* 为下一轮对话构建消息列表
*
* @param request 聊天请求
* @param response 聊天响应
* @param toolMessage 工具调用结果
*/
private void buildNextRoundMessages(ChatRequest request, ChatResponse response, ToolMessage toolMessage) {
List<Message> messages = request.getMessages();
messages.add(response.getOutput());
messages.add(toolMessage);
}
}

View File

@@ -1,129 +0,0 @@
package com.yomahub.liteflow.ai.engine.interact.callbacks;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkTransformer;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import java.util.List;
/**
* 流式消息处理管道的回调接口。根据块数据的类型进行具体回调
*
* @author 苍镜月
* @since 2.16.0
*/
public interface ChunkCallbackTransformer extends ChunkTransformer {
static ChunkCallbackTransformer getDefault() {
return new ChunkCallbackTransformer() {
@Override
public String onText(String content, InteractContext context) {
return content;
}
@Override
public String onThinking(String content, InteractContext context) {
return content;
}
@Override
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
return toolCalls;
}
@Override
public Object onUsage(Object content, InteractContext context) {
return content;
}
@Override
public Object onGrounding(Object content, InteractContext context) {
return content;
}
};
}
@SuppressWarnings("unchecked")
default void transform(StreamingProtocolChunk transformedChunk, InteractContext context) {
switch (transformedChunk.getType()) {
case TEXT:
String textContent = (String) transformedChunk.getData();
String callbackText = this.onText(textContent, context);
transformedChunk.setData(callbackText);
context.addText(callbackText);
break;
case THINKING:
String thinkingContent = (String) transformedChunk.getData();
String callbackThinking = this.onThinking(thinkingContent, context);
transformedChunk.setData(callbackThinking);
context.addThinking(callbackThinking);
break;
case TOOL_CALLS:
List<ToolCall> toolCallsContent = (List<ToolCall>) transformedChunk.getData();
List<ToolCall> responseToolCalls = this.onToolsCalling(toolCallsContent, context);
transformedChunk.setData(responseToolCalls);
break;
case USAGE:
Object usageContent = transformedChunk.getData();
Object responseUsage = this.onUsage(usageContent, context);
transformedChunk.setData(responseUsage);
break;
case BASE64_IMAGE:
case DATA:
case ERROR:
// TODO 异常处理
default:
// 其他类型暂不处理
break;
}
}
/**
* 处理文本消息的回调方法。
*
* @param content 文本内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
String onText(String content, InteractContext context);
/**
* 处理思考消息的回调方法。
*
* @param content 思考内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
String onThinking(String content, InteractContext context);
/**
* 处理工具调用消息的回调方法。
*
* @param toolCalls 工具调用内容,可能是工具调用的结果或相关信息
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context);
/**
* 处理 Token 统计信息的回调方法
*
* @param content Token 统计信息内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
// TODO args
Object onUsage(Object content, InteractContext context);
/**
* 处理基础信息/搜索结果的回调方法
*
* @param content 基础信息或搜索结果内容
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
// TODO args
Object onGrounding(Object content, InteractContext context);
@Override
default String getTransformerType() {
return "ChunkCallback";
}
}

View File

@@ -1,47 +0,0 @@
package com.yomahub.liteflow.ai.engine.interact.callbacks;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
/**
* 对消息全部发送完毕并转换后的结果进行处理的接口。
*
* @author 苍镜月
* @since 2.16.0
*/
public interface ResultHandler {
static ResultHandler getDefault() {
return new ResultHandler() {
@Override
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
return response;
}
@Override
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
return response;
}
};
}
/**
* 消息处理完成的回调方法。
*
* @param response 处理后的聊天响应结果
* @param context 聊天上下文,包含处理过程中的状态和信息
* @return 处理后的结果
*/
ChatResponse onCompletion(ChatResponse response, InteractContext context);
/**
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。
*
* @param response 最终的聊天响应结果
* @param context 聊天上下文,包含处理过程中的状态和信息
* @return 处理后的结果
*/
ChatResponse onFinal(ChatResponse response, InteractContext context);
}

View File

@@ -0,0 +1,220 @@
package com.yomahub.liteflow.ai.engine.interact.chunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import java.util.Objects;
/**
* 流式事件封装
* <p>
* 该类封装了整个流式交互过程中产生的各种事件,包括:
* - 请求开始事件 (START)
* - 数据分块事件 (CHUNK)
* - 请求完成事件 (COMPLETE)
* - 错误事件 (ERROR)
* <p>
* 每个事件都包含相关的上下文信息,允许用户在处理流时访问中间状态。
*
* @author 苍镜月
* @since 2.16.0
*/
public class ChunkEvent {
/**
* 事件类型枚举
*/
public enum EventType {
/**
* 流开始事件 - 在流开始处理时触发
*/
START,
/**
* 数据分块事件 - 接收到新的数据分块时触发
*/
CHUNK,
/**
* 流完成事件 - 所有数据处理完毕时触发
*/
COMPLETE,
/**
* 错误事件 - 处理过程中发生错误时触发
*/
ERROR
}
/**
* 事件类型
*/
private final EventType eventType;
/**
* 原始的 JSON 响应块(协议相关的原始格式)
* 在 CHUNK 事件中可能有值,其他事件类型为 null
*/
private final String rawChunk;
/**
* 转换后的框架标准响应块
* 在 CHUNK 事件中可能有值,其他事件类型为 null
*/
private final StreamingProtocolChunk transformedChunk;
/**
* 交互上下文 - 包含当前的交互状态
* 在 CHUNK 和 COMPLETE 事件中有值START 事件为初始化的上下文ERROR 事件可能为 null
*/
private final InteractContext context;
/**
* 错误信息 - 仅在 ERROR 事件中有值
*/
private final Throwable error;
/**
* 最终响应 - 仅在 COMPLETE 事件中有值
*/
private final ChatResponse finalResponse;
/**
* 私有构造函数
*/
private ChunkEvent(EventType eventType, String rawChunk, StreamingProtocolChunk transformedChunk,
InteractContext context, Throwable error, ChatResponse finalResponse) {
this.eventType = Objects.requireNonNull(eventType, "eventType cannot be null");
this.rawChunk = rawChunk;
this.transformedChunk = transformedChunk;
this.context = context;
this.error = error;
this.finalResponse = finalResponse;
}
/**
* 创建开始事件
*
* @param context 交互上下文
* @return 开始事件
*/
public static ChunkEvent start(InteractContext context) {
return new ChunkEvent(EventType.START, null, null, context, null, null);
}
/**
* 创建数据分块事件
*
* @param rawChunk 原始 JSON 分块
* @param transformedChunk 转换后的框架标准分块
* @param context 交互上下文
* @return 数据分块事件
*/
public static ChunkEvent chunk(String rawChunk, StreamingProtocolChunk transformedChunk, InteractContext context) {
return new ChunkEvent(EventType.CHUNK, rawChunk, transformedChunk, context, null, null);
}
/**
* 创建完成事件
*
* @param context 交互上下文
* @param finalResponse 最终聊天响应
* @return 完成事件
*/
public static ChunkEvent complete(InteractContext context, ChatResponse finalResponse) {
return new ChunkEvent(EventType.COMPLETE, null, null, context, null, finalResponse);
}
/**
* 创建错误事件
*
* @param error 错误信息
* @return 错误事件
*/
public static ChunkEvent error(Throwable error) {
return new ChunkEvent(EventType.ERROR, null, null, null, error, null);
}
/**
* 创建错误事件 - 带上下文
*
* @param error 错误信息
* @param context 交互上下文
* @return 错误事件
*/
public static ChunkEvent error(Throwable error, InteractContext context) {
return new ChunkEvent(EventType.ERROR, null, null, context, error, null);
}
// ==================== Getters ====================
public EventType getEventType() {
return eventType;
}
public String getRawChunk() {
return rawChunk;
}
public StreamingProtocolChunk getTransformedChunk() {
return transformedChunk;
}
public InteractContext getContext() {
return context;
}
public Throwable getError() {
return error;
}
public ChatResponse getFinalResponse() {
return finalResponse;
}
// ==================== Convenience Methods ====================
/**
* 是否为开始事件
*
* @return true 如果是开始事件
*/
public boolean isStart() {
return eventType == EventType.START;
}
/**
* 是否为分块事件
*
* @return true 如果是分块事件
*/
public boolean isChunk() {
return eventType == EventType.CHUNK;
}
/**
* 是否为完成事件
*
* @return true 如果是完成事件
*/
public boolean isComplete() {
return eventType == EventType.COMPLETE;
}
/**
* 是否为错误事件
*
* @return true 如果是错误事件
*/
public boolean isError() {
return eventType == EventType.ERROR;
}
@Override
public String toString() {
return "ChunkEvent{" +
"eventType=" + eventType +
", rawChunk=" + (rawChunk != null ? rawChunk.substring(0, Math.min(50, rawChunk.length())) + "..." : "null") +
", transformedChunk=" + transformedChunk +
", context=" + context +
", error=" + error +
", finalResponse=" + finalResponse +
'}';
}
}

View File

@@ -1,4 +1,4 @@
package com.yomahub.liteflow.ai.engine.interact.pipeline;
package com.yomahub.liteflow.ai.engine.interact.chunk;
import cn.hutool.core.collection.CollectionUtil;
import com.yomahub.liteflow.ai.engine.model.output.TokenUsage;

View File

@@ -1,4 +1,4 @@
package com.yomahub.liteflow.ai.engine.interact.protocol;
package com.yomahub.liteflow.ai.engine.interact.chunk;
/**
* 流式消息块

View File

@@ -1,4 +1,4 @@
package com.yomahub.liteflow.ai.engine.interact.protocol;
package com.yomahub.liteflow.ai.engine.interact.chunk;
import java.util.Arrays;
import java.util.Map;

View File

@@ -1,91 +0,0 @@
package com.yomahub.liteflow.ai.engine.interact.pipeline;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
/**
* 流式消息处理管道
*
* @author 苍镜月
* @since 2.16.0
*/
public class ChunkProcessPipeline {
private final InteractContext context;
private final ProtocolTransformer protocolTransformer;
private final ChunkCallbackTransformer chunkCallbackTransformer;
private ChunkProcessPipeline(
InteractContext context,
ProtocolTransformer protocolTransformer,
ChunkCallbackTransformer chunkCallbackTransformer) {
this.context = context;
this.protocolTransformer = protocolTransformer;
this.chunkCallbackTransformer = chunkCallbackTransformer;
}
/**
* 创建流式消息处理管道实例。
*
* @param context 聊天上下文,包含会话状态等
* @param protocolTransformer 协议转换器,将不同厂商大模型响应转换为 LiteFlow-AI 支持的统一格式
* @param chunkCallbackTransformer 消息处理管道的回调接口,根据块数据的类型进行具体回调
* @return 流式消息处理管道实例
*/
public static ChunkProcessPipeline createStreamingPipeline(InteractContext context, ProtocolTransformer protocolTransformer, ChunkCallbackTransformer chunkCallbackTransformer) {
return new ChunkProcessPipeline(context, protocolTransformer, chunkCallbackTransformer);
}
/**
* 创建阻塞式调用的消息处理管道实例。
*
* @param context 聊天上下文,包含会话状态等
* @param protocolTransformer 协议转换器,将不同厂商大模型响应转换为 LiteFlow-AI 支持的统一格式
* @return 阻塞式调用的消息处理管道实例
*/
public static ChunkProcessPipeline createBlockingPipeline(InteractContext context, ProtocolTransformer protocolTransformer) {
return new ChunkProcessPipeline(context, protocolTransformer, null);
}
/**
* 将流式响应的 chunk 转换为 LiteFlow-AI 支持的格式。
*
* @param chunk 流式响应的 chunk
* @return 当前流式响应是否结束, true 表示结束false 表示未结束
*/
public boolean processStreaming(String chunk) {
// 协议转换 chunk
StreamingProtocolChunk transformedChunk = protocolTransformer.transformStreamingChunk(chunk, context);
// 回调处理器进行回调
chunkCallbackTransformer.transform(transformedChunk, context);
return context.isFinished();
}
/**
* 将 阻塞式调用 的 Response 转换为 LiteFlow-AI 支持的 AssistantMessage。
*
* @param blockingResponse OkHttp 的 Response 对象
* @return 转换后的 ChatResponse
*/
public ChatResponse processBlocking(String blockingResponse) {
return protocolTransformer.transformBlockingResponse(blockingResponse, context);
}
/**
* 构造流式调用的最终响应
*
* @return 最终的 ChatResponse
*/
public ChatResponse buildFinalStreamingResponse() {
return protocolTransformer.transformStreamingResponse(context);
}
public InteractContext getContext() {
return context;
}
}

View File

@@ -1,14 +0,0 @@
package com.yomahub.liteflow.ai.engine.interact.pipeline;
/**
* 消息转换器接口
*
* @author 苍镜月
* @since 2.16.0
*/
public interface ChunkTransformer {
String getTransformerType();
}

View File

@@ -4,7 +4,9 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;

View File

@@ -1,7 +1,7 @@
package com.yomahub.liteflow.ai.engine.interact.protocol;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkTransformer;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
/**
@@ -12,7 +12,7 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
* @since 2.16.0
*/
public interface ProtocolTransformer extends ChunkTransformer {
public interface ProtocolTransformer {
/**
* 将流式响应的 chunk 转换为 LiteFlow-AI 支持的格式。
@@ -41,9 +41,4 @@ public interface ProtocolTransformer extends ChunkTransformer {
ChatResponse transformBlockingResponse(String blockingResponse, InteractContext context);
String getProviderName();
@Override
default String getTransformerType() {
return "ProtocolTransformer";
}
}

View File

@@ -1,9 +1,8 @@
package com.yomahub.liteflow.ai.engine.interact.transport;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.Flowable;
import java.util.Map;
@@ -19,22 +18,20 @@ public interface Transport {
/**
* 启动流式传输
*
* @param config 聊天配置
* @param request 聊天请求
* @param pipeline 处理管道
* @param listener 传输监听器
* @param config 聊天配置
* @param request 聊天请求
* @return 流式数据流
*/
void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener);
Flowable<String> startStreaming(ChatConfig config, ChatRequest request);
/**
* 启动阻塞式传输
*
* @param config 聊天配置
* @param request 聊天请求
* @param pipeline 处理管道
* @param config 聊天配置
* @param request 聊天请求
* @return 聊天响应
*/
ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline);
String startBlocking(ChatConfig config, ChatRequest request);
/**
* 关闭传输

View File

@@ -1,53 +0,0 @@
package com.yomahub.liteflow.ai.engine.interact.transport;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
/**
* 传输监听器接口
*
* @author 苍镜月
* @since 2.16.0
*/
public interface TransportListener {
static TransportListener getDefault() {
return new TransportListener() {
@Override
public void onStart(InteractContext context) {
}
@Override
public void onClose(InteractContext context) {
}
@Override
public void onError(InteractContext context, Throwable t) {
throw new LiteFlowAIEngineException(t.getMessage(), t);
}
};
}
/**
* 请求开始时的回调方法
*
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
void onStart(InteractContext context);
/**
* 请求关闭时的回调方法
*
* @param context 聊天上下文,包含处理过程中的状态和信息
*/
void onClose(InteractContext context);
/**
* 处理过程中发生错误的回调方法。
*
* @param context 聊天上下文,包含处理过程中的状态和信息
* @param t
*/
void onError(InteractContext context, Throwable t);
}

View File

@@ -2,14 +2,14 @@ package com.yomahub.liteflow.ai.engine.interact.transport.impl;
import cn.hutool.core.util.StrUtil;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.log.EngineLog;
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.disposables.Disposable;
import okhttp3.*;
import org.jetbrains.annotations.NotNull;
@@ -18,6 +18,8 @@ import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/**
* 一个用于处理"换行符分隔的 JSON"(Delimiter-Newline JSON)流的 LLM 客户端。(Ollama流式输出应使用该传输方式)
@@ -34,92 +36,102 @@ import java.util.Objects;
* @since 2.16.0
*/
public class DnJsonTransport implements Transport, Callback {
public class DnJsonTransport implements Transport {
private static final EngineLog LOG = EngineLogManager.getLogger(DnJsonTransport.class);
private ChunkProcessPipeline pipeline;
private TransportListener listener;
private OkHttpClient client;
private boolean isStop = false;
private boolean isLogResponse = false;
private final AtomicReference<OkHttpClient> clientRef = new AtomicReference<>();
private final AtomicBoolean isStop = new AtomicBoolean(false);
@Override
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
this.pipeline = pipeline;
this.listener = listener;
this.isLogResponse = config.isLogResponse();
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
return Flowable.create(emitter -> {
Request dnJsonRequest = buildDnJsonRequest(config, request);
Request dnJsonRequest = buildDnJsonRequest(config, request);
OkHttpClient client = new okhttp3.OkHttpClient.Builder()
.connectTimeout(config.getConnectTimeout())
.readTimeout(config.getReadTimeout())
.build();
client = new okhttp3.OkHttpClient.Builder()
.connectTimeout(config.getConnectTimeout())
.readTimeout(config.getReadTimeout())
.build();
this.clientRef.set(client);
this.listener.onStart(pipeline.getContext());
// 异步执行请求this 作为回调处理器
this.client.newCall(dnJsonRequest).enqueue(this);
client.newCall(dnJsonRequest).enqueue(new Callback() {
@Override
public void onFailure(@NotNull Call call, @NotNull IOException e) {
emitter.onError(e);
}
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try {
if (!response.isSuccessful()) {
emitter.onError(new LiteFlowAIEngineException("error response when calling LLM: " + response.message()));
return;
}
ResponseBody body = response.body();
if (Objects.isNull(body)) {
emitter.onError(new LiteFlowAIEngineException("response body is null when calling LLM"));
return;
}
// 逐行读取响应体,响应体为换行符分隔的 JSON
try (BufferedReader br = new BufferedReader(new InputStreamReader(body.byteStream()))) {
String line = br.readLine();
while (StrUtil.isNotBlank(line)) {
if (config.isLogResponse()) {
LOG.info("DN-JSON Response: {}", line);
}
try {
emitter.onNext(line);
} catch (Exception e) {
emitter.onError(e);
return;
}
line = br.readLine();
}
emitter.onComplete();
}
} finally {
close();
}
}
});
// 设置取消订阅时的清理逻辑
emitter.setDisposable(new Disposable() {
@Override
public void dispose() {
close();
}
@Override
public boolean isDisposed() {
return isStop.get();
}
});
}, BackpressureStrategy.BUFFER);
}
@Override
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
throw new UnsupportedOperationException("SSE传输不支持阻塞式调用请使用HTTP传输或调用start方法");
public String startBlocking(ChatConfig config, ChatRequest request) {
throw new UnsupportedOperationException("DN-JSON传输不支持阻塞式调用请使用HTTP传输");
}
@Override
public void close() {
if (!this.isStop) {
try {
this.isStop = true;
this.listener.onClose(pipeline.getContext());
} finally {
if (Objects.nonNull(client)) {
client.dispatcher().executorService().shutdown();
client.connectionPool().evictAll();
}
if (this.isStop.compareAndSet(false, true)) {
OkHttpClient client = clientRef.get();
if (Objects.nonNull(client)) {
client.dispatcher().executorService().shutdown();
client.connectionPool().evictAll();
}
}
}
@Override
public void onFailure(@NotNull Call call, @NotNull IOException e) {
this.listener.onError(pipeline.getContext(), e);
close();
}
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
if (!response.isSuccessful()) {
this.listener.onError(pipeline.getContext(), new LiteFlowAIEngineException("error response when calling LLM: " + response.message()));
close();
return;
}
ResponseBody body = response.body();
if (Objects.isNull(body)) {
this.listener.onError(pipeline.getContext(), new LiteFlowAIEngineException("response body is null when calling LLM"));
close();
return;
}
// 逐行读取响应体,响应体为换行符分隔的 JSON
try (BufferedReader br = new BufferedReader(new InputStreamReader(body.byteStream()))) {
String line = br.readLine();
while (StrUtil.isNotBlank(line)) {
if (this.isLogResponse) {
LOG.info("DN-JSON Response: {}", line);
}
pipeline.processStreaming(line);
line = br.readLine();
}
} finally {
// 确保在读取完毕后关闭资源
close();
}
}
private Request buildDnJsonRequest(ChatConfig config, ChatRequest request) {
String requestBody = buildRequestBody(config, request);

View File

@@ -1,15 +1,13 @@
package com.yomahub.liteflow.ai.engine.interact.transport.impl;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.log.EngineLog;
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.util.HttpUtil;
import io.reactivex.rxjava3.core.Flowable;
import java.io.IOException;
import java.util.Map;
@@ -26,12 +24,12 @@ public class HttpTransport implements Transport {
private static final EngineLog LOG = EngineLogManager.getLogger(HttpTransport.class);
@Override
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
throw new UnsupportedOperationException("HTTP传输不支持流式调用请使用SSE传输或调用startBlocking方法");
}
@Override
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
public String startBlocking(ChatConfig config, ChatRequest request) {
try (HttpUtil httpUtil = HttpUtil
.builder()
.connectTimeout(config.getConnectTimeout())
@@ -61,8 +59,8 @@ public class HttpTransport implements Transport {
LOG.info("======= HTTP Response End =======");
}
// 处理响应
return pipeline.processBlocking(responseBody);
// 返回响应
return responseBody;
} catch (IOException e) {
throw new LiteFlowAIEngineException("阻塞调用大模型失败", e);
}

View File

@@ -1,13 +1,13 @@
package com.yomahub.liteflow.ai.engine.interact.transport.impl;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.log.EngineLog;
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.BackpressureStrategy;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.disposables.Disposable;
import okhttp3.Headers;
import okhttp3.OkHttpClient;
import okhttp3.Request;
@@ -20,6 +20,8 @@ import org.jetbrains.annotations.Nullable;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/**
* Sse传输实现基于Server-Sent Events的非阻塞式传输
@@ -28,96 +30,114 @@ import java.util.Objects;
* @since 2.16.0
*/
public class SseTransport extends EventSourceListener implements Transport {
public class SseTransport implements Transport {
private static final EngineLog LOG = EngineLogManager.getLogger(SseTransport.class);
private ChunkProcessPipeline pipeline;
private TransportListener listener;
private OkHttpClient client;
private EventSource eventSource;
private boolean isStop = false;
private boolean isLogResponse = false;
private final AtomicReference<EventSource> eventSourceRef = new AtomicReference<>();
private final AtomicReference<OkHttpClient> clientRef = new AtomicReference<>();
private final AtomicBoolean isStopped = new AtomicBoolean(false);
@Override
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
this.pipeline = pipeline;
this.listener = listener;
this.isLogResponse = config.isLogResponse();
try {
// 构建SSE请求
Request sseRequest = buildSseRequest(config, request);
// 创建EventSource实例
client = new okhttp3.OkHttpClient.Builder()
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
return Flowable.create(emitter -> {
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(config.getConnectTimeout())
.readTimeout(config.getReadTimeout())
.build();
this.eventSource = EventSources.createFactory(client)
.newEventSource(sseRequest, this);
this.clientRef.set(client);
listener.onStart(pipeline.getContext());
} catch (Exception e) {
onFailure(eventSource, e, null);
}
try {
// 构建SSE请求
Request sseRequest = buildSseRequest(config, request);
// 创建响应式的 EventSourceListener
EventSourceListener listener = new EventSourceListener() {
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
if (config.isLogResponse()) {
LOG.info("SSE Response Event - id: {}, type: {}, data: {}", id, type, data);
}
// 发送数据到 Flowable
try {
emitter.onNext(data);
} catch (Exception e) {
emitter.onError(e);
}
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
try {
if (!emitter.isCancelled()) {
emitter.onComplete();
}
} catch (Exception e) {
LOG.warn("Error completing stream", e);
}
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
if (!emitter.isCancelled()) {
if (t != null) {
emitter.onError(t);
} else {
emitter.onError(new RuntimeException("SSE connection failed"));
}
}
}
};
// 创建 EventSource
EventSource eventSource = EventSources.createFactory(client)
.newEventSource(sseRequest, listener);
eventSourceRef.set(eventSource);
// 设置取消订阅时的清理逻辑
emitter.setDisposable(new Disposable() {
@Override
public void dispose() {
close();
}
@Override
public boolean isDisposed() {
return isStopped.get();
}
});
} catch (Exception e) {
emitter.onError(e);
}
}, BackpressureStrategy.BUFFER);
}
@Override
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
public String startBlocking(ChatConfig config, ChatRequest request) {
throw new UnsupportedOperationException("SSE传输不支持阻塞式调用请使用HTTP传输或调用start方法");
}
@Override
public void close() {
if (!this.isStop) {
try {
this.isStop = true;
this.listener.onClose(pipeline.getContext());
} finally {
if (Objects.nonNull(eventSource)) {
eventSource.cancel();
}
if (Objects.nonNull(client)) {
client.dispatcher().executorService().shutdown();
client.connectionPool().evictAll();
}
if (this.isStopped.compareAndSet(false, true)) {
EventSource eventSource = eventSourceRef.get();
if (Objects.nonNull(eventSource)) {
eventSource.cancel();
}
OkHttpClient client = clientRef.get();
if (Objects.nonNull(client)) {
client.dispatcher().executorService().shutdown();
client.connectionPool().evictAll();
}
}
}
@Override
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
super.onOpen(eventSource, response);
}
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
super.onEvent(eventSource, id, type, data);
if (this.isLogResponse) {
LOG.info("SSE Response Event - id: {}, type: {}, data: {}", id, type, data);
}
// 如果返回 true 表示流式响应结束,关闭连接, 有一些模型不会主动关闭连接,需要在这里判断
if (pipeline.processStreaming(data)) {
close();
}
}
@Override
public void onClosed(@NotNull EventSource eventSource) {
super.onClosed(eventSource);
}
@Override
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
super.onFailure(eventSource, t, response);
close();
this.listener.onError(pipeline.getContext(), t);
}
private Request buildSseRequest(ChatConfig config, ChatRequest request) {
String requestBody = buildRequestBody(config, request);

View File

@@ -2,9 +2,11 @@ package com.yomahub.liteflow.ai.engine.model.chat;
import com.yomahub.liteflow.ai.engine.interact.InteractClient;
import com.yomahub.liteflow.ai.engine.interact.LlmInteractClient;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.Flowable;
import java.time.Duration;
import java.util.Map;
@@ -43,8 +45,8 @@ public abstract class AbstractChatModel implements ChatModel {
}
@Override
public void stream(ChatRequest request) {
interactClient.stream(config, request);
public Flowable<ChunkEvent> stream(ChatRequest request) {
return interactClient.stream(config, request);
}
@Override

View File

@@ -1,9 +1,11 @@
package com.yomahub.liteflow.ai.engine.model.chat;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.model.BaseModel;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.Flowable;
import java.util.concurrent.CompletableFuture;
@@ -33,10 +35,16 @@ public interface ChatModel extends BaseModel<ChatConfig> {
CompletableFuture<ChatResponse> chatAsync(ChatRequest request);
/**
* 执行聊天请求(异步流式)
* 执行聊天请求(响应式流式)
* <p>
* 返回一个 Flowable 流,用户可以订阅和处理每个 ChunkEvent 事件,包括:
* - START流开始事件包含初始化的 InteractContext
* - CHUNK数据分块事件包含原始 JSON、转换后的分块和更新的 InteractContext
* - COMPLETE完成事件包含最终的 ChatResponse 和完整的 InteractContext
* - ERROR错误事件包含异常信息
*
* @param request 聊天请求
* @return 包含所有流式事件的 Flowable 流
*/
void stream(ChatRequest request);
Flowable<ChunkEvent> stream(ChatRequest request);
}

View File

@@ -1,11 +1,6 @@
package com.yomahub.liteflow.ai.engine.model.chat.entity;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.ModelRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
@@ -13,7 +8,6 @@ import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.JsonSchemaParser;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
@@ -24,9 +18,6 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
/**
@@ -57,21 +48,6 @@ public class ChatRequest implements ModelRequest {
*/
protected final TransportType transportType;
/**
* 传输监听器,用于处理请求的开始和结束事件
*/
protected final TransportListener transportListener;
/**
* 结果处理器,用于处理消息全部发送完毕后的结果
*/
protected ResultHandler resultHandler;
/**
* 分块回调,用于处理消息分块的各种事件
*/
protected final ChunkCallbackTransformer chunkCallbackTransformer;
/**
* 响应类型,默认为文本类型
*/
@@ -114,9 +90,6 @@ public class ChatRequest implements ModelRequest {
this.options = ChatOptions.DEFAULT;
this.streaming = true; // 默认启用流式输出
this.transportType = TransportType.SSE; // 默认使用 SSE 传输
this.transportListener = TransportListener.getDefault();
this.resultHandler = ResultHandler.getDefault();
this.chunkCallbackTransformer = ChunkCallbackTransformer.getDefault();
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
TypeReference<String> targetType = new TypeReference<String>() {
}; // 默认目标类型为 String
@@ -130,9 +103,6 @@ public class ChatRequest implements ModelRequest {
ChatOptions options,
boolean streaming,
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType,
boolean strict,
@@ -142,9 +112,6 @@ public class ChatRequest implements ModelRequest {
this.options = options;
this.streaming = streaming;
this.transportType = transportType;
this.transportListener = transportListener;
this.resultHandler = resultHandler;
this.chunkCallbackTransformer = chunkCallbackTransformer;
this.responseType = responseType;
this.strict = strict;
this.outputParser = JsonSchemaParser.fromTypeReference(targetType, strict);
@@ -163,9 +130,6 @@ public class ChatRequest implements ModelRequest {
this.options = builder.options;
this.streaming = builder.streaming;
this.transportType = builder.transportType;
this.transportListener = builder.transportListener;
this.resultHandler = builder.resultHandler;
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
this.responseType = builder.responseType;
this.strict = builder.strict;
this.outputParser = JsonSchemaParser.fromType(builder.targetType, builder.strict);
@@ -243,18 +207,6 @@ public class ChatRequest implements ModelRequest {
return transportType;
}
public TransportListener getTransportListener() {
return transportListener;
}
public ResultHandler getResultHandler() {
return resultHandler;
}
public ChunkCallbackTransformer getChunkCallbackTransformer() {
return chunkCallbackTransformer;
}
public ResponseType getResponseType() {
return responseType;
}
@@ -271,10 +223,6 @@ public class ChatRequest implements ModelRequest {
return strict;
}
public void setResultHandler(ResultHandler resultHandler) {
this.resultHandler = resultHandler;
}
public ToolRegistry getToolRegistry() {
return toolRegistry;
}
@@ -288,10 +236,6 @@ public class ChatRequest implements ModelRequest {
protected ChatOptions options;
protected boolean streaming = true; // 默认启用流式输出
protected TransportType transportType = TransportType.SSE; // 默认使用 SSE 传输
protected final LlmListenerAggregator listenerAggregator = new LlmListenerAggregator();
protected TransportListener transportListener;
protected ResultHandler resultHandler;
protected ChunkCallbackTransformer chunkCallbackTransformer;
protected ResponseType responseType = ResponseType.TEXT;
protected Type targetType = String.class;
protected boolean strict = true;
@@ -308,9 +252,6 @@ public class ChatRequest implements ModelRequest {
if (Objects.isNull(options)) {
options = ChatOptions.DEFAULT;
}
this.transportListener = listenerAggregator.toTransportListener();
this.resultHandler = listenerAggregator.toResultHandler();
this.chunkCallbackTransformer = listenerAggregator.toChunkCallbackTransformer();
}
/**
@@ -356,116 +297,6 @@ public class ChatRequest implements ModelRequest {
return self();
}
/**
* 请求开始时的回调方法
*
* @param onStart 请求开始时的回调函数
* @see TransportListener#onStart(InteractContext)
*/
public B onStart(Consumer<InteractContext> onStart) {
listenerAggregator.onStart = onStart;
return self();
}
/**
* 请求结束时的回调方法
*
* @param onClose 请求结束时的回调函数
* @see TransportListener#onClose(InteractContext)
*/
public B onClose(Consumer<InteractContext> onClose) {
listenerAggregator.onClose = onClose;
return self();
}
/**
* 请求发生错误时的回调方法
*
* @param onError 请求发生错误时的回调函数
* @see TransportListener#onError(InteractContext, Throwable)
*/
public B onError(BiConsumer<InteractContext, Throwable> onError) {
listenerAggregator.onError = onError;
return self();
}
/**
* 文本消息的回调方法
*
* @param onText 文本消息的回调函数
* @see ChunkCallbackTransformer#onText(String, InteractContext)
*/
public B onText(BiFunction<String, InteractContext, String> onText) {
listenerAggregator.onText = onText;
return self();
}
/**
* 思考消息的回调方法
*
* @param onThinking 思考消息的回调函数
* @see ChunkCallbackTransformer#onThinking(String, InteractContext)
*/
public B onThinking(BiFunction<String, InteractContext, String> onThinking) {
listenerAggregator.onThinking = onThinking;
return self();
}
/**
* 工具调用消息的回调方法
*
* @param onToolsCalling 工具调用消息的回调函数
* @see ChunkCallbackTransformer#onToolsCalling(List, InteractContext)
*/
public B onToolsCalling(BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling) {
listenerAggregator.onToolsCalling = onToolsCalling;
return self();
}
/**
* Token 统计信息的回调方法
*
* @param onUsage Token 统计信息的回调函数
* @see ChunkCallbackTransformer#onUsage(Object, InteractContext)
*/
public B onUsage(BiFunction<Object, InteractContext, Object> onUsage) {
listenerAggregator.onUsage = onUsage;
return self();
}
/**
* 基础信息/搜索结果的回调方法
*
* @param onGrounding 基础信息/搜索结果的回调函数
* @see ChunkCallbackTransformer#onGrounding(Object, InteractContext)
*/
public B onGrounding(BiFunction<Object, InteractContext, Object> onGrounding) {
listenerAggregator.onGrounding = onGrounding;
return self();
}
/**
* 请求完成时的回调方法
*
* @param onCompletion 请求完成时的回调函数
* @see ResultHandler#onCompletion(ChatResponse, InteractContext)
*/
public B onCompletion(BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion) {
listenerAggregator.onCompletion = onCompletion;
return self();
}
/**
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。
*
* @param onFinal 请求最终结果的回调函数
* @see ResultHandler#onFinal(ChatResponse, InteractContext)
*/
public B onFinal(BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal) {
listenerAggregator.onFinal = onFinal;
return self();
}
/**
* 设置响应类型
*
@@ -524,98 +355,6 @@ public class ChatRequest implements ModelRequest {
return self();
}
/**
* 内部聚合类
*/
protected static class LlmListenerAggregator {
Consumer<InteractContext> onStart = context -> {
};
Consumer<InteractContext> onClose = context -> {
};
BiConsumer<InteractContext, Throwable> onError = (context, t) -> {
throw new LiteFlowAIEngineException(t.getMessage(), t);
};
BiFunction<String, InteractContext, String> onText = (content, context) -> content;
BiFunction<String, InteractContext, String> onThinking = (content, context) -> content;
BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling = (toolCalls, context) -> toolCalls;
BiFunction<Object, InteractContext, Object> onUsage = (content, context) -> content;
BiFunction<Object, InteractContext, Object> onGrounding = (content, context) -> content;
BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion = (response, context) -> response;
BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal = (response, context) -> response;
TransportListener toTransportListener() {
Objects.requireNonNull(onStart, "onStart cannot be null");
Objects.requireNonNull(onClose, "onClose cannot be null");
Objects.requireNonNull(onError, "onError cannot be null");
return new TransportListener() {
@Override
public void onStart(InteractContext context) {
onStart.accept(context);
}
@Override
public void onClose(InteractContext context) {
onClose.accept(context);
}
@Override
public void onError(InteractContext context, Throwable t) {
onError.accept(context, t);
}
};
}
ResultHandler toResultHandler() {
Objects.requireNonNull(onCompletion, "onCompletion cannot be null");
Objects.requireNonNull(onFinal, "onFinal cannot be null");
return new ResultHandler() {
@Override
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
return onCompletion.apply(response, context);
}
@Override
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
return onFinal.apply(response, context);
}
};
}
ChunkCallbackTransformer toChunkCallbackTransformer() {
Objects.requireNonNull(onText, "onText cannot be null");
Objects.requireNonNull(onThinking, "onThinking cannot be null");
Objects.requireNonNull(onToolsCalling, "onToolsCalling cannot be null");
Objects.requireNonNull(onUsage, "onUsage cannot be null");
Objects.requireNonNull(onGrounding, "onGrounding cannot be null");
return new ChunkCallbackTransformer() {
@Override
public String onText(String content, InteractContext context) {
return onText.apply(content, context);
}
@Override
public String onThinking(String content, InteractContext context) {
return onThinking.apply(content, context);
}
@Override
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
return onToolsCalling.apply(toolCalls, context);
}
@Override
public Object onUsage(Object content, InteractContext context) {
return onUsage.apply(content, context);
}
@Override
public Object onGrounding(Object content, InteractContext context) {
return onGrounding.apply(content, context);
}
};
}
}
private static class BuilderImpl extends Builder<BuilderImpl> {
@Override

View File

@@ -1,8 +1,12 @@
package com.yomahub.liteflow.ai.engine.tool.registry;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import java.util.Collection;
import java.util.Objects;
/**
* 工具调用注册接口
@@ -27,4 +31,19 @@ public interface ToolRegistry {
* @return 所有工具的集合
*/
Collection<ToolCallBack> getAllTools();
/**
* 执行工具调用
*
* @param toolCall 工具调用信息
* @return 工具调用结果消息
*/
default ToolMessage executeToolCall(ToolCall toolCall) {
ToolCallBack toolCallBack = getTool(toolCall.getName());
if (Objects.isNull(toolCallBack)) {
throw new LiteFlowAIEngineException("Unable to find target tool with tool name: " + toolCall.getName());
}
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
return new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
}
}

View File

@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;

View File

@@ -2,9 +2,6 @@ package com.yomahub.liteflow.ai.model.dashscope.model.chat;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
@@ -56,16 +53,12 @@ public class DashScopeChatRequest extends ChatRequest {
ChatOptions options,
boolean streaming,
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType,
boolean strict,
ToolRegistry toolRegistry
) {
super(messages, options, streaming, transportType,
transportListener, resultHandler, chunkCallbackTransformer,
responseType, targetType, strict, toolRegistry);
}

View File

@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;

View File

@@ -1,8 +1,5 @@
package com.yomahub.liteflow.ai.model.ollama.model.chat;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
@@ -39,16 +36,12 @@ public class OllamaChatRequest extends ChatRequest {
ChatOptions options,
boolean streaming,
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType,
boolean strict,
ToolRegistry toolRegistry
) {
super(messages, options, streaming, transportType,
transportListener, resultHandler, chunkCallbackTransformer,
responseType, targetType, strict, toolRegistry);
}

View File

@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;

View File

@@ -2,9 +2,6 @@ package com.yomahub.liteflow.ai.model.openai.model.chat;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
@@ -43,16 +40,12 @@ public class OpenAIChatRequest extends ChatRequest {
ChatOptions options,
boolean streaming,
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType,
boolean strict,
ToolRegistry toolRegistry
) {
super(messages, options, streaming, transportType,
transportListener, resultHandler, chunkCallbackTransformer,
responseType, targetType, strict, toolRegistry);
}

View File

@@ -1,26 +1,8 @@
package com.yomahub.liteflow.ai.workflow.coze.invocation;
import cn.hutool.core.util.StrUtil;
import com.coze.openapi.client.chat.model.ChatEvent;
import com.coze.openapi.client.workflows.chat.WorkflowChatReq;
import com.coze.openapi.service.auth.TokenAuth;
import com.coze.openapi.service.config.Consts;
import com.coze.openapi.service.service.CozeAPI;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
import com.yomahub.liteflow.ai.util.SpringUtil;
import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowChat;
import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty;
import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil;
import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowChatProxyWrapBean;
import io.reactivex.Flowable;
import java.util.List;
import java.util.Map;
import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
/**
* Coze 对话流 调用处理器
@@ -36,60 +18,61 @@ public class CozeWorkflowChatInvocationHandler extends AbstractAIInvocationHandl
@Override
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
CozeWorkflowChat annotation = wrapBean.getAnnotation();
CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
CozeAPI api = new CozeAPI.Builder()
.auth(new TokenAuth(property.getApiKey()))
.baseURL(
StrUtil.isBlank(property.getBaseUrl()) ?
Consts.COZE_CN_BASE_URL :
property.getBaseUrl()
)
.build();
WorkflowChatReq.WorkflowChatReqBuilder<?, ?> builder = WorkflowChatReq.builder();
setIfPresent(builder::workflowID, annotation.workflowId());
setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class);
if (annotation.parameters().length > 0) {
Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
if (!paramMap.isEmpty()) {
builder.parameters(paramMap);
}
}
setIfPresent(builder::appID, annotation.appId());
setIfPresent(builder::botID, annotation.botId());
setIfPresent(builder::conversationID, annotation.conversationId());
if (annotation.ext().length > 0) {
Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
if (!extMap.isEmpty()) {
builder.ext(extMap);
}
}
setIfPresent(builder::connectTimeout, annotation.connectTimeout());
setIfPresent(builder::readTimeout, annotation.readTimeout());
setIfPresent(builder::writeTimeout, annotation.writeTimeout());
setIfPresent(builder::customerToken, annotation.customerToken());
Flowable<ChatEvent> res = api.workflows().chat().stream(builder.build());
InteractContext context = new InteractContext();
res.blockingForEach(chunk -> {
ChatContext chatContext = processorContext.getChatContext();
chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
});
return null;
// CozeWorkflowChat annotation = wrapBean.getAnnotation();
//
// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
//
// CozeAPI api = new CozeAPI.Builder()
// .auth(new TokenAuth(property.getApiKey()))
// .baseURL(
// StrUtil.isBlank(property.getBaseUrl()) ?
// Consts.COZE_CN_BASE_URL :
// property.getBaseUrl()
// )
// .build();
//
// WorkflowChatReq.WorkflowChatReqBuilder<?, ?> builder = WorkflowChatReq.builder();
//
// setIfPresent(builder::workflowID, annotation.workflowId());
//
// setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class);
//
// if (annotation.parameters().length > 0) {
// Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
// if (!paramMap.isEmpty()) {
// builder.parameters(paramMap);
// }
// }
//
// setIfPresent(builder::appID, annotation.appId());
//
// setIfPresent(builder::botID, annotation.botId());
//
// setIfPresent(builder::conversationID, annotation.conversationId());
//
// if (annotation.ext().length > 0) {
// Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
// if (!extMap.isEmpty()) {
// builder.ext(extMap);
// }
// }
//
// setIfPresent(builder::connectTimeout, annotation.connectTimeout());
//
// setIfPresent(builder::readTimeout, annotation.readTimeout());
//
// setIfPresent(builder::writeTimeout, annotation.writeTimeout());
//
// setIfPresent(builder::customerToken, annotation.customerToken());
//
// Flowable<ChatEvent> res = api.workflows().chat().stream(builder.build());
// InteractContext context = new InteractContext();
// res.blockingForEach(chunk -> {
// ChatContext chatContext = processorContext.getChatContext();
// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
// });
// return null;
}
@Override

View File

@@ -1,26 +1,8 @@
package com.yomahub.liteflow.ai.workflow.coze.invocation;
import cn.hutool.core.util.StrUtil;
import com.coze.openapi.client.workflows.run.RunWorkflowReq;
import com.coze.openapi.client.workflows.run.model.WorkflowEvent;
import com.coze.openapi.service.auth.TokenAuth;
import com.coze.openapi.service.config.Consts;
import com.coze.openapi.service.service.CozeAPI;
import com.coze.openapi.service.service.workflow.WorkflowRunService;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
import com.yomahub.liteflow.ai.util.SpringUtil;
import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowRun;
import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty;
import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil;
import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowRunProxyWrapBean;
import io.reactivex.Flowable;
import java.util.Map;
import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
/**
* Coze 工作流 调用处理器
@@ -37,62 +19,63 @@ public class CozeWorkflowRunInvocationHandler extends AbstractAIInvocationHandle
@Override
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
CozeWorkflowRun annotation = wrapBean.getAnnotation();
CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
CozeAPI api = new CozeAPI.Builder()
.auth(new TokenAuth(property.getApiKey()))
.baseURL(
StrUtil.isBlank(property.getBaseUrl()) ?
Consts.COZE_CN_BASE_URL :
property.getBaseUrl()
)
.build();
RunWorkflowReq.RunWorkflowReqBuilder<?, ?> builder = RunWorkflowReq.builder();
setIfPresent(builder::workflowID, annotation.workflowId());
if (annotation.parameters().length > 0) {
Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
if (!paramMap.isEmpty()) {
builder.parameters(paramMap);
}
}
setIfPresent(builder::appID, annotation.appId());
setIfPresent(builder::botID, annotation.botId());
if (annotation.ext().length > 0) {
Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
if (!extMap.isEmpty()) {
builder.ext(extMap);
}
}
setIfPresent(builder::connectTimeout, annotation.connectTimeout());
setIfPresent(builder::readTimeout, annotation.readTimeout());
setIfPresent(builder::writeTimeout, annotation.writeTimeout());
setIfPresent(builder::customerToken, annotation.customerToken());
WorkflowRunService runs = api.workflows().runs();
if (annotation.stream()) {
Flowable<WorkflowEvent> res = runs.stream(builder.build());
InteractContext context = new InteractContext();
res.blockingForEach(chunk -> {
ChatContext chatContext = processorContext.getChatContext();
chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
});
return null;
} else {
return runs.create(builder.build());
}
return null;
// CozeWorkflowRun annotation = wrapBean.getAnnotation();
//
// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
//
// CozeAPI api = new CozeAPI.Builder()
// .auth(new TokenAuth(property.getApiKey()))
// .baseURL(
// StrUtil.isBlank(property.getBaseUrl()) ?
// Consts.COZE_CN_BASE_URL :
// property.getBaseUrl()
// )
// .build();
//
// RunWorkflowReq.RunWorkflowReqBuilder<?, ?> builder = RunWorkflowReq.builder();
//
// setIfPresent(builder::workflowID, annotation.workflowId());
//
// if (annotation.parameters().length > 0) {
// Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
// if (!paramMap.isEmpty()) {
// builder.parameters(paramMap);
// }
// }
//
// setIfPresent(builder::appID, annotation.appId());
//
// setIfPresent(builder::botID, annotation.botId());
//
// if (annotation.ext().length > 0) {
// Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
// if (!extMap.isEmpty()) {
// builder.ext(extMap);
// }
// }
//
// setIfPresent(builder::connectTimeout, annotation.connectTimeout());
//
// setIfPresent(builder::readTimeout, annotation.readTimeout());
//
// setIfPresent(builder::writeTimeout, annotation.writeTimeout());
//
// setIfPresent(builder::customerToken, annotation.customerToken());
//
// WorkflowRunService runs = api.workflows().runs();
//
// if (annotation.stream()) {
// Flowable<WorkflowEvent> res = runs.stream(builder.build());
// InteractContext context = new InteractContext();
// res.blockingForEach(chunk -> {
// ChatContext chatContext = processorContext.getChatContext();
// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
// });
// return null;
// } else {
// return runs.create(builder.build());
// }
}
@Override

View File

@@ -4,7 +4,7 @@ import com.alibaba.dashscope.app.FlowStreamMode;
import com.google.gson.JsonObject;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
import java.lang.annotation.*;
import java.util.List;

View File

@@ -1,26 +1,11 @@
package com.yomahub.liteflow.ai.workflow.dashscope.invocation;
import cn.hutool.core.util.StrUtil;
import com.alibaba.dashscope.app.Application;
import com.alibaba.dashscope.app.ApplicationParam;
import com.alibaba.dashscope.app.ApplicationResult;
import com.alibaba.dashscope.app.RagOptions;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.JsonUtils;
import com.google.gson.JsonObject;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
import com.yomahub.liteflow.ai.parse.context.ContextAccessor;
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
import com.yomahub.liteflow.ai.util.KeyValue;
import com.yomahub.liteflow.ai.util.SpringUtil;
import com.yomahub.liteflow.ai.workflow.dashscope.annotation.DashScopeWorkflow;
import com.yomahub.liteflow.ai.workflow.dashscope.config.DashScopeWorkflowProperty;
import com.yomahub.liteflow.ai.workflow.dashscope.wrap.DashScopeWorkflowProxyWrapBean;
import io.reactivex.Flowable;
import java.util.Arrays;
import java.util.List;
@@ -46,95 +31,96 @@ public class DashScopeWorkflowInvocationHandler extends AbstractAIInvocationHand
@Override
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation();
DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class);
ApplicationParam.ApplicationParamBuilder<?, ?> builder = ApplicationParam.builder();
builder.appId(dashScopeWorkflow.appId());
builder.apiKey(property.getApiKey());
setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext);
String history = dashScopeWorkflow.history();
setIfPresent(builder::history, StrUtil.isNotBlank(history) ?
ContextAccessor.searchContextByExpression(history, processorContext) : null);
String messages = dashScopeWorkflow.messages();
setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ?
ContextAccessor.searchContextByExpression(messages, processorContext) : null);
setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId());
setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts());
setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(),
processorContext, JsonObject.class, JsonUtils::parse);
setIfPresent(builder::topP, dashScopeWorkflow.topP());
setIfPresent(builder::topK, dashScopeWorkflow.topK());
setIfPresent(builder::seed, dashScopeWorkflow.seed());
setIfPresent(builder::temperature, dashScopeWorkflow.temperature());
setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput());
setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId());
setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class,
s -> Arrays.stream(s.split(COMMA_SPLITTER))
.map(String::trim)
.collect(Collectors.toList()));
setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext));
setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class,
s -> Arrays.stream(s.split(COMMA_SPLITTER))
.map(String::trim)
.collect(Collectors.toList()));
setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch());
setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime());
setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium());
setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound());
setIfPresent(builder::modelId, dashScopeWorkflow.modelId());
setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode());
setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking());
ApplicationParam applicationParam = builder.build();
Application application = StrUtil.isNotBlank(property.getApiUrl()) ?
new Application(property.getApiUrl()) :
new Application();
if (dashScopeWorkflow.stream()) {
try {
Flowable<ApplicationResult> res = application.streamCall(applicationParam);
InteractContext context = new InteractContext();
res.blockingForEach(chunk -> {
ChatContext chatContext = processorContext.getChatContext();
chatContext.getStreamHandler().onText(chunk.getOutput().getText(), context);
context.addText(chunk.getOutput().getText());
});
return null;
} catch (NoApiKeyException | InputRequiredException e) {
throw new LiteFlowAIException("DashScope stream call failed", e);
}
} else {
try {
return application.call(applicationParam);
} catch (NoApiKeyException | InputRequiredException e) {
throw new LiteFlowAIException("DashScope call failed", e);
}
}
return null;
// DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation();
//
// DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class);
//
// ApplicationParam.ApplicationParamBuilder<?, ?> builder = ApplicationParam.builder();
//
// builder.appId(dashScopeWorkflow.appId());
// builder.apiKey(property.getApiKey());
//
// setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext);
//
// String history = dashScopeWorkflow.history();
// setIfPresent(builder::history, StrUtil.isNotBlank(history) ?
// ContextAccessor.searchContextByExpression(history, processorContext) : null);
//
// String messages = dashScopeWorkflow.messages();
// setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ?
// ContextAccessor.searchContextByExpression(messages, processorContext) : null);
//
// setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId());
//
// setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts());
//
// setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(),
// processorContext, JsonObject.class, JsonUtils::parse);
//
// setIfPresent(builder::topP, dashScopeWorkflow.topP());
//
// setIfPresent(builder::topK, dashScopeWorkflow.topK());
//
// setIfPresent(builder::seed, dashScopeWorkflow.seed());
//
// setIfPresent(builder::temperature, dashScopeWorkflow.temperature());
//
// setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput());
//
// setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId());
//
// setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class,
// s -> Arrays.stream(s.split(COMMA_SPLITTER))
// .map(String::trim)
// .collect(Collectors.toList()));
//
// setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext));
//
// setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class,
// s -> Arrays.stream(s.split(COMMA_SPLITTER))
// .map(String::trim)
// .collect(Collectors.toList()));
//
// setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch());
//
// setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime());
//
// setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium());
//
// setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound());
//
// setIfPresent(builder::modelId, dashScopeWorkflow.modelId());
//
// setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode());
//
// setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking());
//
// ApplicationParam applicationParam = builder.build();
// Application application = StrUtil.isNotBlank(property.getApiUrl()) ?
// new Application(property.getApiUrl()) :
// new Application();
//
// if (dashScopeWorkflow.stream()) {
// try {
// Flowable<ApplicationResult> res = application.streamCall(applicationParam);
// InteractContext context = new InteractContext();
// res.blockingForEach(chunk -> {
// ChatContext chatContext = processorContext.getChatContext();
// chatContext.getStreamHandler().onText(chunk.getOutput().getText(), context);
// context.addText(chunk.getOutput().getText());
// });
// return null;
// } catch (NoApiKeyException | InputRequiredException e) {
// throw new LiteFlowAIException("DashScope stream call failed", e);
// }
// } else {
// try {
// return application.call(applicationParam);
// } catch (NoApiKeyException | InputRequiredException e) {
// throw new LiteFlowAIException("DashScope call failed", e);
// }
// }
}
@SuppressWarnings("unchecked")

View File

@@ -27,6 +27,7 @@
<victools.version>4.36.0</victools.version>
<dashscope-sdk.version>2.21.5</dashscope-sdk.version>
<coze-api.version>0.4.2</coze-api.version>
<rxjava.version>3.1.8</rxjava.version>
</properties>
<dependencyManagement>
@@ -72,6 +73,12 @@
<artifactId>coze-api</artifactId>
<version>${coze-api.version}</version>
</dependency>
<dependency>
<groupId>io.reactivex.rxjava3</groupId>
<artifactId>rxjava</artifactId>
<version>${rxjava.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
</project>

View File

@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.chat;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.core.FlowExecutor;
@@ -73,46 +75,47 @@ public class ChatTest extends MockAITest {
}
private StreamHandler getStreamHandler() {
return StreamHandler.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
return eventStream -> eventStream.doOnNext(chunkEvent -> {
if (chunkEvent.isStart()) {
System.out.println("chat start");
}
if (chunkEvent.isChunk()) {
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
switch (chunk.getType()) {
case TEXT:
System.out.println(chunk.getData());
break;
case THINKING:
System.out.println("[Thinking] " + chunk.getData());
}
}
if (chunkEvent.isComplete()) {
ChatResponse response = chunkEvent.getFinalResponse();
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
})
.build();
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
System.out.println("chat close");
}
});
}
}

View File

@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.structure;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.core.FlowExecutor;
@@ -80,45 +82,46 @@ public class StructureTest extends MockAITest {
}
private StreamHandler getStreamHandler() {
return StreamHandler.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
return eventStream -> eventStream.doOnNext(chunkEvent -> {
if (chunkEvent.isStart()) {
System.out.println("chat start");
}
if (chunkEvent.isChunk()) {
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
switch (chunk.getType()) {
case TEXT:
System.out.println(chunk.getData());
break;
case THINKING:
System.out.println("[Thinking] " + chunk.getData());
}
}
if (chunkEvent.isComplete()) {
ChatResponse response = chunkEvent.getFinalResponse();
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
})
.build();
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
System.out.println("chat close");
}
});
}
}

View File

@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.tool;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
@@ -98,78 +100,56 @@ public class ToolTest extends MockAITest {
private ChatContext buildBlockingChatContext() {
return ChatContext.builder()
.streamHandler(StreamHandler.builder()
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
})
.build())
.build();
}
private ChatContext buildStreamingChatContext() {
return ChatContext.builder()
.streamHandler(StreamHandler.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
})
.build())
.streamHandler(getStreamHandler())
.build();
}
private StreamHandler getStreamHandler() {
return eventStream -> eventStream.doOnNext(chunkEvent -> {
if (chunkEvent.isStart()) {
System.out.println("chat start");
}
if (chunkEvent.isChunk()) {
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
switch (chunk.getType()) {
case TEXT:
System.out.println(chunk.getData());
break;
case THINKING:
System.out.println("[Thinking] " + chunk.getData());
}
}
if (chunkEvent.isComplete()) {
ChatResponse response = chunkEvent.getFinalResponse();
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
System.out.println("chat close");
}
});
}
}

View File

@@ -1,28 +1,5 @@
package com.yomahub.liteflow.test.ai.engine.interact;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.dashscope.interact.DashScopeProtocolTransformer;
import com.yomahub.liteflow.ai.model.ollama.interact.OllamaProtocolTransformer;
import com.yomahub.liteflow.ai.model.openai.interact.OpenAIProtocolTransformer;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
/**
* 交互处理器测试类
* 测试各种协议转换器和流式处理管道的功能
@@ -31,368 +8,327 @@ import static org.junit.jupiter.api.Assertions.*;
* @since 2.16.0
*/
public class InteractTest {
private ChunkProcessPipeline pipeline;
private ProtocolTransformer protocolTransformer;
private ChunkCallbackTransformer chunkCallbackTransformer;
private InteractContext context;
/**
* 创建协议转换器的工厂方法
*
* @param provider 提供者枚举
* @return 对应的协议转换器实例
*/
private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) {
switch (provider) {
case OPENAI:
return new OpenAIProtocolTransformer();
case DASHSCOPE:
return new DashScopeProtocolTransformer();
case OLLAMA:
return new OllamaProtocolTransformer();
default:
throw new IllegalArgumentException("不支持的提供者: " + provider);
}
}
/**
* 简单的回调处理器,用于测试
*/
private static class TestChunkCallbackTransformer implements ChunkCallbackTransformer {
@Override
public String onText(String content, InteractContext context) {
System.out.println("onText: " + content);
return content;
}
@Override
public String onThinking(String content, InteractContext context) {
System.out.println("onThinking: " + content);
return content;
}
@Override
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
// 测试用的简单实现,直接返回原工具调用
return toolCalls;
}
@Override
public Object onUsage(Object content, InteractContext context) {
// 测试用的简单实现,直接返回原使用统计
return content;
}
@Override
public Object onGrounding(Object content, InteractContext context) {
// 测试用的简单实现,直接返回原基础信息
return content;
}
}
@Nested
@DisplayName("OpenAI协议转换器测试")
class OpenAITests {
@BeforeEach
void setUp() {
context = new InteractContext();
protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI);
chunkCallbackTransformer = new TestChunkCallbackTransformer();
}
@Test
@DisplayName("测试OpenAI流式文本响应转换")
void testStreamingTextResponse() {
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false);
}
@Test
@DisplayName("测试OpenAI流式工具调用响应转换")
void testStreamingToolCallResponse() {
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
}
@Test
@DisplayName("测试OpenAI流式结构化响应转换")
void testStreamingStructuredResponse() {
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
}
@Test
@DisplayName("测试OpenAI阻塞式文本响应转换")
void testBlockingTextResponse() {
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false);
}
@Test
@DisplayName("测试OpenAI阻塞式工具调用响应转换")
void testBlockingToolCallResponse() {
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
}
@Test
@DisplayName("测试OpenAI阻塞式结构化响应转换")
void testBlockingStructuredResponse() {
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
}
}
@Nested
@DisplayName("DashScope协议转换器测试")
class DashScopeTests {
@BeforeEach
void setUp() {
context = new InteractContext();
protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE);
chunkCallbackTransformer = new TestChunkCallbackTransformer();
}
@Test
@DisplayName("测试DashScope流式文本响应转换")
void testStreamingTextResponse() {
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false);
}
@Test
@DisplayName("测试DashScope流式工具调用响应转换")
void testStreamingToolCallResponse() {
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
}
@Test
@DisplayName("测试DashScope流式结构化响应转换")
void testStreamingStructuredResponse() {
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
}
@Test
@DisplayName("测试DashScope阻塞式文本响应转换")
void testBlockingTextResponse() {
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false);
}
@Test
@DisplayName("测试DashScope阻塞式工具调用响应转换")
void testBlockingToolCallResponse() {
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
}
@Test
@DisplayName("测试DashScope阻塞式结构化响应转换")
void testBlockingStructuredResponse() {
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
}
}
@Nested
@DisplayName("Ollama协议转换器测试")
class OllamaTests {
@BeforeEach
void setUp() {
context = new InteractContext();
protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA);
chunkCallbackTransformer = new TestChunkCallbackTransformer();
}
@Test
@DisplayName("测试Ollama流式文本响应转换")
void testStreamingTextResponse() {
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false);
}
@Test
@DisplayName("测试Ollama流式工具调用响应转换")
void testStreamingToolCallResponse() {
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
}
@Test
@DisplayName("测试Ollama流式结构化响应转换")
void testStreamingStructuredResponse() {
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
}
@Test
@DisplayName("测试Ollama阻塞式文本响应转换")
void testBlockingTextResponse() {
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false);
}
@Test
@DisplayName("测试Ollama阻塞式工具调用响应转换")
void testBlockingToolCallResponse() {
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
}
@Test
@DisplayName("测试Ollama阻塞式结构化响应转换")
void testBlockingStructuredResponse() {
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
}
}
@Nested
@DisplayName("参数化测试 - 所有协议转换器")
class ParameterizedTests {
@ParameterizedTest
@EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" })
@DisplayName("参数化测试协议转换器提供者名称")
void testProtocolTransformerProviderEnumName(ProviderEnum provider) {
ProtocolTransformer transformer = createProtocolTransformer(provider);
assertEquals(provider.getProviderName(), transformer.getProviderName(),
provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'");
}
}
/**
* 测试流式响应处理的通用方法
*
* @param provider 提供者
* @param requestType 响应类型
* @param hasToolCalls 是否包含工具调用
*/
private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) {
// 重新构建pipeline
context = new InteractContext();
protocolTransformer = createProtocolTransformer(provider);
chunkCallbackTransformer = new TestChunkCallbackTransformer();
pipeline = ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, chunkCallbackTransformer);
// 读取测试数据
List<String> streamChunks = TestDataReader.getStreamingChunks(provider, requestType);
assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空");
// 处理流式数据
for (String streamChunk : streamChunks) {
pipeline.processStreaming(streamChunk);
// 注意这里不强制要求最后一次返回true因为不同模型的结束标志可能不同
}
// 获取最终响应
ChatResponse chatResponse = pipeline.buildFinalStreamingResponse();
// 验证响应
assertNotNull(chatResponse, "聊天响应不应为null");
assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
// 验证工具调用
if (hasToolCalls) {
assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用");
assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
// 验证工具调用内容
for (ToolCall toolCall : message.getToolCalls()) {
if (!provider.equals(ProviderEnum.OLLAMA)) {
assertNotNull(toolCall.getId(), "工具调用ID不应为null");
}
assertNotNull(toolCall.getName(), "工具调用名称不应为null");
assertNotNull(toolCall.getType(), "工具调用类型不应为null");
}
} else {
// 对于非工具调用响应,验证内容存在
assertNotNull(message.getContent(), "消息内容不应为null");
// 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
}
// 打印测试结果
System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ===");
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println(
"内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..."
: message.getContent()));
}
if (hasToolCalls && message.getToolCalls() != null) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
}
System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
System.out.println("完成原因: " + chatResponse.getFinishReason());
}
/**
* 测试阻塞式响应处理的通用方法
*
* @param provider 提供者
* @param responseType 响应类型
* @param hasToolCalls 是否包含工具调用
*/
private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) {
// 重新构建pipeline
context = new InteractContext();
protocolTransformer = createProtocolTransformer(provider);
pipeline = ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
// 读取测试数据
String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType);
assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null");
assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空");
// 处理阻塞式响应
ChatResponse chatResponse = pipeline.processBlocking(blockingResponse);
// 验证响应
assertNotNull(chatResponse, "聊天响应不应为null");
assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
// 验证工具调用
if (hasToolCalls) {
assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用");
assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
// 验证每个工具调用
for (ToolCall toolCall : message.getToolCalls()) {
if (!provider.equals(ProviderEnum.OLLAMA)) {
assertNotNull(toolCall.getId(), "工具调用ID不应为null");
}
assertNotNull(toolCall.getName(), "工具调用名称不应为null");
assertNotNull(toolCall.getType(), "工具调用类型不应为null");
assertEquals("function", toolCall.getType(), "工具调用类型应为function");
}
} else {
// 对于非工具调用响应,验证内容存在
assertNotNull(message.getContent(), "消息内容不应为null");
// 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
}
// 验证token使用情况如果存在
if (chatResponse.getTokenUsage() != null) {
assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0");
}
// 打印测试结果
System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ===");
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (hasToolCalls && message.getToolCalls() != null) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
System.out.println("完成原因: " + chatResponse.getFinishReason());
}
//
// private ProtocolTransformer protocolTransformer;
// private InteractContext context;
//
// /**
// * 创建协议转换器的工厂方法
// *
// * @param provider 提供者枚举
// * @return 对应的协议转换器实例
// */
// private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) {
// switch (provider) {
// case OPENAI:
// return new OpenAIProtocolTransformer();
// case DASHSCOPE:
// return new DashScopeProtocolTransformer();
// case OLLAMA:
// return new OllamaProtocolTransformer();
// default:
// throw new IllegalArgumentException("不支持的提供者: " + provider);
// }
// }
//
// @Nested
// @DisplayName("OpenAI协议转换器测试")
// class OpenAITests {
//
// @BeforeEach
// void setUp() {
// context = new InteractContext();
// protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI);
// }
//
// @Test
// @DisplayName("测试OpenAI流式文本响应转换")
// void testStreamingTextResponse() {
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试OpenAI流式工具调用响应转换")
// void testStreamingToolCallResponse() {
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试OpenAI流式结构化响应转换")
// void testStreamingStructuredResponse() {
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
// }
//
// @Test
// @DisplayName("测试OpenAI阻塞式文本响应转换")
// void testBlockingTextResponse() {
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试OpenAI阻塞式工具调用响应转换")
// void testBlockingToolCallResponse() {
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试OpenAI阻塞式结构化响应转换")
// void testBlockingStructuredResponse() {
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
// }
// }
//
// @Nested
// @DisplayName("DashScope协议转换器测试")
// class DashScopeTests {
//
// @BeforeEach
// void setUp() {
// context = new InteractContext();
// protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE);
// }
//
// @Test
// @DisplayName("测试DashScope流式文本响应转换")
// void testStreamingTextResponse() {
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试DashScope流式工具调用响应转换")
// void testStreamingToolCallResponse() {
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试DashScope流式结构化响应转换")
// void testStreamingStructuredResponse() {
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
// }
//
// @Test
// @DisplayName("测试DashScope阻塞式文本响应转换")
// void testBlockingTextResponse() {
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试DashScope阻塞式工具调用响应转换")
// void testBlockingToolCallResponse() {
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试DashScope阻塞式结构化响应转换")
// void testBlockingStructuredResponse() {
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
// }
// }
//
// @Nested
// @DisplayName("Ollama协议转换器测试")
// class OllamaTests {
//
// @BeforeEach
// void setUp() {
// context = new InteractContext();
// protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA);
// }
//
// @Test
// @DisplayName("测试Ollama流式文本响应转换")
// void testStreamingTextResponse() {
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试Ollama流式工具调用响应转换")
// void testStreamingToolCallResponse() {
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试Ollama流式结构化响应转换")
// void testStreamingStructuredResponse() {
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
// }
//
// @Test
// @DisplayName("测试Ollama阻塞式文本响应转换")
// void testBlockingTextResponse() {
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false);
// }
//
// @Test
// @DisplayName("测试Ollama阻塞式工具调用响应转换")
// void testBlockingToolCallResponse() {
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
// }
//
// @Test
// @DisplayName("测试Ollama阻塞式结构化响应转换")
// void testBlockingStructuredResponse() {
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
// }
// }
//
// @Nested
// @DisplayName("参数化测试 - 所有协议转换器")
// class ParameterizedTests {
//
// @ParameterizedTest
// @EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" })
// @DisplayName("参数化测试协议转换器提供者名称")
// void testProtocolTransformerProviderEnumName(ProviderEnum provider) {
// ProtocolTransformer transformer = createProtocolTransformer(provider);
// assertEquals(provider.getProviderName(), transformer.getProviderName(),
// provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'");
// }
// }
//
// /**
// * 测试流式响应处理的通用方法
// *
// * @param provider 提供者
// * @param requestType 响应类型
// * @param hasToolCalls 是否包含工具调用
// */
// private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) {
// // 重新构建pipeline
// context = new InteractContext();
// protocolTransformer = createProtocolTransformer(provider);
// pipeline = ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, chunkCallbackTransformer);
//
// // 读取测试数据
// List<String> streamChunks = TestDataReader.getStreamingChunks(provider, requestType);
// assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空");
//
// // 处理流式数据
// for (String streamChunk : streamChunks) {
// pipeline.processStreaming(streamChunk);
// // 注意这里不强制要求最后一次返回true因为不同模型的结束标志可能不同
// }
//
// // 获取最终响应
// ChatResponse chatResponse = pipeline.buildFinalStreamingResponse();
//
// // 验证响应
// assertNotNull(chatResponse, "聊天响应不应为null");
// assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
//
// AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
//
// // 验证工具调用
// if (hasToolCalls) {
// assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用");
// assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
//
// // 验证工具调用内容
// for (ToolCall toolCall : message.getToolCalls()) {
// if (!provider.equals(ProviderEnum.OLLAMA)) {
// assertNotNull(toolCall.getId(), "工具调用ID不应为null");
// }
// assertNotNull(toolCall.getName(), "工具调用名称不应为null");
// assertNotNull(toolCall.getType(), "工具调用类型不应为null");
// }
// } else {
// // 对于非工具调用响应,验证内容存在
// assertNotNull(message.getContent(), "消息内容不应为null");
// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
// }
//
// // 打印测试结果
// System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ===");
// if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
// System.out.println(
// "内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..."
// : message.getContent()));
// }
// if (hasToolCalls && message.getToolCalls() != null) {
// System.out.println("工具调用数量: " + message.getToolCalls().size());
// }
// System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
// System.out.println("完成原因: " + chatResponse.getFinishReason());
// }
//
// /**
// * 测试阻塞式响应处理的通用方法
// *
// * @param provider 提供者
// * @param responseType 响应类型
// * @param hasToolCalls 是否包含工具调用
// */
// private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) {
// // 重新构建pipeline
// context = new InteractContext();
// protocolTransformer = createProtocolTransformer(provider);
// pipeline = ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
//
// // 读取测试数据
// String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType);
// assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null");
// assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空");
//
// // 处理阻塞式响应
// ChatResponse chatResponse = pipeline.processBlocking(blockingResponse);
//
// // 验证响应
// assertNotNull(chatResponse, "聊天响应不应为null");
// assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
//
// AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
//
// // 验证工具调用
// if (hasToolCalls) {
// assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用");
// assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
//
// // 验证每个工具调用
// for (ToolCall toolCall : message.getToolCalls()) {
// if (!provider.equals(ProviderEnum.OLLAMA)) {
// assertNotNull(toolCall.getId(), "工具调用ID不应为null");
// }
// assertNotNull(toolCall.getName(), "工具调用名称不应为null");
// assertNotNull(toolCall.getType(), "工具调用类型不应为null");
// assertEquals("function", toolCall.getType(), "工具调用类型应为function");
// }
// } else {
// // 对于非工具调用响应,验证内容存在
// assertNotNull(message.getContent(), "消息内容不应为null");
// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
// }
//
// // 验证token使用情况如果存在
// if (chatResponse.getTokenUsage() != null) {
// assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0");
// }
//
// // 打印测试结果
// System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ===");
// if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
// System.out.println("内容长度: " + message.getContent().length());
// if (message.getContent().length() > 200) {
// System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
// } else {
// System.out.println("内容: " + message.getContent());
// }
// }
// if (hasToolCalls && message.getToolCalls() != null) {
// System.out.println("工具调用数量: " + message.getToolCalls().size());
// for (int i = 0; i < message.getToolCalls().size(); i++) {
// ToolCall toolCall = message.getToolCalls().get(i);
// System.out.println("工具调用 " + (i + 1) + ":");
// System.out.println(" ID: " + toolCall.getId());
// System.out.println(" 名称: " + toolCall.getName());
// System.out.println(" 类型: " + toolCall.getType());
// System.out.println(" 参数: " + toolCall.getArguments());
// }
// }
// System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
// System.out.println("完成原因: " + chatResponse.getFinishReason());
// }
}

View File

@@ -1,10 +1,12 @@
package com.yomahub.liteflow.test.ai.mock.mockbean;
import com.yomahub.liteflow.ai.engine.interact.LlmInteractClient;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import io.reactivex.rxjava3.core.Flowable;
import org.mockito.Mockito;
import java.util.concurrent.CompletableFuture;
@@ -51,14 +53,14 @@ public class MockInteractClient extends LlmInteractClient {
* 重写 stream 方法
*/
@Override
public void stream(ChatConfig config, ChatRequest request) {
public Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request) {
// 创建一个 spied request它在内部被配置为使用 MockTransport
ChatRequest spiedRequest = createSpiedRequest(request);
// 调用父类的 stream 方法。
// 当父类的 InteractManager 构造函数被调用时,
// 它将使用我们的 spiedRequest并最终获取到 MockTransport。
super.stream(config, spiedRequest);
return super.stream(config, spiedRequest);
}
/**

View File

@@ -1,12 +1,10 @@
package com.yomahub.liteflow.test.ai.mock.mockbean;
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import io.reactivex.rxjava3.core.Flowable;
import java.util.List;
@@ -19,43 +17,29 @@ import java.util.List;
public class MockTransport implements Transport {
private final String blockingResponse;
private final List<String> streamingChunks;
private final MockConfig config;
private ChunkProcessPipeline pipeline;
private TransportListener listener = TransportListener.getDefault();
private boolean isStop = false;
public MockTransport(MockConfig config) {
TestDataReader.RequestType curRequestType = config.getRequestType();
this.blockingResponse = TestDataReader.getBlockingResponse(config.getProvider(), curRequestType);
this.streamingChunks = TestDataReader.getStreamingChunks(config.getProvider(), curRequestType);
this.config = config;
}
@Override
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
this.listener = listener;
this.pipeline = pipeline;
listener.onStart(pipeline.getContext());
for (String streamingChunk : streamingChunks) {
pipeline.processStreaming(streamingChunk);
}
close();
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
List<String> streamingChunks = TestDataReader.getStreamingChunks(this.config.getProvider(), this.config.getRequestType());
return Flowable.fromIterable(streamingChunks);
}
@Override
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
this.pipeline = pipeline;
return pipeline.processBlocking(blockingResponse);
public String startBlocking(ChatConfig config, ChatRequest request) {
return TestDataReader.getBlockingResponse(this.config.getProvider(), this.config.getRequestType());
}
@Override
public void close() {
if (!this.isStop) {
this.isStop = true;
this.listener.onClose(this.pipeline.getContext());
}
}
}

View File

@@ -1,24 +1,27 @@
package com.yomahub.liteflow.test.ai.model.chat.dashscope;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatModel;
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* 阿里百炼 chat 测试
@@ -58,7 +61,7 @@ public class DashScopeChatTest extends MockAITest {
}
@Test
public void testStreaming() throws ExecutionException, InterruptedException {
public void testStreaming() {
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT);
List<Message> messages = Arrays.asList(
@@ -66,21 +69,17 @@ public class DashScopeChatTest extends MockAITest {
new UserMessage("请给我讲一个关于宇宙探索的短故事")
);
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -99,44 +98,6 @@ public class DashScopeChatTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = DashScopeChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = DashScopeChatRequest.builder();
}
}

View File

@@ -1,24 +1,27 @@
package com.yomahub.liteflow.test.ai.model.chat.ollama;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* Ollama chat 测试
@@ -57,7 +60,7 @@ public class OllamaChatTest extends MockAITest {
}
@Test
public void testStreaming() throws ExecutionException, InterruptedException {
public void testStreaming() {
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT);
List<Message> messages = Arrays.asList(
@@ -65,21 +68,17 @@ public class OllamaChatTest extends MockAITest {
new UserMessage("why is the sky blue?")
);
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.DN_JSON)
.messages(messages)
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -96,44 +95,6 @@ public class OllamaChatTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OllamaChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OllamaChatRequest.builder();
}
}

View File

@@ -1,23 +1,27 @@
package com.yomahub.liteflow.test.ai.model.chat.openai;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatModel;
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
@@ -61,20 +65,17 @@ public class OpenAIChatTest extends MockAITest {
new SystemMessage("You are a helpful assistant."),
new UserMessage("请给我讲一个关于未来城市的短故事"));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -92,44 +93,6 @@ public class OpenAIChatTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OpenAIChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OpenAIChatRequest.builder();
}
}

View File

@@ -1,26 +1,29 @@
package com.yomahub.liteflow.test.ai.model.structure.dashscope;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatModel;
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* 阿里百炼 结构化输出 测试
@@ -70,7 +73,7 @@ public class DashScopeStructureTest extends MockAITest {
}
@Test
public void testStructureStreaming() throws ExecutionException, InterruptedException {
public void testStructureStreaming() {
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED);
List<Message> messages = Arrays.asList(
@@ -78,9 +81,7 @@ public class DashScopeStructureTest extends MockAITest {
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1")
);
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
@@ -89,14 +90,11 @@ public class DashScopeStructureTest extends MockAITest {
.responseType(ResponseType.JSON)
.targetType(MathReasoning.class)
// 结构化输出相关配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build()
);
.build());
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
// 将响应转换为结构化结果对象
MathReasoning result = response.as(MathReasoning.class);
@@ -121,44 +119,6 @@ public class DashScopeStructureTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = DashScopeChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = DashScopeChatRequest.builder();
}
}

View File

@@ -1,26 +1,29 @@
package com.yomahub.liteflow.test.ai.model.structure.ollama;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* Ollama 结构化输出 测试
@@ -68,16 +71,14 @@ public class OllamaStructureTest extends MockAITest {
}
@Test
public void testStructureStreaming() throws ExecutionException, InterruptedException {
public void testStructureStreaming() {
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED);
List<Message> messages = Arrays.asList(
new SystemMessage("你是一位数学辅导老师"),
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1"));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.DN_JSON)
@@ -86,15 +87,12 @@ public class OllamaStructureTest extends MockAITest {
.responseType(ResponseType.JSON)
.targetType(MathReasoning.class)
// 结构化输出相关配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
// 将响应转换为结构化结果对象
MathReasoning result = response.as(MathReasoning.class);
Assertions.assertNotNull(result);
@@ -116,44 +114,6 @@ public class OllamaStructureTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OllamaChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OllamaChatRequest.builder();
}
}

View File

@@ -1,25 +1,29 @@
package com.yomahub.liteflow.test.ai.model.structure.openai;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatModel;
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatRequest;
import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
@@ -74,9 +78,7 @@ public class OpenAIStructureTest extends MockAITest {
new SystemMessage("你是一位数学辅导老师"),
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1"));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModel.stream(
Flowable<ChunkEvent> stream = chatModel.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
@@ -85,13 +87,11 @@ public class OpenAIStructureTest extends MockAITest {
.responseType(ResponseType.JSON)
.targetType(MathReasoning.class)
// 结构化输出相关配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
ChatResponse response = future.get();
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
.blockingLast()
.getFinalResponse();
// 将响应转换为结构化结果对象
MathReasoning result = response.as(MathReasoning.class);
@@ -116,44 +116,6 @@ public class OpenAIStructureTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OpenAIChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OpenAIChatRequest.builder();
}
}

View File

@@ -1,12 +1,15 @@
package com.yomahub.liteflow.test.ai.model.tool.dashscope;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* 阿里百炼 工具调用 测试
@@ -75,7 +76,7 @@ public class DashScopeToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithAutoToolCall() {
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
List<Message> messages = new ArrayList<>();
@@ -87,24 +88,18 @@ public class DashScopeToolTest extends MockAITest {
Collections.singletonList(toolRegistry.getTool("assemble_tool"))
);
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithAutoToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -158,13 +153,7 @@ public class DashScopeToolTest extends MockAITest {
Assertions.assertEquals("weather_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = weatherTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -190,7 +179,7 @@ public class DashScopeToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithManualToolCall() {
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL);
List<Message> messages = new ArrayList<>();
@@ -202,24 +191,18 @@ public class DashScopeToolTest extends MockAITest {
Collections.singletonList(toolRegistry.getTool("assemble_tool"))
);
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -231,13 +214,7 @@ public class DashScopeToolTest extends MockAITest {
Assertions.assertEquals("assemble_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = assembleTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -247,24 +224,18 @@ public class DashScopeToolTest extends MockAITest {
MockConfigHolder.clear();
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future2.complete(chatResponse);
return chatResponse;
})
.build()
);
response = future2.get();
response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -311,44 +282,6 @@ public class DashScopeToolTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = DashScopeChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = DashScopeChatRequest.builder();
}
}

View File

@@ -1,12 +1,15 @@
package com.yomahub.liteflow.test.ai.model.tool.ollama;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* Ollama 工具调用 测试
@@ -72,7 +73,7 @@ public class OllamaToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithAutoToolCall() {
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL,
TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
@@ -83,23 +84,18 @@ public class OllamaToolTest extends MockAITest {
ToolRegistry assembleTool = new StaticToolRegistry(
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithAutoToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.DN_JSON)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -137,13 +133,7 @@ public class OllamaToolTest extends MockAITest {
Assertions.assertEquals("weather_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = weatherTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -168,7 +158,7 @@ public class OllamaToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithManualToolCall() {
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL);
List<Message> messages = new ArrayList<>();
@@ -178,23 +168,18 @@ public class OllamaToolTest extends MockAITest {
ToolRegistry assembleTool = new StaticToolRegistry(
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.DN_JSON)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -206,13 +191,8 @@ public class OllamaToolTest extends MockAITest {
Assertions.assertEquals("assemble_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = assembleTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -222,23 +202,18 @@ public class OllamaToolTest extends MockAITest {
MockConfigHolder.clear();
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.DN_JSON)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future2.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
response = future2.get();
response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -269,44 +244,6 @@ public class OllamaToolTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OllamaChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OllamaChatRequest.builder();
}
}

View File

@@ -1,12 +1,15 @@
package com.yomahub.liteflow.test.ai.model.tool.openai;
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
import io.reactivex.rxjava3.core.Flowable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
/**
* OpenAI 工具调用 测试
@@ -72,7 +73,7 @@ public class OpenAIToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithAutoToolCall() {
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL,
TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
@@ -83,23 +84,18 @@ public class OpenAIToolTest extends MockAITest {
ToolRegistry assembleTool = new StaticToolRegistry(
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithAutoToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -137,13 +133,8 @@ public class OpenAIToolTest extends MockAITest {
Assertions.assertEquals("weather_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = weatherTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -168,7 +159,7 @@ public class OpenAIToolTest extends MockAITest {
}
@Test
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
public void testToolCallStreamingWithManualToolCall() {
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL);
List<Message> messages = new ArrayList<>();
@@ -178,23 +169,18 @@ public class OpenAIToolTest extends MockAITest {
ToolRegistry assembleTool = new StaticToolRegistry(
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
ChatResponse response = future.get();
ChatResponse response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -206,13 +192,7 @@ public class OpenAIToolTest extends MockAITest {
Assertions.assertEquals("assemble_tool", toolCall.getName());
// 执行 ToolCall
ToolCallBack toolCallBack = assembleTool.getAllTools()
.stream()
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
.findFirst()
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
// 和 AI 消息一起添加回上下文
messages.add(response.getOutput());
@@ -222,23 +202,18 @@ public class OpenAIToolTest extends MockAITest {
MockConfigHolder.clear();
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
chatModelWithManualToolCall.stream(
stream = chatModelWithManualToolCall.stream(
chatRequestBuilder
.streaming(true)
.transportType(TransportType.SSE)
.messages(messages)
// 工具调用配置
.toolRegistry(assembleTool)
// 工具调用配置
.onFinal((chatResponse, context) -> {
future2.complete(chatResponse);
return chatResponse;
})
.build());
.build()
);
response = future2.get();
response = stream.blockingLast()
.getFinalResponse();
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
@@ -271,44 +246,6 @@ public class OpenAIToolTest extends MockAITest {
.interactClient(new MockInteractClient())
.build();
chatRequestBuilder = OpenAIChatRequest.builder()
.onStart(context -> System.out.println("chat start"))
.onClose(context -> System.out.println("chat close"))
.onError((context, t) -> {
throw new RuntimeException(t);
})
.onText((content, context) -> {
System.out.println("Received text: " + content);
return content;
})
.onThinking((content, context) -> {
System.out.println("Received thinking: " + content);
return content;
})
.onCompletion((response, context) -> {
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
return response;
});
chatRequestBuilder = OpenAIChatRequest.builder();
}
}

View File

@@ -0,0 +1,55 @@
package com.yomahub.liteflow.test.ai.model.util;
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
import io.reactivex.rxjava3.functions.Consumer;
public class StreamUtil {
public static Consumer<ChunkEvent> getChunkEventConsumer() {
return chunkEvent -> {
if (chunkEvent.isStart()) {
System.out.println("chat start");
}
if (chunkEvent.isChunk()) {
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
switch (chunk.getType()) {
case TEXT:
System.out.println(chunk.getData());
break;
case THINKING:
System.out.println("[Thinking] " + chunk.getData());
}
}
if (chunkEvent.isComplete()) {
ChatResponse response = chunkEvent.getFinalResponse();
AssistantMessage message = response.getOutput();
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
System.out.println("内容长度: " + message.getContent().length());
if (message.getContent().length() > 200) {
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
} else {
System.out.println("内容: " + message.getContent());
}
}
if (response.hasToolCalls()) {
System.out.println("工具调用数量: " + message.getToolCalls().size());
for (int i = 0; i < message.getToolCalls().size(); i++) {
ToolCall toolCall = message.getToolCalls().get(i);
System.out.println("工具调用 " + (i + 1) + ":");
System.out.println(" ID: " + toolCall.getId());
System.out.println(" 名称: " + toolCall.getName());
System.out.println(" 类型: " + toolCall.getType());
System.out.println(" 参数: " + toolCall.getArguments());
}
}
System.out.println("Token使用情况: " + response.getTokenUsage());
System.out.println("完成原因: " + response.getFinishReason());
System.out.println("chat close");
}
};
}
}

View File

@@ -6,4 +6,4 @@ liteflow:
openai:
api-key: mock-api-key # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口
dashscope:
api-key: mock-api-key # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口
api-key: sk-70c4b25e826e48868b470ff114419441 # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口