mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-13 19:18:16 +08:00
Feat & Refactor: 引入 rxjava,将流式回调逻辑重构为响应式逻辑
This commit is contained in:
@@ -42,7 +42,7 @@ public class LiteFlowAIAutoConfiguration {
|
||||
|
||||
@Bean
|
||||
public StreamHandler streamHandler() {
|
||||
return StreamHandler.builder().build();
|
||||
return StreamHandler.passThrough();
|
||||
}
|
||||
|
||||
@Bean
|
||||
|
||||
@@ -1,290 +1,66 @@
|
||||
package com.yomahub.liteflow.ai.context;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
/**
|
||||
* 流式输出处理器
|
||||
* 响应式流处理器
|
||||
* <p>
|
||||
* 用户通过实现此接口,直接对 ChunkEvent 流进行响应式编程,
|
||||
* 对流进行任意的变换、过滤、聚合等操作,最后返回处理后的 Flowable 流。
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
@FunctionalInterface
|
||||
public interface StreamHandler {
|
||||
|
||||
static Builder builder() {
|
||||
return new Builder();
|
||||
/**
|
||||
* 对 ChunkEvent 流进行响应式处理
|
||||
* <p>
|
||||
* 用户可以在此方法中对事件流进行任意的响应式操作,如:
|
||||
* - 过滤特定类型的事件
|
||||
* - 转换事件数据
|
||||
* - 聚合多个事件
|
||||
* - 添加副作用(如日志、监控)
|
||||
* - 处理错误和超时
|
||||
* <p>
|
||||
* 示例:
|
||||
* <pre>
|
||||
* StreamHandler handler = eventStream -> eventStream
|
||||
* .filter(ChunkEvent::isChunk)
|
||||
* .doOnNext(event -> System.out.println("Received chunk: " + event))
|
||||
* .onErrorRetry(3);
|
||||
* </pre>
|
||||
*
|
||||
* @param eventStream 原始的 ChunkEvent 流
|
||||
* @return 处理后的 ChunkEvent 流
|
||||
*/
|
||||
Flowable<ChunkEvent> handle(Flowable<ChunkEvent> eventStream);
|
||||
|
||||
/**
|
||||
* 创建一个 pass-through 处理器,不做任何转换直接返回原始流
|
||||
*
|
||||
* @return pass-through StreamHandler
|
||||
*/
|
||||
static StreamHandler passThrough() {
|
||||
return eventStream -> eventStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求开始时的回调方法
|
||||
* 创建一个组合多个处理器的处理器
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @param handlers 处理器数组
|
||||
* @return 组合后的 StreamHandler
|
||||
*/
|
||||
void onStart(InteractContext context);
|
||||
|
||||
/**
|
||||
* 请求关闭时的回调方法
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
void onClose(InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理过程中发生错误的回调方法。
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @param t
|
||||
*/
|
||||
void onError(InteractContext context, Throwable t);
|
||||
|
||||
/**
|
||||
* 处理文本消息的回调方法。需要启用流式调用
|
||||
*
|
||||
* @param content 文本内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
String onText(String content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理思考消息的回调方法。需要启用流式调用
|
||||
*
|
||||
* @param content 思考内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
String onThinking(String content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理工具调用消息的回调方法。需要启用流式调用
|
||||
*
|
||||
* @param toolCalls 工具调用内容,可能是工具调用的结果或相关信息
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理 Token 统计信息的回调方法。需要启用流式调用
|
||||
*
|
||||
* @param content Token 统计信息内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
Object onUsage(Object content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理基础信息/搜索结果的回调方法。需要启用流式调用
|
||||
*
|
||||
* @param content 基础信息或搜索结果内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
// TODO args
|
||||
Object onGrounding(Object content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 消息处理完成的回调方法。(流式传输/阻塞式传输均会调用)
|
||||
*
|
||||
* @param response 处理后的聊天响应结果
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @return 处理后的结果
|
||||
*/
|
||||
ChatResponse onCompletion(ChatResponse response, InteractContext context);
|
||||
|
||||
/**
|
||||
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。(流式传输/阻塞式传输均会调用)
|
||||
*
|
||||
* @param response 最终的聊天响应结果
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @return 处理后的结果
|
||||
*/
|
||||
ChatResponse onFinal(ChatResponse response, InteractContext context);
|
||||
|
||||
class Builder {
|
||||
Consumer<InteractContext> onStart = context -> {
|
||||
static StreamHandler composite(StreamHandler... handlers) {
|
||||
return eventStream -> {
|
||||
Flowable<ChunkEvent> result = eventStream;
|
||||
for (StreamHandler handler : handlers) {
|
||||
result = handler.handle(result);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
Consumer<InteractContext> onClose = context -> {
|
||||
};
|
||||
BiConsumer<InteractContext, Throwable> onError = (context, t) -> {
|
||||
throw new LiteFlowAIEngineException(t.getMessage(), t);
|
||||
};
|
||||
BiFunction<String, InteractContext, String> onText = (content, context) -> content;
|
||||
BiFunction<String, InteractContext, String> onThinking = (content, context) -> content;
|
||||
BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling = (toolCalls, context) -> toolCalls;
|
||||
BiFunction<Object, InteractContext, Object> onUsage = (content, context) -> content;
|
||||
BiFunction<Object, InteractContext, Object> onGrounding = (content, context) -> content;
|
||||
BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion = (response, context) -> response;
|
||||
BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal = (response, context) -> response;
|
||||
|
||||
/**
|
||||
* 请求开始时的回调方法
|
||||
*
|
||||
* @param onStart 请求开始时的回调函数
|
||||
* @see TransportListener#onStart(InteractContext)
|
||||
*/
|
||||
public Builder onStart(Consumer<InteractContext> onStart) {
|
||||
this.onStart = onStart;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求结束时的回调方法
|
||||
*
|
||||
* @param onClose 请求结束时的回调函数
|
||||
* @see TransportListener#onClose(InteractContext)
|
||||
*/
|
||||
public Builder onClose(Consumer<InteractContext> onClose) {
|
||||
this.onClose = onClose;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求发生错误时的回调方法
|
||||
*
|
||||
* @param onError 请求发生错误时的回调函数
|
||||
* @see TransportListener#onError(InteractContext, Throwable)
|
||||
*/
|
||||
public Builder onError(BiConsumer<InteractContext, Throwable> onError) {
|
||||
this.onError = onError;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 文本消息的回调方法
|
||||
*
|
||||
* @param onText 文本消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onText(String, InteractContext)
|
||||
*/
|
||||
public Builder onText(BiFunction<String, InteractContext, String> onText) {
|
||||
this.onText = onText;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 思考消息的回调方法
|
||||
*
|
||||
* @param onThinking 思考消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onThinking(String, InteractContext)
|
||||
*/
|
||||
public Builder onThinking(BiFunction<String, InteractContext, String> onThinking) {
|
||||
this.onThinking = onThinking;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具调用消息的回调方法
|
||||
*
|
||||
* @param onToolsCalling 工具调用消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onToolsCalling(List, InteractContext)
|
||||
*/
|
||||
public Builder onToolsCalling(BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling) {
|
||||
this.onToolsCalling = onToolsCalling;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Token 统计信息的回调方法
|
||||
*
|
||||
* @param onUsage Token 统计信息的回调函数
|
||||
* @see ChunkCallbackTransformer#onUsage(Object, InteractContext)
|
||||
*/
|
||||
public Builder onUsage(BiFunction<Object, InteractContext, Object> onUsage) {
|
||||
this.onUsage = onUsage;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础信息/搜索结果的回调方法
|
||||
*
|
||||
* @param onGrounding 基础信息/搜索结果的回调函数
|
||||
* @see ChunkCallbackTransformer#onGrounding(Object, InteractContext)
|
||||
*/
|
||||
public Builder onGrounding(BiFunction<Object, InteractContext, Object> onGrounding) {
|
||||
this.onGrounding = onGrounding;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求完成时的回调方法(流式传输/阻塞式传输均会调用)
|
||||
*
|
||||
* @param onCompletion 请求完成时的回调函数
|
||||
* @see ResultHandler#onCompletion(ChatResponse, InteractContext)
|
||||
*/
|
||||
public Builder onCompletion(BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion) {
|
||||
this.onCompletion = onCompletion;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。(流式传输/阻塞式传输均会调用)
|
||||
*
|
||||
* @param onFinal 请求最终结果的回调函数
|
||||
* @see ResultHandler#onFinal(ChatResponse, InteractContext)
|
||||
*/
|
||||
public Builder onFinal(BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal) {
|
||||
this.onFinal = onFinal;
|
||||
return this;
|
||||
}
|
||||
|
||||
public StreamHandler build() {
|
||||
return new StreamHandler() {
|
||||
@Override
|
||||
public void onStart(InteractContext context) {
|
||||
onStart.accept(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(InteractContext context) {
|
||||
onClose.accept(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(InteractContext context, Throwable t) {
|
||||
onError.accept(context, t);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onText(String content, InteractContext context) {
|
||||
return onText.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onThinking(String content, InteractContext context) {
|
||||
return onThinking.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
|
||||
return onToolsCalling.apply(toolCalls, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onUsage(Object content, InteractContext context) {
|
||||
return onUsage.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onGrounding(Object content, InteractContext context) {
|
||||
return onGrounding.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
|
||||
return onCompletion.apply(response, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
|
||||
return onFinal.apply(response, context);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.yomahub.liteflow.ai.parse.assemble;
|
||||
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
|
||||
import com.yomahub.liteflow.ai.domain.dto.ParsedChatAnnotationConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
|
||||
@@ -32,23 +31,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
|
||||
protected ChatRequest doAssemble(ParsedChatAnnotationConfig annotationConfig, ModelConfigAggregator config, ChatContext context) {
|
||||
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
|
||||
|
||||
// 1. 连接 StreamHandler 回调
|
||||
StreamHandler streamHandler = context.getStreamHandler();
|
||||
if (Objects.nonNull(streamHandler)) {
|
||||
LOG.info("Connecting StreamHandler to ChatRequest");
|
||||
builder.onStart(streamHandler::onStart)
|
||||
.onClose(streamHandler::onClose)
|
||||
.onError(streamHandler::onError)
|
||||
.onText(streamHandler::onText)
|
||||
.onThinking(streamHandler::onThinking)
|
||||
.onToolsCalling(streamHandler::onToolsCalling)
|
||||
.onUsage(streamHandler::onUsage)
|
||||
.onGrounding(streamHandler::onGrounding)
|
||||
.onCompletion(streamHandler::onCompletion)
|
||||
.onFinal(streamHandler::onFinal);
|
||||
}
|
||||
|
||||
// 2. ChatOptions
|
||||
// 1. ChatOptions
|
||||
ChatOptions.Builder<?> optionsBuilder = ChatOptions.builder();
|
||||
setIfPresent(optionsBuilder::temperature, config.getTemperature());
|
||||
setIfPresent(optionsBuilder::topP, config.getTopP());
|
||||
@@ -58,7 +41,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
|
||||
setIfPresent(optionsBuilder::enableThinking, config.getEnableThinking());
|
||||
builder.options(optionsBuilder.build());
|
||||
|
||||
// 3. Message
|
||||
// 2. Message
|
||||
List<Message> messages = Optional.ofNullable(annotationConfig.getHistory())
|
||||
.orElse(new ArrayList<>());
|
||||
if (messages.isEmpty()) {
|
||||
@@ -68,7 +51,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ParsedChatAnn
|
||||
|
||||
builder.messages(messages);
|
||||
|
||||
// 4. streaming 相关参数
|
||||
// 3. streaming 相关参数
|
||||
builder.streaming(annotationConfig.isStreaming());
|
||||
builder.transportType(annotationConfig.getTransportType());
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.yomahub.liteflow.ai.parse.assemble;
|
||||
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
|
||||
import com.yomahub.liteflow.ai.domain.dto.ParsedClassifyAnnotationConfig;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
@@ -35,23 +34,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
|
||||
protected ChatRequest doAssemble(ParsedClassifyAnnotationConfig annotationConfig, ModelConfigAggregator config, ChatContext context) {
|
||||
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
|
||||
|
||||
// 1. 连接 StreamHandler 回调
|
||||
StreamHandler streamHandler = context.getStreamHandler();
|
||||
if (Objects.nonNull(streamHandler)) {
|
||||
LOG.info("Connecting StreamHandler to ChatRequest");
|
||||
builder.onStart(streamHandler::onStart)
|
||||
.onClose(streamHandler::onClose)
|
||||
.onError(streamHandler::onError)
|
||||
.onText(streamHandler::onText)
|
||||
.onThinking(streamHandler::onThinking)
|
||||
.onToolsCalling(streamHandler::onToolsCalling)
|
||||
.onUsage(streamHandler::onUsage)
|
||||
.onGrounding(streamHandler::onGrounding)
|
||||
.onCompletion(streamHandler::onCompletion)
|
||||
.onFinal(streamHandler::onFinal);
|
||||
}
|
||||
|
||||
// 2. ChatOptions
|
||||
// 1. ChatOptions
|
||||
ChatOptions.Builder<?> optionsBuilder = ChatOptions.builder();
|
||||
setIfPresent(optionsBuilder::temperature, config.getTemperature());
|
||||
setIfPresent(optionsBuilder::topP, config.getTopP());
|
||||
@@ -62,7 +45,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
|
||||
builder.options(optionsBuilder.build());
|
||||
|
||||
|
||||
// 3. Message
|
||||
// 2. Message
|
||||
List<Message> messages = Optional.ofNullable(annotationConfig.getHistory())
|
||||
.orElse(new ArrayList<>());
|
||||
if (messages.isEmpty()) {
|
||||
@@ -77,7 +60,7 @@ public class ClassifyRequestAssembler extends AbstractRequestAssembler<ParsedCla
|
||||
|
||||
builder.messages(messages);
|
||||
|
||||
// 4. streaming 相关参数
|
||||
// 3. streaming 相关参数
|
||||
// 定死使用阻塞式传输
|
||||
builder.streaming(false);
|
||||
builder.transportType(TransportType.HTTP);
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
package com.yomahub.liteflow.ai.proxy.invocation;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
|
||||
import com.yomahub.liteflow.ai.model.ModelFactory;
|
||||
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
import com.yomahub.liteflow.ai.proxy.wrap.ChatProxyWrapBean;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* 聊天组件的调用处理器
|
||||
@@ -39,53 +37,26 @@ public class ChatAIInvocationHandler extends AbstractAIInvocationHandler<ChatPro
|
||||
ChatRequest chatRequest = processorContext.getModelRequest().toChatRequest();
|
||||
|
||||
if (chatRequest.isStreaming()) {
|
||||
return processStreaming(chatModel, chatRequest);
|
||||
return processStreaming(chatModel, chatRequest, processorContext.getChatContext().getStreamHandler());
|
||||
} else {
|
||||
return chatModel.chat(chatRequest);
|
||||
}
|
||||
}
|
||||
|
||||
private ChatResponse processStreaming(ChatModel chatModel, ChatRequest chatRequest) {
|
||||
// 创建一个 CompletableFuture 用于异步处理,他将持有最终的 ChatResponse
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
// 获取用户传入的 ResultHandler
|
||||
final ResultHandler externalResultHandler = chatRequest.getResultHandler();
|
||||
|
||||
// 定义内部 ResultHandler,包装用户的处理逻辑
|
||||
ResultHandler internalResultHandler = new ResultHandler() {
|
||||
@Override
|
||||
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
|
||||
if (Objects.nonNull(externalResultHandler)) {
|
||||
return externalResultHandler.onCompletion(response, context);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
|
||||
ChatResponse finalResponse = response;
|
||||
try {
|
||||
if (Objects.nonNull(externalResultHandler)) {
|
||||
finalResponse = externalResultHandler.onFinal(response, context);
|
||||
}
|
||||
} finally {
|
||||
future.complete(finalResponse);
|
||||
}
|
||||
return finalResponse;
|
||||
}
|
||||
};
|
||||
|
||||
// 设置内部 ResultHandler 到请求中
|
||||
chatRequest.setResultHandler(internalResultHandler);
|
||||
// 执行流式聊天请求
|
||||
chatModel.stream(chatRequest);
|
||||
|
||||
// 阻塞等待 CompletableFuture 完成,并返回最终的 ChatResponse
|
||||
try {
|
||||
return future.get();
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new LiteFlowAIException("error while processing streaming chat request", e);
|
||||
private ChatResponse processStreaming(ChatModel chatModel, ChatRequest chatRequest, StreamHandler streamHandler) {
|
||||
// 获取用户定义的 StreamHandler(如果有)
|
||||
if (Objects.isNull(streamHandler)) {
|
||||
// 如果未提供,使用 pass-through 处理器
|
||||
streamHandler = StreamHandler.passThrough();
|
||||
}
|
||||
|
||||
// 获取原始的事件流
|
||||
Flowable<ChunkEvent> eventStream = chatModel.stream(chatRequest);
|
||||
|
||||
// 应用用户的 StreamHandler 进行响应式转换
|
||||
Flowable<ChunkEvent> handledStream = streamHandler.handle(eventStream);
|
||||
|
||||
return handledStream.blockingLast()
|
||||
.getFinalResponse();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,6 +57,11 @@
|
||||
<groupId>com.github.victools</groupId>
|
||||
<artifactId>jsonschema-module-jackson</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.reactivex.rxjava3</groupId>
|
||||
<artifactId>rxjava</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -1,8 +1,10 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
@@ -15,7 +17,7 @@ import java.util.concurrent.CompletableFuture;
|
||||
|
||||
public interface InteractClient {
|
||||
|
||||
void stream(ChatConfig config, ChatRequest request);
|
||||
Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request);
|
||||
|
||||
ChatResponse chat(ChatConfig config, ChatRequest request);
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformerFactory;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLog;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
@@ -16,226 +14,241 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
import io.reactivex.rxjava3.core.BackpressureStrategy;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import io.reactivex.rxjava3.disposables.CompositeDisposable;
|
||||
import io.reactivex.rxjava3.disposables.Disposable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
/**
|
||||
* 大模型交互客户端,统筹消息传输、协议转换等功能。
|
||||
* 大模型交互客户端
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public class LlmInteractClient implements InteractClient {
|
||||
|
||||
private static final EngineLog LOG = EngineLogManager.getLogger(LlmInteractClient.class);
|
||||
|
||||
@Override
|
||||
public void stream(ChatConfig config, ChatRequest request) {
|
||||
InteractManager manager = new InteractManager(config, request);
|
||||
manager.executeStreaming();
|
||||
}
|
||||
|
||||
/**
|
||||
* 同步调用
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @return 聊天响应
|
||||
*/
|
||||
@Override
|
||||
public ChatResponse chat(ChatConfig config, ChatRequest request) {
|
||||
InteractManager interactManager = new InteractManager(config, request);
|
||||
return interactManager.executeBlocking();
|
||||
}
|
||||
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
|
||||
Transport transport = request.getTransportType().getTransportInstance();
|
||||
|
||||
@Override
|
||||
public CompletableFuture<ChatResponse> chatAsync(ChatConfig config, ChatRequest request) {
|
||||
CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
ChatResponse response; // 用于存储最终响应
|
||||
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
future.complete(chat(config, request));
|
||||
} catch (Exception e) {
|
||||
future.completeExceptionally(new LiteFlowAIEngineException("异步调用大模型失败", e));
|
||||
while (true) {
|
||||
InteractContext context = new InteractContext();
|
||||
|
||||
// 1. 执行单轮对话
|
||||
String responseBody = transport.startBlocking(config, request);
|
||||
response = protocolTransformer.transformBlockingResponse(responseBody, context);
|
||||
|
||||
// 2. 检查循环的 "退出条件"
|
||||
// 如果 AI 没有要求工具调用,或者配置禁用了自动调用,
|
||||
// 那么这就是最终答案,跳出循环。
|
||||
if (!response.hasToolCalls() || !config.isAutoToolCallEnabled()) {
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
return future;
|
||||
// 3. 获取并执行工具调用
|
||||
List<ToolCall> toolCalls = response.getOutput().getToolCalls();
|
||||
// 目前只执行单轮单次的工具调用
|
||||
ToolMessage toolMessage = request.getToolRegistry().executeToolCall(toolCalls.get(0));
|
||||
|
||||
// 4. 构建下一轮对话的上下文
|
||||
buildNextRoundMessages(request, response, toolMessage);
|
||||
}
|
||||
|
||||
// 5. 返回循环中断时的最后一个响应
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部执行器
|
||||
* 异步调用
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @return 异步聊天响应
|
||||
*/
|
||||
private static class InteractManager {
|
||||
private final ChatConfig config;
|
||||
private final ChatRequest request;
|
||||
private final InteractContext context;
|
||||
private final ChunkProcessPipeline pipeline;
|
||||
private final TransportListener externalTransportListener;
|
||||
private final ResultHandler resultHandler;
|
||||
private final Transport transport;
|
||||
private final InternalTransportListener internalTransportListener;
|
||||
@Override
|
||||
public CompletableFuture<ChatResponse> chatAsync(ChatConfig config, ChatRequest request) {
|
||||
return CompletableFuture.supplyAsync(() -> chat(config, request));
|
||||
}
|
||||
|
||||
public InteractManager(ChatConfig config, ChatRequest request) {
|
||||
this.config = config;
|
||||
this.request = request;
|
||||
this.context = new InteractContext();
|
||||
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
|
||||
this.pipeline = request.isStreaming()
|
||||
? ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, request.getChunkCallbackTransformer())
|
||||
: ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
|
||||
this.transport = request.getTransportType().getTransportInstance();
|
||||
this.externalTransportListener = request.getTransportListener();
|
||||
this.resultHandler = request.getResultHandler();
|
||||
this.internalTransportListener = new InternalTransportListener();
|
||||
}
|
||||
/**
|
||||
* 创建响应式流事件管道
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @return 包含 ChunkEvent 的流
|
||||
*/
|
||||
public Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request) {
|
||||
InteractContext context = new InteractContext();
|
||||
ProtocolTransformer protocolTransformer = ProtocolTransformerFactory.getTransformer(config.getProvider());
|
||||
|
||||
/**
|
||||
* 流式调用
|
||||
*/
|
||||
public void executeStreaming() {
|
||||
// 启动传输,使用内部监听器
|
||||
transport.start(config, request, pipeline, internalTransportListener);
|
||||
}
|
||||
return Flowable.just(ChunkEvent.start(context))
|
||||
.concatWith(
|
||||
streamRecursive(config, request, context, protocolTransformer)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部传输监听器,用于处理流式调用的各种事件
|
||||
*/
|
||||
private class InternalTransportListener implements TransportListener {
|
||||
/**
|
||||
* 内部递归流逻辑
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @param context 当前轮次的交互上下文
|
||||
* @param protocolTransformer 协议转换器
|
||||
* @return 包含 ChunkEvent 的流
|
||||
*/
|
||||
private Flowable<ChunkEvent> streamRecursive(ChatConfig config, ChatRequest request,
|
||||
InteractContext context, ProtocolTransformer protocolTransformer) {
|
||||
return Flowable.create(emitter -> {
|
||||
CompositeDisposable compositeDisposable = new CompositeDisposable();
|
||||
emitter.setDisposable(compositeDisposable);
|
||||
|
||||
@Override
|
||||
public void onStart(InteractContext context) {
|
||||
externalTransportListener.onStart(context);
|
||||
}
|
||||
Transport transport = request.getTransportType().getTransportInstance();
|
||||
|
||||
@Override
|
||||
public void onClose(InteractContext context) {
|
||||
ChatResponse finalResponse = null;
|
||||
// 判断是否需要继续进行工具调用
|
||||
boolean isContinuingWithToolCall = false;
|
||||
try {
|
||||
// 构造最终响应
|
||||
finalResponse = pipeline.buildFinalStreamingResponse();
|
||||
try {
|
||||
// 获取响应式流
|
||||
Disposable transportDisposable = transport.startStreaming(config, request)
|
||||
.subscribe(
|
||||
// onNext: 处理每个原始 JSON 分块
|
||||
rawChunk -> {
|
||||
try {
|
||||
// 使用协议转换器转换为框架标准格式
|
||||
StreamingProtocolChunk protocolChunk = protocolTransformer.transformStreamingChunk(rawChunk, context);
|
||||
|
||||
// 调用结果处理器的完成回调
|
||||
finalResponse = resultHandler.onCompletion(finalResponse, context);
|
||||
// 根据分块信息更新上下文
|
||||
updateContextFromChunk(context, protocolChunk);
|
||||
|
||||
// 工具调用
|
||||
if (finalResponse.hasToolCalls() && config.isAutoToolCallEnabled()) {
|
||||
// 1. 获取并执行工具调用
|
||||
List<ToolCall> toolCalls = finalResponse.getOutput().getToolCalls();
|
||||
// 目前只执行单轮单次的工具调用
|
||||
ToolMessage toolMessage = executeToolCall(toolCalls.get(0), request.getToolRegistry());
|
||||
// 发送分块事件
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onNext(ChunkEvent.chunk(rawChunk, protocolChunk, context));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onError(e); // 转发解析错误
|
||||
}
|
||||
}
|
||||
},
|
||||
// onError: 转发下游错误
|
||||
error -> {
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onError(error);
|
||||
}
|
||||
try {
|
||||
transport.close(); // 确保关闭当前轮次的 transport
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Error closing transport after onError", e);
|
||||
}
|
||||
},
|
||||
// onComplete: 流完成处理
|
||||
() -> {
|
||||
try {
|
||||
// 构建最终响应
|
||||
ChatResponse finalResponse = protocolTransformer.transformStreamingResponse(context);
|
||||
|
||||
// 2. 构建下一轮对话的上下文
|
||||
buildNextRoundMessages(request, finalResponse, toolMessage);
|
||||
// 检查是否需要工具调用
|
||||
if (finalResponse.hasToolCalls() && config.isAutoToolCallEnabled()) {
|
||||
if (emitter.isCancelled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 3. 递归调用下一轮消息
|
||||
// 设置标志位为 true,避免关闭 transport
|
||||
isContinuingWithToolCall = true;
|
||||
new InteractManager(config, request).executeStreaming();
|
||||
}
|
||||
// 1. 执行工具调用
|
||||
List<ToolCall> toolCalls = finalResponse.getOutput().getToolCalls();
|
||||
ToolMessage toolMessage = request.getToolRegistry().executeToolCall(toolCalls.get(0));
|
||||
|
||||
} catch (Exception e) {
|
||||
onError(context, e);
|
||||
} finally {
|
||||
// 不需要进行工具调用时才进行清理操作
|
||||
if (!isContinuingWithToolCall) {
|
||||
try {
|
||||
// 调用外部监听器的关闭事件
|
||||
externalTransportListener.onClose(context);
|
||||
} catch (Exception e) {
|
||||
onError(context, e);
|
||||
} finally {
|
||||
// 清理资源
|
||||
cleanup(finalResponse);
|
||||
}
|
||||
}
|
||||
// 如果需要进行工具调用,将清理的责任委托给下一轮调用
|
||||
// 2. 构建下一轮消息
|
||||
buildNextRoundMessages(request, finalResponse, toolMessage);
|
||||
|
||||
// 3. 为下一轮创建新的上下文
|
||||
InteractContext nextContext = new InteractContext();
|
||||
|
||||
// 4. 递归调用 "循环体",并将事件转发给当前 emitter
|
||||
Disposable recursiveDisposable = streamRecursive(config, request, nextContext, protocolTransformer)
|
||||
.subscribe(
|
||||
emitter::onNext,
|
||||
emitter::onError,
|
||||
emitter::onComplete
|
||||
);
|
||||
|
||||
compositeDisposable.add(recursiveDisposable);
|
||||
} else {
|
||||
if (!emitter.isCancelled()) {
|
||||
// 没有工具调用,发送 complete 事件并完成流
|
||||
emitter.onNext(ChunkEvent.complete(context, finalResponse));
|
||||
emitter.onComplete();
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onError(e); // 转发 onComplete 逻辑中的错误
|
||||
}
|
||||
} finally {
|
||||
try {
|
||||
transport.close(); // 确保关闭当前轮次的 transport
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Error closing transport", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
compositeDisposable.add(transportDisposable);
|
||||
} catch (Exception e) {
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onError(e);
|
||||
}
|
||||
}
|
||||
}, BackpressureStrategy.BUFFER);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(InteractContext context, Throwable t) {
|
||||
externalTransportListener.onError(context, t);
|
||||
}
|
||||
/**
|
||||
* 根据分块信息更新交互上下文
|
||||
*
|
||||
* @param context 交互上下文
|
||||
* @param chunk 协议分块
|
||||
*/
|
||||
private void updateContextFromChunk(InteractContext context, StreamingProtocolChunk chunk) {
|
||||
if (Objects.isNull(chunk)) {
|
||||
return;
|
||||
}
|
||||
|
||||
public ChatResponse executeBlocking() {
|
||||
ChatResponse response = null;
|
||||
try {
|
||||
externalTransportListener.onStart(context);
|
||||
|
||||
response = transport.startBlocking(config, request, pipeline);
|
||||
|
||||
response = resultHandler.onCompletion(response, context);
|
||||
|
||||
// 处理工具调用
|
||||
if (response.hasToolCalls() && config.isAutoToolCallEnabled()) {
|
||||
// 1. 获取并执行工具调用
|
||||
List<ToolCall> toolCalls = response.getOutput().getToolCalls();
|
||||
// 目前只执行单轮单次的工具调用
|
||||
ToolMessage toolMessage = executeToolCall(toolCalls.get(0), request.getToolRegistry());
|
||||
|
||||
// 2. 构建下一轮对话的上下文
|
||||
buildNextRoundMessages(request, response, toolMessage);
|
||||
|
||||
// 3. 递归调用下一轮消息
|
||||
response = new InteractManager(config, request).executeBlocking();
|
||||
}
|
||||
|
||||
return response;
|
||||
} catch (Exception e) {
|
||||
this.externalTransportListener.onError(context, e);
|
||||
return response;
|
||||
} finally {
|
||||
cleanup(response);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具调用
|
||||
*
|
||||
* @param toolCall 工具调用信息
|
||||
* @param toolRegistry 工具注册中心
|
||||
* @return 工具调用结果
|
||||
*/
|
||||
private ToolMessage executeToolCall(ToolCall toolCall, ToolRegistry toolRegistry) {
|
||||
// 找到对应的工具回调
|
||||
ToolCallBack toolCallBack = toolRegistry.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new LiteFlowAIEngineException(
|
||||
"Unable to find target tool with tool name: " + toolCall.getName()));
|
||||
// 调用工具
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
// 返回工具调用结果
|
||||
return new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具调用之后构建下一轮的消息列表
|
||||
*
|
||||
* @param request 原始请求
|
||||
* @param response 大模型响应(要求工具调用)
|
||||
* @param toolMessage 工具调用结果
|
||||
*/
|
||||
private void buildNextRoundMessages(ChatRequest request, ChatResponse response, ToolMessage toolMessage) {
|
||||
List<Message> messagesHistory = request.getMessages();
|
||||
messagesHistory.add(response.getOutput());
|
||||
messagesHistory.add(toolMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理资源
|
||||
*/
|
||||
private void cleanup(ChatResponse response) {
|
||||
try {
|
||||
resultHandler.onFinal(response, context);
|
||||
} catch (Exception e) {
|
||||
LOG.error("ResultHandler.onFinal 执行失败: {}", e.getMessage());
|
||||
} finally {
|
||||
transport.close();
|
||||
}
|
||||
switch (chunk.getType()) {
|
||||
case TEXT:
|
||||
context.addText((String) chunk.getData());
|
||||
break;
|
||||
case THINKING:
|
||||
context.addThinking((String) chunk.getData());
|
||||
break;
|
||||
default:
|
||||
// 其他类型暂不处理
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 为下一轮对话构建消息列表
|
||||
*
|
||||
* @param request 聊天请求
|
||||
* @param response 聊天响应
|
||||
* @param toolMessage 工具调用结果
|
||||
*/
|
||||
private void buildNextRoundMessages(ChatRequest request, ChatResponse response, ToolMessage toolMessage) {
|
||||
List<Message> messages = request.getMessages();
|
||||
messages.add(response.getOutput());
|
||||
messages.add(toolMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.callbacks;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 流式消息处理管道的回调接口。根据块数据的类型进行具体回调
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public interface ChunkCallbackTransformer extends ChunkTransformer {
|
||||
|
||||
static ChunkCallbackTransformer getDefault() {
|
||||
return new ChunkCallbackTransformer() {
|
||||
@Override
|
||||
public String onText(String content, InteractContext context) {
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onThinking(String content, InteractContext context) {
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
|
||||
return toolCalls;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onUsage(Object content, InteractContext context) {
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onGrounding(Object content, InteractContext context) {
|
||||
return content;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
default void transform(StreamingProtocolChunk transformedChunk, InteractContext context) {
|
||||
switch (transformedChunk.getType()) {
|
||||
case TEXT:
|
||||
String textContent = (String) transformedChunk.getData();
|
||||
String callbackText = this.onText(textContent, context);
|
||||
transformedChunk.setData(callbackText);
|
||||
context.addText(callbackText);
|
||||
break;
|
||||
case THINKING:
|
||||
String thinkingContent = (String) transformedChunk.getData();
|
||||
String callbackThinking = this.onThinking(thinkingContent, context);
|
||||
transformedChunk.setData(callbackThinking);
|
||||
context.addThinking(callbackThinking);
|
||||
break;
|
||||
case TOOL_CALLS:
|
||||
List<ToolCall> toolCallsContent = (List<ToolCall>) transformedChunk.getData();
|
||||
List<ToolCall> responseToolCalls = this.onToolsCalling(toolCallsContent, context);
|
||||
transformedChunk.setData(responseToolCalls);
|
||||
break;
|
||||
case USAGE:
|
||||
Object usageContent = transformedChunk.getData();
|
||||
Object responseUsage = this.onUsage(usageContent, context);
|
||||
transformedChunk.setData(responseUsage);
|
||||
break;
|
||||
case BASE64_IMAGE:
|
||||
case DATA:
|
||||
case ERROR:
|
||||
// TODO 异常处理
|
||||
default:
|
||||
// 其他类型暂不处理
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本消息的回调方法。
|
||||
*
|
||||
* @param content 文本内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
String onText(String content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理思考消息的回调方法。
|
||||
*
|
||||
* @param content 思考内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
String onThinking(String content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理工具调用消息的回调方法。
|
||||
*
|
||||
* @param toolCalls 工具调用内容,可能是工具调用的结果或相关信息
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理 Token 统计信息的回调方法
|
||||
*
|
||||
* @param content Token 统计信息内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
// TODO args
|
||||
Object onUsage(Object content, InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理基础信息/搜索结果的回调方法
|
||||
*
|
||||
* @param content 基础信息或搜索结果内容
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
// TODO args
|
||||
Object onGrounding(Object content, InteractContext context);
|
||||
|
||||
@Override
|
||||
default String getTransformerType() {
|
||||
return "ChunkCallback";
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.callbacks;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
|
||||
/**
|
||||
* 对消息全部发送完毕并转换后的结果进行处理的接口。
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public interface ResultHandler {
|
||||
|
||||
static ResultHandler getDefault() {
|
||||
return new ResultHandler() {
|
||||
@Override
|
||||
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
|
||||
return response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
|
||||
return response;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 消息处理完成的回调方法。
|
||||
*
|
||||
* @param response 处理后的聊天响应结果
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @return 处理后的结果
|
||||
*/
|
||||
ChatResponse onCompletion(ChatResponse response, InteractContext context);
|
||||
|
||||
/**
|
||||
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。
|
||||
*
|
||||
* @param response 最终的聊天响应结果
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @return 处理后的结果
|
||||
*/
|
||||
ChatResponse onFinal(ChatResponse response, InteractContext context);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.chunk;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 流式事件封装
|
||||
* <p>
|
||||
* 该类封装了整个流式交互过程中产生的各种事件,包括:
|
||||
* - 请求开始事件 (START)
|
||||
* - 数据分块事件 (CHUNK)
|
||||
* - 请求完成事件 (COMPLETE)
|
||||
* - 错误事件 (ERROR)
|
||||
* <p>
|
||||
* 每个事件都包含相关的上下文信息,允许用户在处理流时访问中间状态。
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
public class ChunkEvent {
|
||||
|
||||
/**
|
||||
* 事件类型枚举
|
||||
*/
|
||||
public enum EventType {
|
||||
/**
|
||||
* 流开始事件 - 在流开始处理时触发
|
||||
*/
|
||||
START,
|
||||
/**
|
||||
* 数据分块事件 - 接收到新的数据分块时触发
|
||||
*/
|
||||
CHUNK,
|
||||
/**
|
||||
* 流完成事件 - 所有数据处理完毕时触发
|
||||
*/
|
||||
COMPLETE,
|
||||
/**
|
||||
* 错误事件 - 处理过程中发生错误时触发
|
||||
*/
|
||||
ERROR
|
||||
}
|
||||
|
||||
/**
|
||||
* 事件类型
|
||||
*/
|
||||
private final EventType eventType;
|
||||
|
||||
/**
|
||||
* 原始的 JSON 响应块(协议相关的原始格式)
|
||||
* 在 CHUNK 事件中可能有值,其他事件类型为 null
|
||||
*/
|
||||
private final String rawChunk;
|
||||
|
||||
/**
|
||||
* 转换后的框架标准响应块
|
||||
* 在 CHUNK 事件中可能有值,其他事件类型为 null
|
||||
*/
|
||||
private final StreamingProtocolChunk transformedChunk;
|
||||
|
||||
/**
|
||||
* 交互上下文 - 包含当前的交互状态
|
||||
* 在 CHUNK 和 COMPLETE 事件中有值,START 事件为初始化的上下文,ERROR 事件可能为 null
|
||||
*/
|
||||
private final InteractContext context;
|
||||
|
||||
/**
|
||||
* 错误信息 - 仅在 ERROR 事件中有值
|
||||
*/
|
||||
private final Throwable error;
|
||||
|
||||
/**
|
||||
* 最终响应 - 仅在 COMPLETE 事件中有值
|
||||
*/
|
||||
private final ChatResponse finalResponse;
|
||||
|
||||
/**
|
||||
* 私有构造函数
|
||||
*/
|
||||
private ChunkEvent(EventType eventType, String rawChunk, StreamingProtocolChunk transformedChunk,
|
||||
InteractContext context, Throwable error, ChatResponse finalResponse) {
|
||||
this.eventType = Objects.requireNonNull(eventType, "eventType cannot be null");
|
||||
this.rawChunk = rawChunk;
|
||||
this.transformedChunk = transformedChunk;
|
||||
this.context = context;
|
||||
this.error = error;
|
||||
this.finalResponse = finalResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建开始事件
|
||||
*
|
||||
* @param context 交互上下文
|
||||
* @return 开始事件
|
||||
*/
|
||||
public static ChunkEvent start(InteractContext context) {
|
||||
return new ChunkEvent(EventType.START, null, null, context, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建数据分块事件
|
||||
*
|
||||
* @param rawChunk 原始 JSON 分块
|
||||
* @param transformedChunk 转换后的框架标准分块
|
||||
* @param context 交互上下文
|
||||
* @return 数据分块事件
|
||||
*/
|
||||
public static ChunkEvent chunk(String rawChunk, StreamingProtocolChunk transformedChunk, InteractContext context) {
|
||||
return new ChunkEvent(EventType.CHUNK, rawChunk, transformedChunk, context, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建完成事件
|
||||
*
|
||||
* @param context 交互上下文
|
||||
* @param finalResponse 最终聊天响应
|
||||
* @return 完成事件
|
||||
*/
|
||||
public static ChunkEvent complete(InteractContext context, ChatResponse finalResponse) {
|
||||
return new ChunkEvent(EventType.COMPLETE, null, null, context, null, finalResponse);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建错误事件
|
||||
*
|
||||
* @param error 错误信息
|
||||
* @return 错误事件
|
||||
*/
|
||||
public static ChunkEvent error(Throwable error) {
|
||||
return new ChunkEvent(EventType.ERROR, null, null, null, error, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建错误事件 - 带上下文
|
||||
*
|
||||
* @param error 错误信息
|
||||
* @param context 交互上下文
|
||||
* @return 错误事件
|
||||
*/
|
||||
public static ChunkEvent error(Throwable error, InteractContext context) {
|
||||
return new ChunkEvent(EventType.ERROR, null, null, context, error, null);
|
||||
}
|
||||
|
||||
// ==================== Getters ====================
|
||||
|
||||
public EventType getEventType() {
|
||||
return eventType;
|
||||
}
|
||||
|
||||
public String getRawChunk() {
|
||||
return rawChunk;
|
||||
}
|
||||
|
||||
public StreamingProtocolChunk getTransformedChunk() {
|
||||
return transformedChunk;
|
||||
}
|
||||
|
||||
public InteractContext getContext() {
|
||||
return context;
|
||||
}
|
||||
|
||||
public Throwable getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
public ChatResponse getFinalResponse() {
|
||||
return finalResponse;
|
||||
}
|
||||
|
||||
// ==================== Convenience Methods ====================
|
||||
|
||||
/**
|
||||
* 是否为开始事件
|
||||
*
|
||||
* @return true 如果是开始事件
|
||||
*/
|
||||
public boolean isStart() {
|
||||
return eventType == EventType.START;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否为分块事件
|
||||
*
|
||||
* @return true 如果是分块事件
|
||||
*/
|
||||
public boolean isChunk() {
|
||||
return eventType == EventType.CHUNK;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否为完成事件
|
||||
*
|
||||
* @return true 如果是完成事件
|
||||
*/
|
||||
public boolean isComplete() {
|
||||
return eventType == EventType.COMPLETE;
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否为错误事件
|
||||
*
|
||||
* @return true 如果是错误事件
|
||||
*/
|
||||
public boolean isError() {
|
||||
return eventType == EventType.ERROR;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ChunkEvent{" +
|
||||
"eventType=" + eventType +
|
||||
", rawChunk=" + (rawChunk != null ? rawChunk.substring(0, Math.min(50, rawChunk.length())) + "..." : "null") +
|
||||
", transformedChunk=" + transformedChunk +
|
||||
", context=" + context +
|
||||
", error=" + error +
|
||||
", finalResponse=" + finalResponse +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.pipeline;
|
||||
package com.yomahub.liteflow.ai.engine.interact.chunk;
|
||||
|
||||
import cn.hutool.core.collection.CollectionUtil;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.TokenUsage;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.protocol;
|
||||
package com.yomahub.liteflow.ai.engine.interact.chunk;
|
||||
|
||||
/**
|
||||
* 流式消息块
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.protocol;
|
||||
package com.yomahub.liteflow.ai.engine.interact.chunk;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
@@ -1,91 +0,0 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.pipeline;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
|
||||
/**
|
||||
* 流式消息处理管道
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public class ChunkProcessPipeline {
|
||||
|
||||
private final InteractContext context;
|
||||
|
||||
private final ProtocolTransformer protocolTransformer;
|
||||
|
||||
private final ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
|
||||
private ChunkProcessPipeline(
|
||||
InteractContext context,
|
||||
ProtocolTransformer protocolTransformer,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer) {
|
||||
this.context = context;
|
||||
this.protocolTransformer = protocolTransformer;
|
||||
this.chunkCallbackTransformer = chunkCallbackTransformer;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建流式消息处理管道实例。
|
||||
*
|
||||
* @param context 聊天上下文,包含会话状态等
|
||||
* @param protocolTransformer 协议转换器,将不同厂商大模型响应转换为 LiteFlow-AI 支持的统一格式
|
||||
* @param chunkCallbackTransformer 消息处理管道的回调接口,根据块数据的类型进行具体回调
|
||||
* @return 流式消息处理管道实例
|
||||
*/
|
||||
public static ChunkProcessPipeline createStreamingPipeline(InteractContext context, ProtocolTransformer protocolTransformer, ChunkCallbackTransformer chunkCallbackTransformer) {
|
||||
return new ChunkProcessPipeline(context, protocolTransformer, chunkCallbackTransformer);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建阻塞式调用的消息处理管道实例。
|
||||
*
|
||||
* @param context 聊天上下文,包含会话状态等
|
||||
* @param protocolTransformer 协议转换器,将不同厂商大模型响应转换为 LiteFlow-AI 支持的统一格式
|
||||
* @return 阻塞式调用的消息处理管道实例
|
||||
*/
|
||||
public static ChunkProcessPipeline createBlockingPipeline(InteractContext context, ProtocolTransformer protocolTransformer) {
|
||||
return new ChunkProcessPipeline(context, protocolTransformer, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将流式响应的 chunk 转换为 LiteFlow-AI 支持的格式。
|
||||
*
|
||||
* @param chunk 流式响应的 chunk
|
||||
* @return 当前流式响应是否结束, true 表示结束,false 表示未结束
|
||||
*/
|
||||
public boolean processStreaming(String chunk) {
|
||||
// 协议转换 chunk
|
||||
StreamingProtocolChunk transformedChunk = protocolTransformer.transformStreamingChunk(chunk, context);
|
||||
// 回调处理器进行回调
|
||||
chunkCallbackTransformer.transform(transformedChunk, context);
|
||||
return context.isFinished();
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 阻塞式调用 的 Response 转换为 LiteFlow-AI 支持的 AssistantMessage。
|
||||
*
|
||||
* @param blockingResponse OkHttp 的 Response 对象
|
||||
* @return 转换后的 ChatResponse
|
||||
*/
|
||||
public ChatResponse processBlocking(String blockingResponse) {
|
||||
return protocolTransformer.transformBlockingResponse(blockingResponse, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* 构造流式调用的最终响应
|
||||
*
|
||||
* @return 最终的 ChatResponse
|
||||
*/
|
||||
public ChatResponse buildFinalStreamingResponse() {
|
||||
return protocolTransformer.transformStreamingResponse(context);
|
||||
}
|
||||
|
||||
public InteractContext getContext() {
|
||||
return context;
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.pipeline;
|
||||
|
||||
/**
|
||||
* 消息转换器接口
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public interface ChunkTransformer {
|
||||
|
||||
String getTransformerType();
|
||||
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.protocol;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
|
||||
/**
|
||||
@@ -12,7 +12,7 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public interface ProtocolTransformer extends ChunkTransformer {
|
||||
public interface ProtocolTransformer {
|
||||
|
||||
/**
|
||||
* 将流式响应的 chunk 转换为 LiteFlow-AI 支持的格式。
|
||||
@@ -41,9 +41,4 @@ public interface ProtocolTransformer extends ChunkTransformer {
|
||||
ChatResponse transformBlockingResponse(String blockingResponse, InteractContext context);
|
||||
|
||||
String getProviderName();
|
||||
|
||||
@Override
|
||||
default String getTransformerType() {
|
||||
return "ProtocolTransformer";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.transport;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@@ -19,22 +18,20 @@ public interface Transport {
|
||||
/**
|
||||
* 启动流式传输
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @param pipeline 处理管道
|
||||
* @param listener 传输监听器
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @return 流式数据流
|
||||
*/
|
||||
void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener);
|
||||
Flowable<String> startStreaming(ChatConfig config, ChatRequest request);
|
||||
|
||||
/**
|
||||
* 启动阻塞式传输
|
||||
*
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @param pipeline 处理管道
|
||||
* @param config 聊天配置
|
||||
* @param request 聊天请求
|
||||
* @return 聊天响应
|
||||
*/
|
||||
ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline);
|
||||
String startBlocking(ChatConfig config, ChatRequest request);
|
||||
|
||||
/**
|
||||
* 关闭传输
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.transport;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
|
||||
/**
|
||||
* 传输监听器接口
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public interface TransportListener {
|
||||
|
||||
static TransportListener getDefault() {
|
||||
return new TransportListener() {
|
||||
@Override
|
||||
public void onStart(InteractContext context) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(InteractContext context) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(InteractContext context, Throwable t) {
|
||||
throw new LiteFlowAIEngineException(t.getMessage(), t);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求开始时的回调方法
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
void onStart(InteractContext context);
|
||||
|
||||
/**
|
||||
* 请求关闭时的回调方法
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
*/
|
||||
void onClose(InteractContext context);
|
||||
|
||||
/**
|
||||
* 处理过程中发生错误的回调方法。
|
||||
*
|
||||
* @param context 聊天上下文,包含处理过程中的状态和信息
|
||||
* @param t
|
||||
*/
|
||||
void onError(InteractContext context, Throwable t);
|
||||
}
|
||||
@@ -2,14 +2,14 @@ package com.yomahub.liteflow.ai.engine.interact.transport.impl;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLog;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.BackpressureStrategy;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import io.reactivex.rxjava3.disposables.Disposable;
|
||||
import okhttp3.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
@@ -18,6 +18,8 @@ import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* 一个用于处理"换行符分隔的 JSON"(Delimiter-Newline JSON)流的 LLM 客户端。(Ollama流式输出应使用该传输方式)
|
||||
@@ -34,92 +36,102 @@ import java.util.Objects;
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public class DnJsonTransport implements Transport, Callback {
|
||||
public class DnJsonTransport implements Transport {
|
||||
|
||||
private static final EngineLog LOG = EngineLogManager.getLogger(DnJsonTransport.class);
|
||||
|
||||
private ChunkProcessPipeline pipeline;
|
||||
private TransportListener listener;
|
||||
private OkHttpClient client;
|
||||
private boolean isStop = false;
|
||||
private boolean isLogResponse = false;
|
||||
private final AtomicReference<OkHttpClient> clientRef = new AtomicReference<>();
|
||||
private final AtomicBoolean isStop = new AtomicBoolean(false);
|
||||
|
||||
@Override
|
||||
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
|
||||
this.pipeline = pipeline;
|
||||
this.listener = listener;
|
||||
this.isLogResponse = config.isLogResponse();
|
||||
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
|
||||
return Flowable.create(emitter -> {
|
||||
Request dnJsonRequest = buildDnJsonRequest(config, request);
|
||||
|
||||
Request dnJsonRequest = buildDnJsonRequest(config, request);
|
||||
OkHttpClient client = new okhttp3.OkHttpClient.Builder()
|
||||
.connectTimeout(config.getConnectTimeout())
|
||||
.readTimeout(config.getReadTimeout())
|
||||
.build();
|
||||
|
||||
client = new okhttp3.OkHttpClient.Builder()
|
||||
.connectTimeout(config.getConnectTimeout())
|
||||
.readTimeout(config.getReadTimeout())
|
||||
.build();
|
||||
this.clientRef.set(client);
|
||||
|
||||
this.listener.onStart(pipeline.getContext());
|
||||
// 异步执行请求,this 作为回调处理器
|
||||
this.client.newCall(dnJsonRequest).enqueue(this);
|
||||
client.newCall(dnJsonRequest).enqueue(new Callback() {
|
||||
@Override
|
||||
public void onFailure(@NotNull Call call, @NotNull IOException e) {
|
||||
emitter.onError(e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
|
||||
try {
|
||||
if (!response.isSuccessful()) {
|
||||
emitter.onError(new LiteFlowAIEngineException("error response when calling LLM: " + response.message()));
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
if (Objects.isNull(body)) {
|
||||
emitter.onError(new LiteFlowAIEngineException("response body is null when calling LLM"));
|
||||
return;
|
||||
}
|
||||
|
||||
// 逐行读取响应体,响应体为换行符分隔的 JSON
|
||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(body.byteStream()))) {
|
||||
String line = br.readLine();
|
||||
|
||||
while (StrUtil.isNotBlank(line)) {
|
||||
if (config.isLogResponse()) {
|
||||
LOG.info("DN-JSON Response: {}", line);
|
||||
}
|
||||
|
||||
try {
|
||||
emitter.onNext(line);
|
||||
} catch (Exception e) {
|
||||
emitter.onError(e);
|
||||
return;
|
||||
}
|
||||
|
||||
line = br.readLine();
|
||||
}
|
||||
|
||||
emitter.onComplete();
|
||||
}
|
||||
} finally {
|
||||
close();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// 设置取消订阅时的清理逻辑
|
||||
emitter.setDisposable(new Disposable() {
|
||||
@Override
|
||||
public void dispose() {
|
||||
close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDisposed() {
|
||||
return isStop.get();
|
||||
}
|
||||
});
|
||||
}, BackpressureStrategy.BUFFER);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
|
||||
throw new UnsupportedOperationException("SSE传输不支持阻塞式调用,请使用HTTP传输或调用start方法");
|
||||
public String startBlocking(ChatConfig config, ChatRequest request) {
|
||||
throw new UnsupportedOperationException("DN-JSON传输不支持阻塞式调用,请使用HTTP传输");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (!this.isStop) {
|
||||
try {
|
||||
this.isStop = true;
|
||||
this.listener.onClose(pipeline.getContext());
|
||||
} finally {
|
||||
if (Objects.nonNull(client)) {
|
||||
client.dispatcher().executorService().shutdown();
|
||||
client.connectionPool().evictAll();
|
||||
}
|
||||
if (this.isStop.compareAndSet(false, true)) {
|
||||
OkHttpClient client = clientRef.get();
|
||||
if (Objects.nonNull(client)) {
|
||||
client.dispatcher().executorService().shutdown();
|
||||
client.connectionPool().evictAll();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(@NotNull Call call, @NotNull IOException e) {
|
||||
this.listener.onError(pipeline.getContext(), e);
|
||||
close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
|
||||
if (!response.isSuccessful()) {
|
||||
this.listener.onError(pipeline.getContext(), new LiteFlowAIEngineException("error response when calling LLM: " + response.message()));
|
||||
close();
|
||||
return;
|
||||
}
|
||||
ResponseBody body = response.body();
|
||||
if (Objects.isNull(body)) {
|
||||
this.listener.onError(pipeline.getContext(), new LiteFlowAIEngineException("response body is null when calling LLM"));
|
||||
close();
|
||||
return;
|
||||
}
|
||||
|
||||
// 逐行读取响应体,响应体为换行符分隔的 JSON
|
||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(body.byteStream()))) {
|
||||
String line = br.readLine();
|
||||
|
||||
while (StrUtil.isNotBlank(line)) {
|
||||
if (this.isLogResponse) {
|
||||
LOG.info("DN-JSON Response: {}", line);
|
||||
}
|
||||
|
||||
pipeline.processStreaming(line);
|
||||
line = br.readLine();
|
||||
}
|
||||
} finally {
|
||||
// 确保在读取完毕后关闭资源
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
private Request buildDnJsonRequest(ChatConfig config, ChatRequest request) {
|
||||
String requestBody = buildRequestBody(config, request);
|
||||
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.transport.impl;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLog;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.util.HttpUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
@@ -26,12 +24,12 @@ public class HttpTransport implements Transport {
|
||||
private static final EngineLog LOG = EngineLogManager.getLogger(HttpTransport.class);
|
||||
|
||||
@Override
|
||||
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
|
||||
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
|
||||
throw new UnsupportedOperationException("HTTP传输不支持流式调用,请使用SSE传输或调用startBlocking方法");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
|
||||
public String startBlocking(ChatConfig config, ChatRequest request) {
|
||||
try (HttpUtil httpUtil = HttpUtil
|
||||
.builder()
|
||||
.connectTimeout(config.getConnectTimeout())
|
||||
@@ -61,8 +59,8 @@ public class HttpTransport implements Transport {
|
||||
LOG.info("======= HTTP Response End =======");
|
||||
}
|
||||
|
||||
// 处理响应
|
||||
return pipeline.processBlocking(responseBody);
|
||||
// 返回响应体
|
||||
return responseBody;
|
||||
} catch (IOException e) {
|
||||
throw new LiteFlowAIEngineException("阻塞调用大模型失败", e);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.yomahub.liteflow.ai.engine.interact.transport.impl;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLog;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.BackpressureStrategy;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import io.reactivex.rxjava3.disposables.Disposable;
|
||||
import okhttp3.Headers;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
@@ -20,6 +20,8 @@ import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
/**
|
||||
* Sse传输实现,基于Server-Sent Events的非阻塞式传输
|
||||
@@ -28,96 +30,114 @@ import java.util.Objects;
|
||||
* @since 2.16.0
|
||||
*/
|
||||
|
||||
public class SseTransport extends EventSourceListener implements Transport {
|
||||
public class SseTransport implements Transport {
|
||||
|
||||
private static final EngineLog LOG = EngineLogManager.getLogger(SseTransport.class);
|
||||
|
||||
private ChunkProcessPipeline pipeline;
|
||||
private TransportListener listener;
|
||||
private OkHttpClient client;
|
||||
private EventSource eventSource;
|
||||
private boolean isStop = false;
|
||||
private boolean isLogResponse = false;
|
||||
private final AtomicReference<EventSource> eventSourceRef = new AtomicReference<>();
|
||||
private final AtomicReference<OkHttpClient> clientRef = new AtomicReference<>();
|
||||
private final AtomicBoolean isStopped = new AtomicBoolean(false);
|
||||
|
||||
@Override
|
||||
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
|
||||
this.pipeline = pipeline;
|
||||
this.listener = listener;
|
||||
this.isLogResponse = config.isLogResponse();
|
||||
|
||||
try {
|
||||
// 构建SSE请求
|
||||
Request sseRequest = buildSseRequest(config, request);
|
||||
|
||||
// 创建EventSource实例
|
||||
client = new okhttp3.OkHttpClient.Builder()
|
||||
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
|
||||
return Flowable.create(emitter -> {
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.connectTimeout(config.getConnectTimeout())
|
||||
.readTimeout(config.getReadTimeout())
|
||||
.build();
|
||||
|
||||
this.eventSource = EventSources.createFactory(client)
|
||||
.newEventSource(sseRequest, this);
|
||||
this.clientRef.set(client);
|
||||
|
||||
listener.onStart(pipeline.getContext());
|
||||
} catch (Exception e) {
|
||||
onFailure(eventSource, e, null);
|
||||
}
|
||||
try {
|
||||
// 构建SSE请求
|
||||
Request sseRequest = buildSseRequest(config, request);
|
||||
|
||||
// 创建响应式的 EventSourceListener
|
||||
EventSourceListener listener = new EventSourceListener() {
|
||||
@Override
|
||||
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
|
||||
if (config.isLogResponse()) {
|
||||
LOG.info("SSE Response Event - id: {}, type: {}, data: {}", id, type, data);
|
||||
}
|
||||
|
||||
// 发送数据到 Flowable
|
||||
try {
|
||||
emitter.onNext(data);
|
||||
} catch (Exception e) {
|
||||
emitter.onError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(@NotNull EventSource eventSource) {
|
||||
try {
|
||||
if (!emitter.isCancelled()) {
|
||||
emitter.onComplete();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
LOG.warn("Error completing stream", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
|
||||
if (!emitter.isCancelled()) {
|
||||
if (t != null) {
|
||||
emitter.onError(t);
|
||||
} else {
|
||||
emitter.onError(new RuntimeException("SSE connection failed"));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 创建 EventSource
|
||||
EventSource eventSource = EventSources.createFactory(client)
|
||||
.newEventSource(sseRequest, listener);
|
||||
eventSourceRef.set(eventSource);
|
||||
|
||||
// 设置取消订阅时的清理逻辑
|
||||
emitter.setDisposable(new Disposable() {
|
||||
@Override
|
||||
public void dispose() {
|
||||
close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isDisposed() {
|
||||
return isStopped.get();
|
||||
}
|
||||
});
|
||||
} catch (Exception e) {
|
||||
emitter.onError(e);
|
||||
}
|
||||
}, BackpressureStrategy.BUFFER);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
|
||||
public String startBlocking(ChatConfig config, ChatRequest request) {
|
||||
throw new UnsupportedOperationException("SSE传输不支持阻塞式调用,请使用HTTP传输或调用start方法");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (!this.isStop) {
|
||||
try {
|
||||
this.isStop = true;
|
||||
this.listener.onClose(pipeline.getContext());
|
||||
} finally {
|
||||
if (Objects.nonNull(eventSource)) {
|
||||
eventSource.cancel();
|
||||
}
|
||||
if (Objects.nonNull(client)) {
|
||||
client.dispatcher().executorService().shutdown();
|
||||
client.connectionPool().evictAll();
|
||||
}
|
||||
if (this.isStopped.compareAndSet(false, true)) {
|
||||
EventSource eventSource = eventSourceRef.get();
|
||||
if (Objects.nonNull(eventSource)) {
|
||||
eventSource.cancel();
|
||||
}
|
||||
OkHttpClient client = clientRef.get();
|
||||
if (Objects.nonNull(client)) {
|
||||
client.dispatcher().executorService().shutdown();
|
||||
client.connectionPool().evictAll();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(@NotNull EventSource eventSource, @NotNull Response response) {
|
||||
super.onOpen(eventSource, response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
|
||||
super.onEvent(eventSource, id, type, data);
|
||||
|
||||
if (this.isLogResponse) {
|
||||
LOG.info("SSE Response Event - id: {}, type: {}, data: {}", id, type, data);
|
||||
}
|
||||
|
||||
// 如果返回 true 表示流式响应结束,关闭连接, 有一些模型不会主动关闭连接,需要在这里判断
|
||||
if (pipeline.processStreaming(data)) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(@NotNull EventSource eventSource) {
|
||||
super.onClosed(eventSource);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(@NotNull EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
|
||||
super.onFailure(eventSource, t, response);
|
||||
close();
|
||||
this.listener.onError(pipeline.getContext(), t);
|
||||
}
|
||||
|
||||
private Request buildSseRequest(ChatConfig config, ChatRequest request) {
|
||||
String requestBody = buildRequestBody(config, request);
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@ package com.yomahub.liteflow.ai.engine.model.chat;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.InteractClient;
|
||||
import com.yomahub.liteflow.ai.engine.interact.LlmInteractClient;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
@@ -43,8 +45,8 @@ public abstract class AbstractChatModel implements ChatModel {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stream(ChatRequest request) {
|
||||
interactClient.stream(config, request);
|
||||
public Flowable<ChunkEvent> stream(ChatRequest request) {
|
||||
return interactClient.stream(config, request);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package com.yomahub.liteflow.ai.engine.model.chat;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.model.BaseModel;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
@@ -33,10 +35,16 @@ public interface ChatModel extends BaseModel<ChatConfig> {
|
||||
CompletableFuture<ChatResponse> chatAsync(ChatRequest request);
|
||||
|
||||
/**
|
||||
* 执行聊天请求(异步流式)
|
||||
* 执行聊天请求(响应式流式)
|
||||
* <p>
|
||||
* 返回一个 Flowable 流,用户可以订阅和处理每个 ChunkEvent 事件,包括:
|
||||
* - START:流开始事件,包含初始化的 InteractContext
|
||||
* - CHUNK:数据分块事件,包含原始 JSON、转换后的分块和更新的 InteractContext
|
||||
* - COMPLETE:完成事件,包含最终的 ChatResponse 和完整的 InteractContext
|
||||
* - ERROR:错误事件,包含异常信息
|
||||
*
|
||||
* @param request 聊天请求
|
||||
* @return 包含所有流式事件的 Flowable 流
|
||||
*/
|
||||
void stream(ChatRequest request);
|
||||
|
||||
Flowable<ChunkEvent> stream(ChatRequest request);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
package com.yomahub.liteflow.ai.engine.model.chat.entity;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.ModelRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
@@ -13,7 +8,6 @@ import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.JsonSchemaParser;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
@@ -24,9 +18,6 @@ import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.BiConsumer;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@@ -57,21 +48,6 @@ public class ChatRequest implements ModelRequest {
|
||||
*/
|
||||
protected final TransportType transportType;
|
||||
|
||||
/**
|
||||
* 传输监听器,用于处理请求的开始和结束事件
|
||||
*/
|
||||
protected final TransportListener transportListener;
|
||||
|
||||
/**
|
||||
* 结果处理器,用于处理消息全部发送完毕后的结果
|
||||
*/
|
||||
protected ResultHandler resultHandler;
|
||||
|
||||
/**
|
||||
* 分块回调,用于处理消息分块的各种事件
|
||||
*/
|
||||
protected final ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
|
||||
/**
|
||||
* 响应类型,默认为文本类型
|
||||
*/
|
||||
@@ -114,9 +90,6 @@ public class ChatRequest implements ModelRequest {
|
||||
this.options = ChatOptions.DEFAULT;
|
||||
this.streaming = true; // 默认启用流式输出
|
||||
this.transportType = TransportType.SSE; // 默认使用 SSE 传输
|
||||
this.transportListener = TransportListener.getDefault();
|
||||
this.resultHandler = ResultHandler.getDefault();
|
||||
this.chunkCallbackTransformer = ChunkCallbackTransformer.getDefault();
|
||||
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
|
||||
TypeReference<String> targetType = new TypeReference<String>() {
|
||||
}; // 默认目标类型为 String
|
||||
@@ -130,9 +103,6 @@ public class ChatRequest implements ModelRequest {
|
||||
ChatOptions options,
|
||||
boolean streaming,
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType,
|
||||
boolean strict,
|
||||
@@ -142,9 +112,6 @@ public class ChatRequest implements ModelRequest {
|
||||
this.options = options;
|
||||
this.streaming = streaming;
|
||||
this.transportType = transportType;
|
||||
this.transportListener = transportListener;
|
||||
this.resultHandler = resultHandler;
|
||||
this.chunkCallbackTransformer = chunkCallbackTransformer;
|
||||
this.responseType = responseType;
|
||||
this.strict = strict;
|
||||
this.outputParser = JsonSchemaParser.fromTypeReference(targetType, strict);
|
||||
@@ -163,9 +130,6 @@ public class ChatRequest implements ModelRequest {
|
||||
this.options = builder.options;
|
||||
this.streaming = builder.streaming;
|
||||
this.transportType = builder.transportType;
|
||||
this.transportListener = builder.transportListener;
|
||||
this.resultHandler = builder.resultHandler;
|
||||
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
|
||||
this.responseType = builder.responseType;
|
||||
this.strict = builder.strict;
|
||||
this.outputParser = JsonSchemaParser.fromType(builder.targetType, builder.strict);
|
||||
@@ -243,18 +207,6 @@ public class ChatRequest implements ModelRequest {
|
||||
return transportType;
|
||||
}
|
||||
|
||||
public TransportListener getTransportListener() {
|
||||
return transportListener;
|
||||
}
|
||||
|
||||
public ResultHandler getResultHandler() {
|
||||
return resultHandler;
|
||||
}
|
||||
|
||||
public ChunkCallbackTransformer getChunkCallbackTransformer() {
|
||||
return chunkCallbackTransformer;
|
||||
}
|
||||
|
||||
public ResponseType getResponseType() {
|
||||
return responseType;
|
||||
}
|
||||
@@ -271,10 +223,6 @@ public class ChatRequest implements ModelRequest {
|
||||
return strict;
|
||||
}
|
||||
|
||||
public void setResultHandler(ResultHandler resultHandler) {
|
||||
this.resultHandler = resultHandler;
|
||||
}
|
||||
|
||||
public ToolRegistry getToolRegistry() {
|
||||
return toolRegistry;
|
||||
}
|
||||
@@ -288,10 +236,6 @@ public class ChatRequest implements ModelRequest {
|
||||
protected ChatOptions options;
|
||||
protected boolean streaming = true; // 默认启用流式输出
|
||||
protected TransportType transportType = TransportType.SSE; // 默认使用 SSE 传输
|
||||
protected final LlmListenerAggregator listenerAggregator = new LlmListenerAggregator();
|
||||
protected TransportListener transportListener;
|
||||
protected ResultHandler resultHandler;
|
||||
protected ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
protected ResponseType responseType = ResponseType.TEXT;
|
||||
protected Type targetType = String.class;
|
||||
protected boolean strict = true;
|
||||
@@ -308,9 +252,6 @@ public class ChatRequest implements ModelRequest {
|
||||
if (Objects.isNull(options)) {
|
||||
options = ChatOptions.DEFAULT;
|
||||
}
|
||||
this.transportListener = listenerAggregator.toTransportListener();
|
||||
this.resultHandler = listenerAggregator.toResultHandler();
|
||||
this.chunkCallbackTransformer = listenerAggregator.toChunkCallbackTransformer();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -356,116 +297,6 @@ public class ChatRequest implements ModelRequest {
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求开始时的回调方法
|
||||
*
|
||||
* @param onStart 请求开始时的回调函数
|
||||
* @see TransportListener#onStart(InteractContext)
|
||||
*/
|
||||
public B onStart(Consumer<InteractContext> onStart) {
|
||||
listenerAggregator.onStart = onStart;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求结束时的回调方法
|
||||
*
|
||||
* @param onClose 请求结束时的回调函数
|
||||
* @see TransportListener#onClose(InteractContext)
|
||||
*/
|
||||
public B onClose(Consumer<InteractContext> onClose) {
|
||||
listenerAggregator.onClose = onClose;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求发生错误时的回调方法
|
||||
*
|
||||
* @param onError 请求发生错误时的回调函数
|
||||
* @see TransportListener#onError(InteractContext, Throwable)
|
||||
*/
|
||||
public B onError(BiConsumer<InteractContext, Throwable> onError) {
|
||||
listenerAggregator.onError = onError;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 文本消息的回调方法
|
||||
*
|
||||
* @param onText 文本消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onText(String, InteractContext)
|
||||
*/
|
||||
public B onText(BiFunction<String, InteractContext, String> onText) {
|
||||
listenerAggregator.onText = onText;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 思考消息的回调方法
|
||||
*
|
||||
* @param onThinking 思考消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onThinking(String, InteractContext)
|
||||
*/
|
||||
public B onThinking(BiFunction<String, InteractContext, String> onThinking) {
|
||||
listenerAggregator.onThinking = onThinking;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具调用消息的回调方法
|
||||
*
|
||||
* @param onToolsCalling 工具调用消息的回调函数
|
||||
* @see ChunkCallbackTransformer#onToolsCalling(List, InteractContext)
|
||||
*/
|
||||
public B onToolsCalling(BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling) {
|
||||
listenerAggregator.onToolsCalling = onToolsCalling;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Token 统计信息的回调方法
|
||||
*
|
||||
* @param onUsage Token 统计信息的回调函数
|
||||
* @see ChunkCallbackTransformer#onUsage(Object, InteractContext)
|
||||
*/
|
||||
public B onUsage(BiFunction<Object, InteractContext, Object> onUsage) {
|
||||
listenerAggregator.onUsage = onUsage;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 基础信息/搜索结果的回调方法
|
||||
*
|
||||
* @param onGrounding 基础信息/搜索结果的回调函数
|
||||
* @see ChunkCallbackTransformer#onGrounding(Object, InteractContext)
|
||||
*/
|
||||
public B onGrounding(BiFunction<Object, InteractContext, Object> onGrounding) {
|
||||
listenerAggregator.onGrounding = onGrounding;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求完成时的回调方法
|
||||
*
|
||||
* @param onCompletion 请求完成时的回调函数
|
||||
* @see ResultHandler#onCompletion(ChatResponse, InteractContext)
|
||||
*/
|
||||
public B onCompletion(BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion) {
|
||||
listenerAggregator.onCompletion = onCompletion;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 最终结果处理的回调方法。无论是否发生错误均会调用此方法。
|
||||
*
|
||||
* @param onFinal 请求最终结果的回调函数
|
||||
* @see ResultHandler#onFinal(ChatResponse, InteractContext)
|
||||
*/
|
||||
public B onFinal(BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal) {
|
||||
listenerAggregator.onFinal = onFinal;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置响应类型
|
||||
*
|
||||
@@ -524,98 +355,6 @@ public class ChatRequest implements ModelRequest {
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部聚合类
|
||||
*/
|
||||
protected static class LlmListenerAggregator {
|
||||
Consumer<InteractContext> onStart = context -> {
|
||||
};
|
||||
Consumer<InteractContext> onClose = context -> {
|
||||
};
|
||||
BiConsumer<InteractContext, Throwable> onError = (context, t) -> {
|
||||
throw new LiteFlowAIEngineException(t.getMessage(), t);
|
||||
};
|
||||
BiFunction<String, InteractContext, String> onText = (content, context) -> content;
|
||||
BiFunction<String, InteractContext, String> onThinking = (content, context) -> content;
|
||||
BiFunction<List<ToolCall>, InteractContext, List<ToolCall>> onToolsCalling = (toolCalls, context) -> toolCalls;
|
||||
BiFunction<Object, InteractContext, Object> onUsage = (content, context) -> content;
|
||||
BiFunction<Object, InteractContext, Object> onGrounding = (content, context) -> content;
|
||||
BiFunction<ChatResponse, InteractContext, ChatResponse> onCompletion = (response, context) -> response;
|
||||
BiFunction<ChatResponse, InteractContext, ChatResponse> onFinal = (response, context) -> response;
|
||||
|
||||
TransportListener toTransportListener() {
|
||||
Objects.requireNonNull(onStart, "onStart cannot be null");
|
||||
Objects.requireNonNull(onClose, "onClose cannot be null");
|
||||
Objects.requireNonNull(onError, "onError cannot be null");
|
||||
return new TransportListener() {
|
||||
@Override
|
||||
public void onStart(InteractContext context) {
|
||||
onStart.accept(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(InteractContext context) {
|
||||
onClose.accept(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(InteractContext context, Throwable t) {
|
||||
onError.accept(context, t);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ResultHandler toResultHandler() {
|
||||
Objects.requireNonNull(onCompletion, "onCompletion cannot be null");
|
||||
Objects.requireNonNull(onFinal, "onFinal cannot be null");
|
||||
return new ResultHandler() {
|
||||
@Override
|
||||
public ChatResponse onCompletion(ChatResponse response, InteractContext context) {
|
||||
return onCompletion.apply(response, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse onFinal(ChatResponse response, InteractContext context) {
|
||||
return onFinal.apply(response, context);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ChunkCallbackTransformer toChunkCallbackTransformer() {
|
||||
Objects.requireNonNull(onText, "onText cannot be null");
|
||||
Objects.requireNonNull(onThinking, "onThinking cannot be null");
|
||||
Objects.requireNonNull(onToolsCalling, "onToolsCalling cannot be null");
|
||||
Objects.requireNonNull(onUsage, "onUsage cannot be null");
|
||||
Objects.requireNonNull(onGrounding, "onGrounding cannot be null");
|
||||
return new ChunkCallbackTransformer() {
|
||||
@Override
|
||||
public String onText(String content, InteractContext context) {
|
||||
return onText.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onThinking(String content, InteractContext context) {
|
||||
return onThinking.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
|
||||
return onToolsCalling.apply(toolCalls, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onUsage(Object content, InteractContext context) {
|
||||
return onUsage.apply(content, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onGrounding(Object content, InteractContext context) {
|
||||
return onGrounding.apply(content, context);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class BuilderImpl extends Builder<BuilderImpl> {
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package com.yomahub.liteflow.ai.engine.tool.registry;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 工具调用注册接口
|
||||
@@ -27,4 +31,19 @@ public interface ToolRegistry {
|
||||
* @return 所有工具的集合
|
||||
*/
|
||||
Collection<ToolCallBack> getAllTools();
|
||||
|
||||
/**
|
||||
* 执行工具调用
|
||||
*
|
||||
* @param toolCall 工具调用信息
|
||||
* @return 工具调用结果消息
|
||||
*/
|
||||
default ToolMessage executeToolCall(ToolCall toolCall) {
|
||||
ToolCallBack toolCallBack = getTool(toolCall.getName());
|
||||
if (Objects.isNull(toolCallBack)) {
|
||||
throw new LiteFlowAIEngineException("Unable to find target tool with tool name: " + toolCall.getName());
|
||||
}
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
return new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
|
||||
@@ -2,9 +2,6 @@ package com.yomahub.liteflow.ai.model.dashscope.model.chat;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
@@ -56,16 +53,12 @@ public class DashScopeChatRequest extends ChatRequest {
|
||||
ChatOptions options,
|
||||
boolean streaming,
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType,
|
||||
boolean strict,
|
||||
ToolRegistry toolRegistry
|
||||
) {
|
||||
super(messages, options, streaming, transportType,
|
||||
transportListener, resultHandler, chunkCallbackTransformer,
|
||||
responseType, targetType, strict, toolRegistry);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.yomahub.liteflow.ai.model.ollama.model.chat;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
@@ -39,16 +36,12 @@ public class OllamaChatRequest extends ChatRequest {
|
||||
ChatOptions options,
|
||||
boolean streaming,
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType,
|
||||
boolean strict,
|
||||
ToolRegistry toolRegistry
|
||||
) {
|
||||
super(messages, options, streaming, transportType,
|
||||
transportListener, resultHandler, chunkCallbackTransformer,
|
||||
responseType, targetType, strict, toolRegistry);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,10 @@ import cn.hutool.core.collection.CollectionUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
|
||||
@@ -2,9 +2,6 @@ package com.yomahub.liteflow.ai.model.openai.model.chat;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ResultHandler;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
@@ -43,16 +40,12 @@ public class OpenAIChatRequest extends ChatRequest {
|
||||
ChatOptions options,
|
||||
boolean streaming,
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType,
|
||||
boolean strict,
|
||||
ToolRegistry toolRegistry
|
||||
) {
|
||||
super(messages, options, streaming, transportType,
|
||||
transportListener, resultHandler, chunkCallbackTransformer,
|
||||
responseType, targetType, strict, toolRegistry);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,26 +1,8 @@
|
||||
package com.yomahub.liteflow.ai.workflow.coze.invocation;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.coze.openapi.client.chat.model.ChatEvent;
|
||||
import com.coze.openapi.client.workflows.chat.WorkflowChatReq;
|
||||
import com.coze.openapi.service.auth.TokenAuth;
|
||||
import com.coze.openapi.service.config.Consts;
|
||||
import com.coze.openapi.service.service.CozeAPI;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
|
||||
import com.yomahub.liteflow.ai.util.SpringUtil;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowChat;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowChatProxyWrapBean;
|
||||
import io.reactivex.Flowable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
|
||||
|
||||
/**
|
||||
* Coze 对话流 调用处理器
|
||||
@@ -36,60 +18,61 @@ public class CozeWorkflowChatInvocationHandler extends AbstractAIInvocationHandl
|
||||
|
||||
@Override
|
||||
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
|
||||
CozeWorkflowChat annotation = wrapBean.getAnnotation();
|
||||
|
||||
CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
|
||||
|
||||
CozeAPI api = new CozeAPI.Builder()
|
||||
.auth(new TokenAuth(property.getApiKey()))
|
||||
.baseURL(
|
||||
StrUtil.isBlank(property.getBaseUrl()) ?
|
||||
Consts.COZE_CN_BASE_URL :
|
||||
property.getBaseUrl()
|
||||
)
|
||||
.build();
|
||||
|
||||
WorkflowChatReq.WorkflowChatReqBuilder<?, ?> builder = WorkflowChatReq.builder();
|
||||
|
||||
setIfPresent(builder::workflowID, annotation.workflowId());
|
||||
|
||||
setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class);
|
||||
|
||||
if (annotation.parameters().length > 0) {
|
||||
Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
|
||||
if (!paramMap.isEmpty()) {
|
||||
builder.parameters(paramMap);
|
||||
}
|
||||
}
|
||||
|
||||
setIfPresent(builder::appID, annotation.appId());
|
||||
|
||||
setIfPresent(builder::botID, annotation.botId());
|
||||
|
||||
setIfPresent(builder::conversationID, annotation.conversationId());
|
||||
|
||||
if (annotation.ext().length > 0) {
|
||||
Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
|
||||
if (!extMap.isEmpty()) {
|
||||
builder.ext(extMap);
|
||||
}
|
||||
}
|
||||
|
||||
setIfPresent(builder::connectTimeout, annotation.connectTimeout());
|
||||
|
||||
setIfPresent(builder::readTimeout, annotation.readTimeout());
|
||||
|
||||
setIfPresent(builder::writeTimeout, annotation.writeTimeout());
|
||||
|
||||
setIfPresent(builder::customerToken, annotation.customerToken());
|
||||
|
||||
Flowable<ChatEvent> res = api.workflows().chat().stream(builder.build());
|
||||
InteractContext context = new InteractContext();
|
||||
res.blockingForEach(chunk -> {
|
||||
ChatContext chatContext = processorContext.getChatContext();
|
||||
chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
|
||||
});
|
||||
return null;
|
||||
// CozeWorkflowChat annotation = wrapBean.getAnnotation();
|
||||
//
|
||||
// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
|
||||
//
|
||||
// CozeAPI api = new CozeAPI.Builder()
|
||||
// .auth(new TokenAuth(property.getApiKey()))
|
||||
// .baseURL(
|
||||
// StrUtil.isBlank(property.getBaseUrl()) ?
|
||||
// Consts.COZE_CN_BASE_URL :
|
||||
// property.getBaseUrl()
|
||||
// )
|
||||
// .build();
|
||||
//
|
||||
// WorkflowChatReq.WorkflowChatReqBuilder<?, ?> builder = WorkflowChatReq.builder();
|
||||
//
|
||||
// setIfPresent(builder::workflowID, annotation.workflowId());
|
||||
//
|
||||
// setIfPresent(builder::additionalMessages, annotation.additionalMessages(), processorContext, List.class);
|
||||
//
|
||||
// if (annotation.parameters().length > 0) {
|
||||
// Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
|
||||
// if (!paramMap.isEmpty()) {
|
||||
// builder.parameters(paramMap);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// setIfPresent(builder::appID, annotation.appId());
|
||||
//
|
||||
// setIfPresent(builder::botID, annotation.botId());
|
||||
//
|
||||
// setIfPresent(builder::conversationID, annotation.conversationId());
|
||||
//
|
||||
// if (annotation.ext().length > 0) {
|
||||
// Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
|
||||
// if (!extMap.isEmpty()) {
|
||||
// builder.ext(extMap);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// setIfPresent(builder::connectTimeout, annotation.connectTimeout());
|
||||
//
|
||||
// setIfPresent(builder::readTimeout, annotation.readTimeout());
|
||||
//
|
||||
// setIfPresent(builder::writeTimeout, annotation.writeTimeout());
|
||||
//
|
||||
// setIfPresent(builder::customerToken, annotation.customerToken());
|
||||
//
|
||||
// Flowable<ChatEvent> res = api.workflows().chat().stream(builder.build());
|
||||
// InteractContext context = new InteractContext();
|
||||
// res.blockingForEach(chunk -> {
|
||||
// ChatContext chatContext = processorContext.getChatContext();
|
||||
// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
|
||||
// });
|
||||
// return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,26 +1,8 @@
|
||||
package com.yomahub.liteflow.ai.workflow.coze.invocation;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.coze.openapi.client.workflows.run.RunWorkflowReq;
|
||||
import com.coze.openapi.client.workflows.run.model.WorkflowEvent;
|
||||
import com.coze.openapi.service.auth.TokenAuth;
|
||||
import com.coze.openapi.service.config.Consts;
|
||||
import com.coze.openapi.service.service.CozeAPI;
|
||||
import com.coze.openapi.service.service.workflow.WorkflowRunService;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
|
||||
import com.yomahub.liteflow.ai.util.SpringUtil;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.annotation.CozeWorkflowRun;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.config.CozeWorkflowProperty;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.util.KeyValueUtil;
|
||||
import com.yomahub.liteflow.ai.workflow.coze.wrap.CozeWorkflowRunProxyWrapBean;
|
||||
import io.reactivex.Flowable;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
|
||||
|
||||
/**
|
||||
* Coze 工作流 调用处理器
|
||||
@@ -37,62 +19,63 @@ public class CozeWorkflowRunInvocationHandler extends AbstractAIInvocationHandle
|
||||
|
||||
@Override
|
||||
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
|
||||
CozeWorkflowRun annotation = wrapBean.getAnnotation();
|
||||
|
||||
CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
|
||||
|
||||
CozeAPI api = new CozeAPI.Builder()
|
||||
.auth(new TokenAuth(property.getApiKey()))
|
||||
.baseURL(
|
||||
StrUtil.isBlank(property.getBaseUrl()) ?
|
||||
Consts.COZE_CN_BASE_URL :
|
||||
property.getBaseUrl()
|
||||
)
|
||||
.build();
|
||||
|
||||
RunWorkflowReq.RunWorkflowReqBuilder<?, ?> builder = RunWorkflowReq.builder();
|
||||
|
||||
setIfPresent(builder::workflowID, annotation.workflowId());
|
||||
|
||||
if (annotation.parameters().length > 0) {
|
||||
Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
|
||||
if (!paramMap.isEmpty()) {
|
||||
builder.parameters(paramMap);
|
||||
}
|
||||
}
|
||||
|
||||
setIfPresent(builder::appID, annotation.appId());
|
||||
|
||||
setIfPresent(builder::botID, annotation.botId());
|
||||
|
||||
if (annotation.ext().length > 0) {
|
||||
Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
|
||||
if (!extMap.isEmpty()) {
|
||||
builder.ext(extMap);
|
||||
}
|
||||
}
|
||||
|
||||
setIfPresent(builder::connectTimeout, annotation.connectTimeout());
|
||||
|
||||
setIfPresent(builder::readTimeout, annotation.readTimeout());
|
||||
|
||||
setIfPresent(builder::writeTimeout, annotation.writeTimeout());
|
||||
|
||||
setIfPresent(builder::customerToken, annotation.customerToken());
|
||||
|
||||
WorkflowRunService runs = api.workflows().runs();
|
||||
|
||||
if (annotation.stream()) {
|
||||
Flowable<WorkflowEvent> res = runs.stream(builder.build());
|
||||
InteractContext context = new InteractContext();
|
||||
res.blockingForEach(chunk -> {
|
||||
ChatContext chatContext = processorContext.getChatContext();
|
||||
chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
|
||||
});
|
||||
return null;
|
||||
} else {
|
||||
return runs.create(builder.build());
|
||||
}
|
||||
return null;
|
||||
// CozeWorkflowRun annotation = wrapBean.getAnnotation();
|
||||
//
|
||||
// CozeWorkflowProperty property = SpringUtil.getBean(CozeWorkflowProperty.class);
|
||||
//
|
||||
// CozeAPI api = new CozeAPI.Builder()
|
||||
// .auth(new TokenAuth(property.getApiKey()))
|
||||
// .baseURL(
|
||||
// StrUtil.isBlank(property.getBaseUrl()) ?
|
||||
// Consts.COZE_CN_BASE_URL :
|
||||
// property.getBaseUrl()
|
||||
// )
|
||||
// .build();
|
||||
//
|
||||
// RunWorkflowReq.RunWorkflowReqBuilder<?, ?> builder = RunWorkflowReq.builder();
|
||||
//
|
||||
// setIfPresent(builder::workflowID, annotation.workflowId());
|
||||
//
|
||||
// if (annotation.parameters().length > 0) {
|
||||
// Map<String, Object> paramMap = KeyValueUtil.buildObjectMapFromKeyValue(annotation.parameters(), processorContext);
|
||||
// if (!paramMap.isEmpty()) {
|
||||
// builder.parameters(paramMap);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// setIfPresent(builder::appID, annotation.appId());
|
||||
//
|
||||
// setIfPresent(builder::botID, annotation.botId());
|
||||
//
|
||||
// if (annotation.ext().length > 0) {
|
||||
// Map<String, String> extMap = KeyValueUtil.buildStringMapFromKeyValue(annotation.ext(), processorContext);
|
||||
// if (!extMap.isEmpty()) {
|
||||
// builder.ext(extMap);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// setIfPresent(builder::connectTimeout, annotation.connectTimeout());
|
||||
//
|
||||
// setIfPresent(builder::readTimeout, annotation.readTimeout());
|
||||
//
|
||||
// setIfPresent(builder::writeTimeout, annotation.writeTimeout());
|
||||
//
|
||||
// setIfPresent(builder::customerToken, annotation.customerToken());
|
||||
//
|
||||
// WorkflowRunService runs = api.workflows().runs();
|
||||
//
|
||||
// if (annotation.stream()) {
|
||||
// Flowable<WorkflowEvent> res = runs.stream(builder.build());
|
||||
// InteractContext context = new InteractContext();
|
||||
// res.blockingForEach(chunk -> {
|
||||
// ChatContext chatContext = processorContext.getChatContext();
|
||||
// chatContext.getStreamHandler().onText(chunk.getMessage().getContent(), context);
|
||||
// });
|
||||
// return null;
|
||||
// } else {
|
||||
// return runs.create(builder.build());
|
||||
// }
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.alibaba.dashscope.app.FlowStreamMode;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.yomahub.liteflow.ai.annotation.AIComponent;
|
||||
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.InteractContext;
|
||||
|
||||
import java.lang.annotation.*;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
package com.yomahub.liteflow.ai.workflow.dashscope.invocation;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.alibaba.dashscope.app.Application;
|
||||
import com.alibaba.dashscope.app.ApplicationParam;
|
||||
import com.alibaba.dashscope.app.ApplicationResult;
|
||||
import com.alibaba.dashscope.app.RagOptions;
|
||||
import com.alibaba.dashscope.exception.InputRequiredException;
|
||||
import com.alibaba.dashscope.exception.NoApiKeyException;
|
||||
import com.alibaba.dashscope.utils.JsonUtils;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
|
||||
import com.yomahub.liteflow.ai.parse.context.ContextAccessor;
|
||||
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
import com.yomahub.liteflow.ai.proxy.invocation.AbstractAIInvocationHandler;
|
||||
import com.yomahub.liteflow.ai.util.KeyValue;
|
||||
import com.yomahub.liteflow.ai.util.SpringUtil;
|
||||
import com.yomahub.liteflow.ai.workflow.dashscope.annotation.DashScopeWorkflow;
|
||||
import com.yomahub.liteflow.ai.workflow.dashscope.config.DashScopeWorkflowProperty;
|
||||
import com.yomahub.liteflow.ai.workflow.dashscope.wrap.DashScopeWorkflowProxyWrapBean;
|
||||
import io.reactivex.Flowable;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
@@ -46,95 +31,96 @@ public class DashScopeWorkflowInvocationHandler extends AbstractAIInvocationHand
|
||||
|
||||
@Override
|
||||
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
|
||||
DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation();
|
||||
|
||||
DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class);
|
||||
|
||||
ApplicationParam.ApplicationParamBuilder<?, ?> builder = ApplicationParam.builder();
|
||||
|
||||
builder.appId(dashScopeWorkflow.appId());
|
||||
builder.apiKey(property.getApiKey());
|
||||
|
||||
setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext);
|
||||
|
||||
String history = dashScopeWorkflow.history();
|
||||
setIfPresent(builder::history, StrUtil.isNotBlank(history) ?
|
||||
ContextAccessor.searchContextByExpression(history, processorContext) : null);
|
||||
|
||||
String messages = dashScopeWorkflow.messages();
|
||||
setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ?
|
||||
ContextAccessor.searchContextByExpression(messages, processorContext) : null);
|
||||
|
||||
setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId());
|
||||
|
||||
setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts());
|
||||
|
||||
setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(),
|
||||
processorContext, JsonObject.class, JsonUtils::parse);
|
||||
|
||||
setIfPresent(builder::topP, dashScopeWorkflow.topP());
|
||||
|
||||
setIfPresent(builder::topK, dashScopeWorkflow.topK());
|
||||
|
||||
setIfPresent(builder::seed, dashScopeWorkflow.seed());
|
||||
|
||||
setIfPresent(builder::temperature, dashScopeWorkflow.temperature());
|
||||
|
||||
setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput());
|
||||
|
||||
setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId());
|
||||
|
||||
setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class,
|
||||
s -> Arrays.stream(s.split(COMMA_SPLITTER))
|
||||
.map(String::trim)
|
||||
.collect(Collectors.toList()));
|
||||
|
||||
setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext));
|
||||
|
||||
setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class,
|
||||
s -> Arrays.stream(s.split(COMMA_SPLITTER))
|
||||
.map(String::trim)
|
||||
.collect(Collectors.toList()));
|
||||
|
||||
setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch());
|
||||
|
||||
setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime());
|
||||
|
||||
setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium());
|
||||
|
||||
setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound());
|
||||
|
||||
setIfPresent(builder::modelId, dashScopeWorkflow.modelId());
|
||||
|
||||
setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode());
|
||||
|
||||
setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking());
|
||||
|
||||
ApplicationParam applicationParam = builder.build();
|
||||
Application application = StrUtil.isNotBlank(property.getApiUrl()) ?
|
||||
new Application(property.getApiUrl()) :
|
||||
new Application();
|
||||
|
||||
if (dashScopeWorkflow.stream()) {
|
||||
try {
|
||||
Flowable<ApplicationResult> res = application.streamCall(applicationParam);
|
||||
InteractContext context = new InteractContext();
|
||||
res.blockingForEach(chunk -> {
|
||||
ChatContext chatContext = processorContext.getChatContext();
|
||||
chatContext.getStreamHandler().onText(chunk.getOutput().getText(), context);
|
||||
context.addText(chunk.getOutput().getText());
|
||||
});
|
||||
return null;
|
||||
} catch (NoApiKeyException | InputRequiredException e) {
|
||||
throw new LiteFlowAIException("DashScope stream call failed", e);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
return application.call(applicationParam);
|
||||
} catch (NoApiKeyException | InputRequiredException e) {
|
||||
throw new LiteFlowAIException("DashScope call failed", e);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
// DashScopeWorkflow dashScopeWorkflow = wrapBean.getAnnotation();
|
||||
//
|
||||
// DashScopeWorkflowProperty property = SpringUtil.getBean(DashScopeWorkflowProperty.class);
|
||||
//
|
||||
// ApplicationParam.ApplicationParamBuilder<?, ?> builder = ApplicationParam.builder();
|
||||
//
|
||||
// builder.appId(dashScopeWorkflow.appId());
|
||||
// builder.apiKey(property.getApiKey());
|
||||
//
|
||||
// setIfPresent(builder::prompt, dashScopeWorkflow.prompt(), processorContext);
|
||||
//
|
||||
// String history = dashScopeWorkflow.history();
|
||||
// setIfPresent(builder::history, StrUtil.isNotBlank(history) ?
|
||||
// ContextAccessor.searchContextByExpression(history, processorContext) : null);
|
||||
//
|
||||
// String messages = dashScopeWorkflow.messages();
|
||||
// setIfPresent(builder::messages, StrUtil.isNotBlank(messages) ?
|
||||
// ContextAccessor.searchContextByExpression(messages, processorContext) : null);
|
||||
//
|
||||
// setIfPresent(builder::sessionId, dashScopeWorkflow.sessionId());
|
||||
//
|
||||
// setIfPresent(builder::hasThoughts, dashScopeWorkflow.hasThoughts());
|
||||
//
|
||||
// setIfPresent(builder::bizParams, dashScopeWorkflow.bizParams(),
|
||||
// processorContext, JsonObject.class, JsonUtils::parse);
|
||||
//
|
||||
// setIfPresent(builder::topP, dashScopeWorkflow.topP());
|
||||
//
|
||||
// setIfPresent(builder::topK, dashScopeWorkflow.topK());
|
||||
//
|
||||
// setIfPresent(builder::seed, dashScopeWorkflow.seed());
|
||||
//
|
||||
// setIfPresent(builder::temperature, dashScopeWorkflow.temperature());
|
||||
//
|
||||
// setIfPresent(builder::incrementalOutput, dashScopeWorkflow.incrementalOutput());
|
||||
//
|
||||
// setIfPresent(builder::memoryId, dashScopeWorkflow.memoryId());
|
||||
//
|
||||
// setIfPresent(builder::images, dashScopeWorkflow.images(), processorContext, List.class,
|
||||
// s -> Arrays.stream(s.split(COMMA_SPLITTER))
|
||||
// .map(String::trim)
|
||||
// .collect(Collectors.toList()));
|
||||
//
|
||||
// setIfPresent(builder::ragOptions, buildRagOptions(dashScopeWorkflow.ragOptions(), processorContext));
|
||||
//
|
||||
// setIfPresent(builder::mcpServers, dashScopeWorkflow.mcpServers(), processorContext, List.class,
|
||||
// s -> Arrays.stream(s.split(COMMA_SPLITTER))
|
||||
// .map(String::trim)
|
||||
// .collect(Collectors.toList()));
|
||||
//
|
||||
// setIfPresent(builder::enableWebSearch, dashScopeWorkflow.enableWebSearch());
|
||||
//
|
||||
// setIfPresent(builder::enableSystemTime, dashScopeWorkflow.enableSystemTime());
|
||||
//
|
||||
// setIfPresent(builder::enablePremium, dashScopeWorkflow.enablePremium());
|
||||
//
|
||||
// setIfPresent(builder::dialogRound, dashScopeWorkflow.dialogRound());
|
||||
//
|
||||
// setIfPresent(builder::modelId, dashScopeWorkflow.modelId());
|
||||
//
|
||||
// setIfPresent(builder::flowStreamMode, dashScopeWorkflow.flowStreamMode());
|
||||
//
|
||||
// setIfPresent(builder::enableThinking, dashScopeWorkflow.enableThinking());
|
||||
//
|
||||
// ApplicationParam applicationParam = builder.build();
|
||||
// Application application = StrUtil.isNotBlank(property.getApiUrl()) ?
|
||||
// new Application(property.getApiUrl()) :
|
||||
// new Application();
|
||||
//
|
||||
// if (dashScopeWorkflow.stream()) {
|
||||
// try {
|
||||
// Flowable<ApplicationResult> res = application.streamCall(applicationParam);
|
||||
// InteractContext context = new InteractContext();
|
||||
// res.blockingForEach(chunk -> {
|
||||
// ChatContext chatContext = processorContext.getChatContext();
|
||||
// chatContext.getStreamHandler().onText(chunk.getOutput().getText(), context);
|
||||
// context.addText(chunk.getOutput().getText());
|
||||
// });
|
||||
// return null;
|
||||
// } catch (NoApiKeyException | InputRequiredException e) {
|
||||
// throw new LiteFlowAIException("DashScope stream call failed", e);
|
||||
// }
|
||||
// } else {
|
||||
// try {
|
||||
// return application.call(applicationParam);
|
||||
// } catch (NoApiKeyException | InputRequiredException e) {
|
||||
// throw new LiteFlowAIException("DashScope call failed", e);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
<victools.version>4.36.0</victools.version>
|
||||
<dashscope-sdk.version>2.21.5</dashscope-sdk.version>
|
||||
<coze-api.version>0.4.2</coze-api.version>
|
||||
<rxjava.version>3.1.8</rxjava.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
@@ -72,6 +73,12 @@
|
||||
<artifactId>coze-api</artifactId>
|
||||
<version>${coze-api.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>io.reactivex.rxjava3</groupId>
|
||||
<artifactId>rxjava</artifactId>
|
||||
<version>${rxjava.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
</project>
|
||||
@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.chat;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.core.FlowExecutor;
|
||||
@@ -73,46 +75,47 @@ public class ChatTest extends MockAITest {
|
||||
}
|
||||
|
||||
private StreamHandler getStreamHandler() {
|
||||
return StreamHandler.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
return eventStream -> eventStream.doOnNext(chunkEvent -> {
|
||||
if (chunkEvent.isStart()) {
|
||||
System.out.println("chat start");
|
||||
}
|
||||
if (chunkEvent.isChunk()) {
|
||||
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
|
||||
switch (chunk.getType()) {
|
||||
case TEXT:
|
||||
System.out.println(chunk.getData());
|
||||
break;
|
||||
case THINKING:
|
||||
System.out.println("[Thinking] " + chunk.getData());
|
||||
}
|
||||
}
|
||||
if (chunkEvent.isComplete()) {
|
||||
ChatResponse response = chunkEvent.getFinalResponse();
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
})
|
||||
.build();
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
System.out.println("chat close");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.structure;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.core.FlowExecutor;
|
||||
@@ -80,45 +82,46 @@ public class StructureTest extends MockAITest {
|
||||
}
|
||||
|
||||
private StreamHandler getStreamHandler() {
|
||||
return StreamHandler.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
return eventStream -> eventStream.doOnNext(chunkEvent -> {
|
||||
if (chunkEvent.isStart()) {
|
||||
System.out.println("chat start");
|
||||
}
|
||||
if (chunkEvent.isChunk()) {
|
||||
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
|
||||
switch (chunk.getType()) {
|
||||
case TEXT:
|
||||
System.out.println(chunk.getData());
|
||||
break;
|
||||
case THINKING:
|
||||
System.out.println("[Thinking] " + chunk.getData());
|
||||
}
|
||||
}
|
||||
if (chunkEvent.isComplete()) {
|
||||
ChatResponse response = chunkEvent.getFinalResponse();
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
})
|
||||
.build();
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
System.out.println("chat close");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package com.yomahub.liteflow.test.ai.core.tool;
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
@@ -98,78 +100,56 @@ public class ToolTest extends MockAITest {
|
||||
|
||||
private ChatContext buildBlockingChatContext() {
|
||||
return ChatContext.builder()
|
||||
.streamHandler(StreamHandler.builder()
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
})
|
||||
.build())
|
||||
.build();
|
||||
}
|
||||
|
||||
private ChatContext buildStreamingChatContext() {
|
||||
return ChatContext.builder()
|
||||
.streamHandler(StreamHandler.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
})
|
||||
.build())
|
||||
.streamHandler(getStreamHandler())
|
||||
.build();
|
||||
}
|
||||
|
||||
private StreamHandler getStreamHandler() {
|
||||
return eventStream -> eventStream.doOnNext(chunkEvent -> {
|
||||
if (chunkEvent.isStart()) {
|
||||
System.out.println("chat start");
|
||||
}
|
||||
if (chunkEvent.isChunk()) {
|
||||
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
|
||||
switch (chunk.getType()) {
|
||||
case TEXT:
|
||||
System.out.println(chunk.getData());
|
||||
break;
|
||||
case THINKING:
|
||||
System.out.println("[Thinking] " + chunk.getData());
|
||||
}
|
||||
}
|
||||
if (chunkEvent.isComplete()) {
|
||||
ChatResponse response = chunkEvent.getFinalResponse();
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
System.out.println("chat close");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,28 +1,5 @@
|
||||
package com.yomahub.liteflow.test.ai.engine.interact;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.callbacks.ChunkCallbackTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.InteractContext;
|
||||
import com.yomahub.liteflow.ai.engine.interact.protocol.ProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.dashscope.interact.DashScopeProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.model.ollama.interact.OllamaProtocolTransformer;
|
||||
import com.yomahub.liteflow.ai.model.openai.interact.OpenAIProtocolTransformer;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* 交互处理器测试类
|
||||
* 测试各种协议转换器和流式处理管道的功能
|
||||
@@ -31,368 +8,327 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||
* @since 2.16.0
|
||||
*/
|
||||
public class InteractTest {
|
||||
|
||||
private ChunkProcessPipeline pipeline;
|
||||
private ProtocolTransformer protocolTransformer;
|
||||
private ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
private InteractContext context;
|
||||
|
||||
/**
|
||||
* 创建协议转换器的工厂方法
|
||||
*
|
||||
* @param provider 提供者枚举
|
||||
* @return 对应的协议转换器实例
|
||||
*/
|
||||
private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) {
|
||||
switch (provider) {
|
||||
case OPENAI:
|
||||
return new OpenAIProtocolTransformer();
|
||||
case DASHSCOPE:
|
||||
return new DashScopeProtocolTransformer();
|
||||
case OLLAMA:
|
||||
return new OllamaProtocolTransformer();
|
||||
default:
|
||||
throw new IllegalArgumentException("不支持的提供者: " + provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 简单的回调处理器,用于测试
|
||||
*/
|
||||
private static class TestChunkCallbackTransformer implements ChunkCallbackTransformer {
|
||||
@Override
|
||||
public String onText(String content, InteractContext context) {
|
||||
System.out.println("onText: " + content);
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onThinking(String content, InteractContext context) {
|
||||
System.out.println("onThinking: " + content);
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ToolCall> onToolsCalling(List<ToolCall> toolCalls, InteractContext context) {
|
||||
// 测试用的简单实现,直接返回原工具调用
|
||||
return toolCalls;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onUsage(Object content, InteractContext context) {
|
||||
// 测试用的简单实现,直接返回原使用统计
|
||||
return content;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object onGrounding(Object content, InteractContext context) {
|
||||
// 测试用的简单实现,直接返回原基础信息
|
||||
return content;
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("OpenAI协议转换器测试")
|
||||
class OpenAITests {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
context = new InteractContext();
|
||||
protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI);
|
||||
chunkCallbackTransformer = new TestChunkCallbackTransformer();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI流式文本响应转换")
|
||||
void testStreamingTextResponse() {
|
||||
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI流式工具调用响应转换")
|
||||
void testStreamingToolCallResponse() {
|
||||
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI流式结构化响应转换")
|
||||
void testStreamingStructuredResponse() {
|
||||
testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI阻塞式文本响应转换")
|
||||
void testBlockingTextResponse() {
|
||||
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI阻塞式工具调用响应转换")
|
||||
void testBlockingToolCallResponse() {
|
||||
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试OpenAI阻塞式结构化响应转换")
|
||||
void testBlockingStructuredResponse() {
|
||||
testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("DashScope协议转换器测试")
|
||||
class DashScopeTests {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
context = new InteractContext();
|
||||
protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE);
|
||||
chunkCallbackTransformer = new TestChunkCallbackTransformer();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope流式文本响应转换")
|
||||
void testStreamingTextResponse() {
|
||||
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope流式工具调用响应转换")
|
||||
void testStreamingToolCallResponse() {
|
||||
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope流式结构化响应转换")
|
||||
void testStreamingStructuredResponse() {
|
||||
testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope阻塞式文本响应转换")
|
||||
void testBlockingTextResponse() {
|
||||
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope阻塞式工具调用响应转换")
|
||||
void testBlockingToolCallResponse() {
|
||||
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试DashScope阻塞式结构化响应转换")
|
||||
void testBlockingStructuredResponse() {
|
||||
testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Ollama协议转换器测试")
|
||||
class OllamaTests {
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
context = new InteractContext();
|
||||
protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA);
|
||||
chunkCallbackTransformer = new TestChunkCallbackTransformer();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama流式文本响应转换")
|
||||
void testStreamingTextResponse() {
|
||||
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama流式工具调用响应转换")
|
||||
void testStreamingToolCallResponse() {
|
||||
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama流式结构化响应转换")
|
||||
void testStreamingStructuredResponse() {
|
||||
testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama阻塞式文本响应转换")
|
||||
void testBlockingTextResponse() {
|
||||
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama阻塞式工具调用响应转换")
|
||||
void testBlockingToolCallResponse() {
|
||||
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("测试Ollama阻塞式结构化响应转换")
|
||||
void testBlockingStructuredResponse() {
|
||||
testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("参数化测试 - 所有协议转换器")
|
||||
class ParameterizedTests {
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" })
|
||||
@DisplayName("参数化测试协议转换器提供者名称")
|
||||
void testProtocolTransformerProviderEnumName(ProviderEnum provider) {
|
||||
ProtocolTransformer transformer = createProtocolTransformer(provider);
|
||||
assertEquals(provider.getProviderName(), transformer.getProviderName(),
|
||||
provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试流式响应处理的通用方法
|
||||
*
|
||||
* @param provider 提供者
|
||||
* @param requestType 响应类型
|
||||
* @param hasToolCalls 是否包含工具调用
|
||||
*/
|
||||
private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) {
|
||||
// 重新构建pipeline
|
||||
context = new InteractContext();
|
||||
protocolTransformer = createProtocolTransformer(provider);
|
||||
chunkCallbackTransformer = new TestChunkCallbackTransformer();
|
||||
pipeline = ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, chunkCallbackTransformer);
|
||||
|
||||
// 读取测试数据
|
||||
List<String> streamChunks = TestDataReader.getStreamingChunks(provider, requestType);
|
||||
assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空");
|
||||
|
||||
// 处理流式数据
|
||||
for (String streamChunk : streamChunks) {
|
||||
pipeline.processStreaming(streamChunk);
|
||||
// 注意:这里不强制要求最后一次返回true,因为不同模型的结束标志可能不同
|
||||
}
|
||||
|
||||
// 获取最终响应
|
||||
ChatResponse chatResponse = pipeline.buildFinalStreamingResponse();
|
||||
|
||||
// 验证响应
|
||||
assertNotNull(chatResponse, "聊天响应不应为null");
|
||||
assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
|
||||
assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
|
||||
|
||||
AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
|
||||
|
||||
// 验证工具调用
|
||||
if (hasToolCalls) {
|
||||
assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用");
|
||||
assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
|
||||
assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
|
||||
|
||||
// 验证工具调用内容
|
||||
for (ToolCall toolCall : message.getToolCalls()) {
|
||||
if (!provider.equals(ProviderEnum.OLLAMA)) {
|
||||
assertNotNull(toolCall.getId(), "工具调用ID不应为null");
|
||||
}
|
||||
assertNotNull(toolCall.getName(), "工具调用名称不应为null");
|
||||
assertNotNull(toolCall.getType(), "工具调用类型不应为null");
|
||||
}
|
||||
} else {
|
||||
// 对于非工具调用响应,验证内容存在
|
||||
assertNotNull(message.getContent(), "消息内容不应为null");
|
||||
// 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
|
||||
}
|
||||
|
||||
// 打印测试结果
|
||||
System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ===");
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println(
|
||||
"内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..."
|
||||
: message.getContent()));
|
||||
}
|
||||
if (hasToolCalls && message.getToolCalls() != null) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
}
|
||||
System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
|
||||
System.out.println("完成原因: " + chatResponse.getFinishReason());
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试阻塞式响应处理的通用方法
|
||||
*
|
||||
* @param provider 提供者
|
||||
* @param responseType 响应类型
|
||||
* @param hasToolCalls 是否包含工具调用
|
||||
*/
|
||||
private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) {
|
||||
// 重新构建pipeline
|
||||
context = new InteractContext();
|
||||
protocolTransformer = createProtocolTransformer(provider);
|
||||
pipeline = ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
|
||||
|
||||
// 读取测试数据
|
||||
String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType);
|
||||
assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null");
|
||||
assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空");
|
||||
|
||||
// 处理阻塞式响应
|
||||
ChatResponse chatResponse = pipeline.processBlocking(blockingResponse);
|
||||
|
||||
// 验证响应
|
||||
assertNotNull(chatResponse, "聊天响应不应为null");
|
||||
assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
|
||||
assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
|
||||
|
||||
AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
|
||||
|
||||
// 验证工具调用
|
||||
if (hasToolCalls) {
|
||||
assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用");
|
||||
assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
|
||||
assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
|
||||
|
||||
// 验证每个工具调用
|
||||
for (ToolCall toolCall : message.getToolCalls()) {
|
||||
if (!provider.equals(ProviderEnum.OLLAMA)) {
|
||||
assertNotNull(toolCall.getId(), "工具调用ID不应为null");
|
||||
}
|
||||
assertNotNull(toolCall.getName(), "工具调用名称不应为null");
|
||||
assertNotNull(toolCall.getType(), "工具调用类型不应为null");
|
||||
assertEquals("function", toolCall.getType(), "工具调用类型应为function");
|
||||
}
|
||||
} else {
|
||||
// 对于非工具调用响应,验证内容存在
|
||||
assertNotNull(message.getContent(), "消息内容不应为null");
|
||||
// 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
|
||||
}
|
||||
|
||||
// 验证token使用情况(如果存在)
|
||||
if (chatResponse.getTokenUsage() != null) {
|
||||
assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0");
|
||||
}
|
||||
|
||||
// 打印测试结果
|
||||
System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ===");
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (hasToolCalls && message.getToolCalls() != null) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
|
||||
System.out.println("完成原因: " + chatResponse.getFinishReason());
|
||||
}
|
||||
//
|
||||
// private ProtocolTransformer protocolTransformer;
|
||||
// private InteractContext context;
|
||||
//
|
||||
// /**
|
||||
// * 创建协议转换器的工厂方法
|
||||
// *
|
||||
// * @param provider 提供者枚举
|
||||
// * @return 对应的协议转换器实例
|
||||
// */
|
||||
// private ProtocolTransformer createProtocolTransformer(ProviderEnum provider) {
|
||||
// switch (provider) {
|
||||
// case OPENAI:
|
||||
// return new OpenAIProtocolTransformer();
|
||||
// case DASHSCOPE:
|
||||
// return new DashScopeProtocolTransformer();
|
||||
// case OLLAMA:
|
||||
// return new OllamaProtocolTransformer();
|
||||
// default:
|
||||
// throw new IllegalArgumentException("不支持的提供者: " + provider);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Nested
|
||||
// @DisplayName("OpenAI协议转换器测试")
|
||||
// class OpenAITests {
|
||||
//
|
||||
// @BeforeEach
|
||||
// void setUp() {
|
||||
// context = new InteractContext();
|
||||
// protocolTransformer = createProtocolTransformer(ProviderEnum.OPENAI);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI流式文本响应转换")
|
||||
// void testStreamingTextResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI流式工具调用响应转换")
|
||||
// void testStreamingToolCallResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI流式结构化响应转换")
|
||||
// void testStreamingStructuredResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI阻塞式文本响应转换")
|
||||
// void testBlockingTextResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI阻塞式工具调用响应转换")
|
||||
// void testBlockingToolCallResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试OpenAI阻塞式结构化响应转换")
|
||||
// void testBlockingStructuredResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OPENAI, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Nested
|
||||
// @DisplayName("DashScope协议转换器测试")
|
||||
// class DashScopeTests {
|
||||
//
|
||||
// @BeforeEach
|
||||
// void setUp() {
|
||||
// context = new InteractContext();
|
||||
// protocolTransformer = createProtocolTransformer(ProviderEnum.DASHSCOPE);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope流式文本响应转换")
|
||||
// void testStreamingTextResponse() {
|
||||
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope流式工具调用响应转换")
|
||||
// void testStreamingToolCallResponse() {
|
||||
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope流式结构化响应转换")
|
||||
// void testStreamingStructuredResponse() {
|
||||
// testStreamingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope阻塞式文本响应转换")
|
||||
// void testBlockingTextResponse() {
|
||||
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope阻塞式工具调用响应转换")
|
||||
// void testBlockingToolCallResponse() {
|
||||
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试DashScope阻塞式结构化响应转换")
|
||||
// void testBlockingStructuredResponse() {
|
||||
// testBlockingResponse(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Nested
|
||||
// @DisplayName("Ollama协议转换器测试")
|
||||
// class OllamaTests {
|
||||
//
|
||||
// @BeforeEach
|
||||
// void setUp() {
|
||||
// context = new InteractContext();
|
||||
// protocolTransformer = createProtocolTransformer(ProviderEnum.OLLAMA);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama流式文本响应转换")
|
||||
// void testStreamingTextResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama流式工具调用响应转换")
|
||||
// void testStreamingToolCallResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama流式结构化响应转换")
|
||||
// void testStreamingStructuredResponse() {
|
||||
// testStreamingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama阻塞式文本响应转换")
|
||||
// void testBlockingTextResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TEXT, false);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama阻塞式工具调用响应转换")
|
||||
// void testBlockingToolCallResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_TOOL_CALL, true);
|
||||
// }
|
||||
//
|
||||
// @Test
|
||||
// @DisplayName("测试Ollama阻塞式结构化响应转换")
|
||||
// void testBlockingStructuredResponse() {
|
||||
// testBlockingResponse(ProviderEnum.OLLAMA, TestDataReader.RequestType.BLOCKING_STRUCTURED, false);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Nested
|
||||
// @DisplayName("参数化测试 - 所有协议转换器")
|
||||
// class ParameterizedTests {
|
||||
//
|
||||
// @ParameterizedTest
|
||||
// @EnumSource(value = ProviderEnum.class, names = { "OPENAI", "DASHSCOPE", "OLLAMA" })
|
||||
// @DisplayName("参数化测试协议转换器提供者名称")
|
||||
// void testProtocolTransformerProviderEnumName(ProviderEnum provider) {
|
||||
// ProtocolTransformer transformer = createProtocolTransformer(provider);
|
||||
// assertEquals(provider.getProviderName(), transformer.getProviderName(),
|
||||
// provider.getProviderName() + "协议转换器的提供者名称应为'" + provider.getProviderName() + "'");
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 测试流式响应处理的通用方法
|
||||
// *
|
||||
// * @param provider 提供者
|
||||
// * @param requestType 响应类型
|
||||
// * @param hasToolCalls 是否包含工具调用
|
||||
// */
|
||||
// private void testStreamingResponse(ProviderEnum provider, TestDataReader.RequestType requestType, boolean hasToolCalls) {
|
||||
// // 重新构建pipeline
|
||||
// context = new InteractContext();
|
||||
// protocolTransformer = createProtocolTransformer(provider);
|
||||
// pipeline = ChunkProcessPipeline.createStreamingPipeline(context, protocolTransformer, chunkCallbackTransformer);
|
||||
//
|
||||
// // 读取测试数据
|
||||
// List<String> streamChunks = TestDataReader.getStreamingChunks(provider, requestType);
|
||||
// assertFalse(streamChunks.isEmpty(), provider.getProviderName() + "流式" + requestType.getRequestName() + "响应数据不应为空");
|
||||
//
|
||||
// // 处理流式数据
|
||||
// for (String streamChunk : streamChunks) {
|
||||
// pipeline.processStreaming(streamChunk);
|
||||
// // 注意:这里不强制要求最后一次返回true,因为不同模型的结束标志可能不同
|
||||
// }
|
||||
//
|
||||
// // 获取最终响应
|
||||
// ChatResponse chatResponse = pipeline.buildFinalStreamingResponse();
|
||||
//
|
||||
// // 验证响应
|
||||
// assertNotNull(chatResponse, "聊天响应不应为null");
|
||||
// assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
|
||||
// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
|
||||
//
|
||||
// AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
|
||||
//
|
||||
// // 验证工具调用
|
||||
// if (hasToolCalls) {
|
||||
// assertTrue(chatResponse.hasToolCalls(), requestType.getRequestName() + "响应应包含工具调用");
|
||||
// assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
|
||||
// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
|
||||
//
|
||||
// // 验证工具调用内容
|
||||
// for (ToolCall toolCall : message.getToolCalls()) {
|
||||
// if (!provider.equals(ProviderEnum.OLLAMA)) {
|
||||
// assertNotNull(toolCall.getId(), "工具调用ID不应为null");
|
||||
// }
|
||||
// assertNotNull(toolCall.getName(), "工具调用名称不应为null");
|
||||
// assertNotNull(toolCall.getType(), "工具调用类型不应为null");
|
||||
// }
|
||||
// } else {
|
||||
// // 对于非工具调用响应,验证内容存在
|
||||
// assertNotNull(message.getContent(), "消息内容不应为null");
|
||||
// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
|
||||
// }
|
||||
//
|
||||
// // 打印测试结果
|
||||
// System.out.println("=== " + provider.getProviderName() + " 流式" + requestType.getRequestName() + "响应测试结果 ===");
|
||||
// if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
// System.out.println(
|
||||
// "内容: " + (message.getContent().length() > 100 ? message.getContent().substring(0, 100) + "..."
|
||||
// : message.getContent()));
|
||||
// }
|
||||
// if (hasToolCalls && message.getToolCalls() != null) {
|
||||
// System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
// }
|
||||
// System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
|
||||
// System.out.println("完成原因: " + chatResponse.getFinishReason());
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * 测试阻塞式响应处理的通用方法
|
||||
// *
|
||||
// * @param provider 提供者
|
||||
// * @param responseType 响应类型
|
||||
// * @param hasToolCalls 是否包含工具调用
|
||||
// */
|
||||
// private void testBlockingResponse(ProviderEnum provider, TestDataReader.RequestType responseType, boolean hasToolCalls) {
|
||||
// // 重新构建pipeline
|
||||
// context = new InteractContext();
|
||||
// protocolTransformer = createProtocolTransformer(provider);
|
||||
// pipeline = ChunkProcessPipeline.createBlockingPipeline(context, protocolTransformer);
|
||||
//
|
||||
// // 读取测试数据
|
||||
// String blockingResponse = TestDataReader.getBlockingResponse(provider, responseType);
|
||||
// assertNotNull(blockingResponse, provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为null");
|
||||
// assertFalse(blockingResponse.trim().isEmpty(), provider.getProviderName() + "阻塞式" + responseType.getRequestName() + "响应数据不应为空");
|
||||
//
|
||||
// // 处理阻塞式响应
|
||||
// ChatResponse chatResponse = pipeline.processBlocking(blockingResponse);
|
||||
//
|
||||
// // 验证响应
|
||||
// assertNotNull(chatResponse, "聊天响应不应为null");
|
||||
// assertNotNull(chatResponse.getOutput(), "响应输出不应为null");
|
||||
// assertTrue(chatResponse.getOutput() instanceof AssistantMessage, "输出应为AssistantMessage类型");
|
||||
//
|
||||
// AssistantMessage message = (AssistantMessage) chatResponse.getOutput();
|
||||
//
|
||||
// // 验证工具调用
|
||||
// if (hasToolCalls) {
|
||||
// assertTrue(chatResponse.hasToolCalls(), responseType.getRequestName() + "响应应包含工具调用");
|
||||
// assertNotNull(message.getToolCalls(), "工具调用列表不应为null");
|
||||
// assertFalse(message.getToolCalls().isEmpty(), "工具调用列表不应为空");
|
||||
//
|
||||
// // 验证每个工具调用
|
||||
// for (ToolCall toolCall : message.getToolCalls()) {
|
||||
// if (!provider.equals(ProviderEnum.OLLAMA)) {
|
||||
// assertNotNull(toolCall.getId(), "工具调用ID不应为null");
|
||||
// }
|
||||
// assertNotNull(toolCall.getName(), "工具调用名称不应为null");
|
||||
// assertNotNull(toolCall.getType(), "工具调用类型不应为null");
|
||||
// assertEquals("function", toolCall.getType(), "工具调用类型应为function");
|
||||
// }
|
||||
// } else {
|
||||
// // 对于非工具调用响应,验证内容存在
|
||||
// assertNotNull(message.getContent(), "消息内容不应为null");
|
||||
// // 注意:某些情况下内容可能为空字符串,所以这里不强制要求非空
|
||||
// }
|
||||
//
|
||||
// // 验证token使用情况(如果存在)
|
||||
// if (chatResponse.getTokenUsage() != null) {
|
||||
// assertTrue(chatResponse.getTokenUsage().getTotalTokenCount() > 0, "总token数应大于0");
|
||||
// }
|
||||
//
|
||||
// // 打印测试结果
|
||||
// System.out.println("=== " + provider.getProviderName() + " 阻塞式" + responseType.getRequestName() + "响应测试结果 ===");
|
||||
// if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
// System.out.println("内容长度: " + message.getContent().length());
|
||||
// if (message.getContent().length() > 200) {
|
||||
// System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
// } else {
|
||||
// System.out.println("内容: " + message.getContent());
|
||||
// }
|
||||
// }
|
||||
// if (hasToolCalls && message.getToolCalls() != null) {
|
||||
// System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
// for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
// ToolCall toolCall = message.getToolCalls().get(i);
|
||||
// System.out.println("工具调用 " + (i + 1) + ":");
|
||||
// System.out.println(" ID: " + toolCall.getId());
|
||||
// System.out.println(" 名称: " + toolCall.getName());
|
||||
// System.out.println(" 类型: " + toolCall.getType());
|
||||
// System.out.println(" 参数: " + toolCall.getArguments());
|
||||
// }
|
||||
// }
|
||||
// System.out.println("Token使用情况: " + chatResponse.getTokenUsage());
|
||||
// System.out.println("完成原因: " + chatResponse.getFinishReason());
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package com.yomahub.liteflow.test.ai.mock.mockbean;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.LlmInteractClient;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
@@ -51,14 +53,14 @@ public class MockInteractClient extends LlmInteractClient {
|
||||
* 重写 stream 方法
|
||||
*/
|
||||
@Override
|
||||
public void stream(ChatConfig config, ChatRequest request) {
|
||||
public Flowable<ChunkEvent> stream(ChatConfig config, ChatRequest request) {
|
||||
// 创建一个 spied request,它在内部被配置为使用 MockTransport
|
||||
ChatRequest spiedRequest = createSpiedRequest(request);
|
||||
|
||||
// 调用父类的 stream 方法。
|
||||
// 当父类的 InteractManager 构造函数被调用时,
|
||||
// 它将使用我们的 spiedRequest,并最终获取到 MockTransport。
|
||||
super.stream(config, spiedRequest);
|
||||
return super.stream(config, spiedRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package com.yomahub.liteflow.test.ai.mock.mockbean;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.pipeline.ChunkProcessPipeline;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.Transport;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatConfig;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -19,43 +17,29 @@ import java.util.List;
|
||||
|
||||
public class MockTransport implements Transport {
|
||||
|
||||
private final String blockingResponse;
|
||||
private final List<String> streamingChunks;
|
||||
private final MockConfig config;
|
||||
|
||||
private ChunkProcessPipeline pipeline;
|
||||
private TransportListener listener = TransportListener.getDefault();
|
||||
private boolean isStop = false;
|
||||
|
||||
public MockTransport(MockConfig config) {
|
||||
TestDataReader.RequestType curRequestType = config.getRequestType();
|
||||
this.blockingResponse = TestDataReader.getBlockingResponse(config.getProvider(), curRequestType);
|
||||
this.streamingChunks = TestDataReader.getStreamingChunks(config.getProvider(), curRequestType);
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline, TransportListener listener) {
|
||||
this.listener = listener;
|
||||
this.pipeline = pipeline;
|
||||
listener.onStart(pipeline.getContext());
|
||||
|
||||
for (String streamingChunk : streamingChunks) {
|
||||
pipeline.processStreaming(streamingChunk);
|
||||
}
|
||||
|
||||
close();
|
||||
public Flowable<String> startStreaming(ChatConfig config, ChatRequest request) {
|
||||
List<String> streamingChunks = TestDataReader.getStreamingChunks(this.config.getProvider(), this.config.getRequestType());
|
||||
return Flowable.fromIterable(streamingChunks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse startBlocking(ChatConfig config, ChatRequest request, ChunkProcessPipeline pipeline) {
|
||||
this.pipeline = pipeline;
|
||||
return pipeline.processBlocking(blockingResponse);
|
||||
public String startBlocking(ChatConfig config, ChatRequest request) {
|
||||
return TestDataReader.getBlockingResponse(this.config.getProvider(), this.config.getRequestType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (!this.isStop) {
|
||||
this.isStop = true;
|
||||
this.listener.onClose(this.pipeline.getContext());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
package com.yomahub.liteflow.test.ai.model.chat.dashscope;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatModel;
|
||||
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* 阿里百炼 chat 测试
|
||||
@@ -58,7 +61,7 @@ public class DashScopeChatTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStreaming() throws ExecutionException, InterruptedException {
|
||||
public void testStreaming() {
|
||||
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TEXT);
|
||||
|
||||
List<Message> messages = Arrays.asList(
|
||||
@@ -66,21 +69,17 @@ public class DashScopeChatTest extends MockAITest {
|
||||
new UserMessage("请给我讲一个关于宇宙探索的短故事")
|
||||
);
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -99,44 +98,6 @@ public class DashScopeChatTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = DashScopeChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = DashScopeChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
package com.yomahub.liteflow.test.ai.model.chat.ollama;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* Ollama chat 测试
|
||||
@@ -57,7 +60,7 @@ public class OllamaChatTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStreaming() throws ExecutionException, InterruptedException {
|
||||
public void testStreaming() {
|
||||
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TEXT);
|
||||
|
||||
List<Message> messages = Arrays.asList(
|
||||
@@ -65,21 +68,17 @@ public class OllamaChatTest extends MockAITest {
|
||||
new UserMessage("why is the sky blue?")
|
||||
);
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.DN_JSON)
|
||||
.messages(messages)
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -96,44 +95,6 @@ public class OllamaChatTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OllamaChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OllamaChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
package com.yomahub.liteflow.test.ai.model.chat.openai;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatModel;
|
||||
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
@@ -61,20 +65,17 @@ public class OpenAIChatTest extends MockAITest {
|
||||
new SystemMessage("You are a helpful assistant."),
|
||||
new UserMessage("请给我讲一个关于未来城市的短故事"));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -92,44 +93,6 @@ public class OpenAIChatTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OpenAIChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OpenAIChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
package com.yomahub.liteflow.test.ai.model.structure.dashscope;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatModel;
|
||||
import com.yomahub.liteflow.ai.model.dashscope.model.chat.DashScopeChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* 阿里百炼 结构化输出 测试
|
||||
@@ -70,7 +73,7 @@ public class DashScopeStructureTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStructureStreaming() throws ExecutionException, InterruptedException {
|
||||
public void testStructureStreaming() {
|
||||
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_STRUCTURED);
|
||||
|
||||
List<Message> messages = Arrays.asList(
|
||||
@@ -78,9 +81,7 @@ public class DashScopeStructureTest extends MockAITest {
|
||||
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1")
|
||||
);
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
@@ -89,14 +90,11 @@ public class DashScopeStructureTest extends MockAITest {
|
||||
.responseType(ResponseType.JSON)
|
||||
.targetType(MathReasoning.class)
|
||||
// 结构化输出相关配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
.build());
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
// 将响应转换为结构化结果对象
|
||||
MathReasoning result = response.as(MathReasoning.class);
|
||||
@@ -121,44 +119,6 @@ public class DashScopeStructureTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = DashScopeChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = DashScopeChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
package com.yomahub.liteflow.test.ai.model.structure.ollama;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* Ollama 结构化输出 测试
|
||||
@@ -68,16 +71,14 @@ public class OllamaStructureTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStructureStreaming() throws ExecutionException, InterruptedException {
|
||||
public void testStructureStreaming() {
|
||||
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_STRUCTURED);
|
||||
|
||||
List<Message> messages = Arrays.asList(
|
||||
new SystemMessage("你是一位数学辅导老师"),
|
||||
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1"));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.DN_JSON)
|
||||
@@ -86,15 +87,12 @@ public class OllamaStructureTest extends MockAITest {
|
||||
.responseType(ResponseType.JSON)
|
||||
.targetType(MathReasoning.class)
|
||||
// 结构化输出相关配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
// 将响应转换为结构化结果对象
|
||||
MathReasoning result = response.as(MathReasoning.class);
|
||||
|
||||
Assertions.assertNotNull(result);
|
||||
@@ -116,44 +114,6 @@ public class OllamaStructureTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OllamaChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OllamaChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
package com.yomahub.liteflow.test.ai.model.structure.openai;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatModel;
|
||||
import com.yomahub.liteflow.ai.model.openai.model.chat.OpenAIChatRequest;
|
||||
import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import com.yomahub.liteflow.test.ai.model.structure.output.MathReasoning;
|
||||
import com.yomahub.liteflow.test.ai.model.util.StreamUtil;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
@@ -74,9 +78,7 @@ public class OpenAIStructureTest extends MockAITest {
|
||||
new SystemMessage("你是一位数学辅导老师"),
|
||||
new UserMessage("使用中文解题: 8x + 9 = 32 and x + y = 1"));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModel.stream(
|
||||
Flowable<ChunkEvent> stream = chatModel.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
@@ -85,13 +87,11 @@ public class OpenAIStructureTest extends MockAITest {
|
||||
.responseType(ResponseType.JSON)
|
||||
.targetType(MathReasoning.class)
|
||||
// 结构化输出相关配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.doOnNext(StreamUtil.getChunkEventConsumer())
|
||||
.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
// 将响应转换为结构化结果对象
|
||||
MathReasoning result = response.as(MathReasoning.class);
|
||||
@@ -116,44 +116,6 @@ public class OpenAIStructureTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OpenAIChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OpenAIChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package com.yomahub.liteflow.test.ai.model.tool.dashscope;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* 阿里百炼 工具调用 测试
|
||||
@@ -75,7 +76,7 @@ public class DashScopeToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithAutoToolCall() {
|
||||
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -87,24 +88,18 @@ public class DashScopeToolTest extends MockAITest {
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool"))
|
||||
);
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithAutoToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -158,13 +153,7 @@ public class DashScopeToolTest extends MockAITest {
|
||||
Assertions.assertEquals("weather_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = weatherTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -190,7 +179,7 @@ public class DashScopeToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithManualToolCall() {
|
||||
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL);
|
||||
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -202,24 +191,18 @@ public class DashScopeToolTest extends MockAITest {
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool"))
|
||||
);
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -231,13 +214,7 @@ public class DashScopeToolTest extends MockAITest {
|
||||
Assertions.assertEquals("assemble_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = assembleTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -247,24 +224,18 @@ public class DashScopeToolTest extends MockAITest {
|
||||
MockConfigHolder.clear();
|
||||
setupChatMock(ProviderEnum.DASHSCOPE, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future2.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build()
|
||||
);
|
||||
|
||||
response = future2.get();
|
||||
response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -311,44 +282,6 @@ public class DashScopeToolTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = DashScopeChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = DashScopeChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package com.yomahub.liteflow.test.ai.model.tool.ollama;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* Ollama 工具调用 测试
|
||||
@@ -72,7 +73,7 @@ public class OllamaToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithAutoToolCall() {
|
||||
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL,
|
||||
TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
@@ -83,23 +84,18 @@ public class OllamaToolTest extends MockAITest {
|
||||
ToolRegistry assembleTool = new StaticToolRegistry(
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithAutoToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.DN_JSON)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -137,13 +133,7 @@ public class OllamaToolTest extends MockAITest {
|
||||
Assertions.assertEquals("weather_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = weatherTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -168,7 +158,7 @@ public class OllamaToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithManualToolCall() {
|
||||
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL);
|
||||
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -178,23 +168,18 @@ public class OllamaToolTest extends MockAITest {
|
||||
ToolRegistry assembleTool = new StaticToolRegistry(
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.DN_JSON)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -206,13 +191,8 @@ public class OllamaToolTest extends MockAITest {
|
||||
Assertions.assertEquals("assemble_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = assembleTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -222,23 +202,18 @@ public class OllamaToolTest extends MockAITest {
|
||||
MockConfigHolder.clear();
|
||||
setupChatMock(ProviderEnum.OLLAMA, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.DN_JSON)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future2.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
response = future2.get();
|
||||
response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -269,44 +244,6 @@ public class OllamaToolTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OllamaChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OllamaChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package com.yomahub.liteflow.test.ai.model.tool.openai;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.enums.ProviderEnum;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.*;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.MessageType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.ToolMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
@@ -16,6 +19,7 @@ import com.yomahub.liteflow.test.ai.mock.MockAITest;
|
||||
import com.yomahub.liteflow.test.ai.mock.TestDataReader;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockConfigHolder;
|
||||
import com.yomahub.liteflow.test.ai.mock.mockbean.MockInteractClient;
|
||||
import io.reactivex.rxjava3.core.Flowable;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -23,9 +27,6 @@ import org.junit.jupiter.api.Test;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* OpenAI 工具调用 测试
|
||||
@@ -72,7 +73,7 @@ public class OpenAIToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithAutoToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithAutoToolCall() {
|
||||
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL,
|
||||
TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
@@ -83,23 +84,18 @@ public class OpenAIToolTest extends MockAITest {
|
||||
ToolRegistry assembleTool = new StaticToolRegistry(
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithAutoToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithAutoToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -137,13 +133,8 @@ public class OpenAIToolTest extends MockAITest {
|
||||
Assertions.assertEquals("weather_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = weatherTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -168,7 +159,7 @@ public class OpenAIToolTest extends MockAITest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolCallStreamingWithManualToolCall() throws ExecutionException, InterruptedException {
|
||||
public void testToolCallStreamingWithManualToolCall() {
|
||||
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL);
|
||||
|
||||
List<Message> messages = new ArrayList<>();
|
||||
@@ -178,23 +169,18 @@ public class OpenAIToolTest extends MockAITest {
|
||||
ToolRegistry assembleTool = new StaticToolRegistry(
|
||||
Collections.singletonList(toolRegistry.getTool("assemble_tool")));
|
||||
|
||||
final CompletableFuture<ChatResponse> future = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
Flowable<ChunkEvent> stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
ChatResponse response = future.get();
|
||||
ChatResponse response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.TOOL_CALL, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -206,13 +192,7 @@ public class OpenAIToolTest extends MockAITest {
|
||||
Assertions.assertEquals("assemble_tool", toolCall.getName());
|
||||
|
||||
// 执行 ToolCall
|
||||
ToolCallBack toolCallBack = assembleTool.getAllTools()
|
||||
.stream()
|
||||
.filter(tool -> Objects.equals(tool.getName(), toolCall.getName()))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new RuntimeException("工具未注册: " + toolCall.getName()));
|
||||
String toolResult = toolCallBack.call(toolCall.getArguments().toString());
|
||||
ToolMessage toolMessage = new ToolMessage(toolResult, toolCall.getId(), toolCall.getName());
|
||||
ToolMessage toolMessage = toolRegistry.executeToolCall(toolCall);
|
||||
|
||||
// 和 AI 消息一起添加回上下文
|
||||
messages.add(response.getOutput());
|
||||
@@ -222,23 +202,18 @@ public class OpenAIToolTest extends MockAITest {
|
||||
MockConfigHolder.clear();
|
||||
setupChatMock(ProviderEnum.OPENAI, TestDataReader.RequestType.STREAMING_TOOL_CALL_2);
|
||||
|
||||
final CompletableFuture<ChatResponse> future2 = new CompletableFuture<>();
|
||||
|
||||
chatModelWithManualToolCall.stream(
|
||||
stream = chatModelWithManualToolCall.stream(
|
||||
chatRequestBuilder
|
||||
.streaming(true)
|
||||
.transportType(TransportType.SSE)
|
||||
.messages(messages)
|
||||
// 工具调用配置
|
||||
.toolRegistry(assembleTool)
|
||||
// 工具调用配置
|
||||
.onFinal((chatResponse, context) -> {
|
||||
future2.complete(chatResponse);
|
||||
return chatResponse;
|
||||
})
|
||||
.build());
|
||||
.build()
|
||||
);
|
||||
|
||||
response = future2.get();
|
||||
response = stream.blockingLast()
|
||||
.getFinalResponse();
|
||||
|
||||
Assertions.assertEquals(FinishReason.STOP, response.getFinishReason());
|
||||
Assertions.assertEquals(MessageType.ASSISTANT, response.getOutput().getMessageType());
|
||||
@@ -271,44 +246,6 @@ public class OpenAIToolTest extends MockAITest {
|
||||
.interactClient(new MockInteractClient())
|
||||
.build();
|
||||
|
||||
chatRequestBuilder = OpenAIChatRequest.builder()
|
||||
.onStart(context -> System.out.println("chat start"))
|
||||
.onClose(context -> System.out.println("chat close"))
|
||||
.onError((context, t) -> {
|
||||
throw new RuntimeException(t);
|
||||
})
|
||||
.onText((content, context) -> {
|
||||
System.out.println("Received text: " + content);
|
||||
return content;
|
||||
})
|
||||
.onThinking((content, context) -> {
|
||||
System.out.println("Received thinking: " + content);
|
||||
return content;
|
||||
})
|
||||
.onCompletion((response, context) -> {
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
return response;
|
||||
});
|
||||
chatRequestBuilder = OpenAIChatRequest.builder();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package com.yomahub.liteflow.test.ai.model.util;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.ChunkEvent;
|
||||
import com.yomahub.liteflow.ai.engine.interact.chunk.StreamingProtocolChunk;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
|
||||
import io.reactivex.rxjava3.functions.Consumer;
|
||||
|
||||
public class StreamUtil {
|
||||
|
||||
public static Consumer<ChunkEvent> getChunkEventConsumer() {
|
||||
return chunkEvent -> {
|
||||
if (chunkEvent.isStart()) {
|
||||
System.out.println("chat start");
|
||||
}
|
||||
if (chunkEvent.isChunk()) {
|
||||
StreamingProtocolChunk chunk = chunkEvent.getTransformedChunk();
|
||||
switch (chunk.getType()) {
|
||||
case TEXT:
|
||||
System.out.println(chunk.getData());
|
||||
break;
|
||||
case THINKING:
|
||||
System.out.println("[Thinking] " + chunk.getData());
|
||||
}
|
||||
}
|
||||
if (chunkEvent.isComplete()) {
|
||||
ChatResponse response = chunkEvent.getFinalResponse();
|
||||
AssistantMessage message = response.getOutput();
|
||||
if (message.getContent() != null && !message.getContent().trim().isEmpty()) {
|
||||
System.out.println("内容长度: " + message.getContent().length());
|
||||
if (message.getContent().length() > 200) {
|
||||
System.out.println("内容预览: " + message.getContent().substring(0, 200) + "...");
|
||||
} else {
|
||||
System.out.println("内容: " + message.getContent());
|
||||
}
|
||||
}
|
||||
if (response.hasToolCalls()) {
|
||||
System.out.println("工具调用数量: " + message.getToolCalls().size());
|
||||
for (int i = 0; i < message.getToolCalls().size(); i++) {
|
||||
ToolCall toolCall = message.getToolCalls().get(i);
|
||||
System.out.println("工具调用 " + (i + 1) + ":");
|
||||
System.out.println(" ID: " + toolCall.getId());
|
||||
System.out.println(" 名称: " + toolCall.getName());
|
||||
System.out.println(" 类型: " + toolCall.getType());
|
||||
System.out.println(" 参数: " + toolCall.getArguments());
|
||||
}
|
||||
}
|
||||
System.out.println("Token使用情况: " + response.getTokenUsage());
|
||||
System.out.println("完成原因: " + response.getFinishReason());
|
||||
System.out.println("chat close");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -6,4 +6,4 @@ liteflow:
|
||||
openai:
|
||||
api-key: mock-api-key # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口
|
||||
dashscope:
|
||||
api-key: mock-api-key # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口
|
||||
api-key: sk-70c4b25e826e48868b470ff114419441 # 测试代码不需要配置,将使用 mock 数据,不会调用 AI 服务接口
|
||||
Reference in New Issue
Block a user