mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-12 16:21:05 +08:00
Feat(engine): 引入 TypeReference,ChatRequest与 ChatResponse 集成结构化输出能力
This commit is contained in:
@@ -4,6 +4,8 @@ import cn.hutool.core.util.ReflectUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.yomahub.liteflow.ai.annotation.AIOutput;
|
||||
import com.yomahub.liteflow.ai.annotation.OutputField;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.Response;
|
||||
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
|
||||
import com.yomahub.liteflow.ai.util.SetUtil;
|
||||
@@ -59,7 +61,13 @@ public class ContextAccessor {
|
||||
if (!(value instanceof Response)) {
|
||||
throw new LiteFlowAIException("AI node output value must be of type Response.");
|
||||
}
|
||||
value = ((Response<?>) value).getContent();
|
||||
// 如果 request 是 ChatRequest 且 value 是 ChatResponse,则尝试进行结构化转换
|
||||
if (context.getModelRequest() instanceof ChatRequest && value instanceof ChatResponse) {
|
||||
ChatRequest chatRequest = context.getModelRequest().toChatRequest();
|
||||
value = ((ChatResponse) value).as(chatRequest.getOutputParser());
|
||||
} else {
|
||||
value = ((Response<?>) value).getContent();
|
||||
}
|
||||
|
||||
NodeComponent nodeComponent = context.getNodeComponent();
|
||||
AIOutput outputAnno = context.getAiOutputAnno();
|
||||
|
||||
@@ -8,8 +8,12 @@ import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
|
||||
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.ModelRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
|
||||
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -60,6 +64,16 @@ public class ChatRequest implements ModelRequest {
|
||||
*/
|
||||
protected final ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
|
||||
/**
|
||||
* 响应类型,默认为文本类型
|
||||
*/
|
||||
protected final ResponseType responseType;
|
||||
|
||||
/**
|
||||
* 输出解析器,用于解析模型的输出结果,由 targetType 生成
|
||||
*/
|
||||
protected final OutputParser<?> outputParser;
|
||||
|
||||
// ==== RequestBody 相关参数 =====
|
||||
protected static final String MESSAGES_KEY = "messages";
|
||||
protected static final String STREAM_KEY = "stream";
|
||||
@@ -73,6 +87,10 @@ public class ChatRequest implements ModelRequest {
|
||||
this.transportListener = TransportListener.getDefault();
|
||||
this.resultHandler = ResultHandler.getDefault();
|
||||
this.chunkCallbackTransformer = ChunkCallbackTransformer.getDefault();
|
||||
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
|
||||
TypeReference<String> targetType = new TypeReference<String>() {
|
||||
}; // 默认目标类型为 String
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType);
|
||||
}
|
||||
|
||||
public ChatRequest(
|
||||
@@ -82,7 +100,9 @@ public class ChatRequest implements ModelRequest {
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType
|
||||
) {
|
||||
this.messages = messages;
|
||||
this.options = options;
|
||||
@@ -91,7 +111,10 @@ public class ChatRequest implements ModelRequest {
|
||||
this.transportListener = transportListener;
|
||||
this.resultHandler = resultHandler;
|
||||
this.chunkCallbackTransformer = chunkCallbackTransformer;
|
||||
this.responseType = responseType;
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -107,7 +130,10 @@ public class ChatRequest implements ModelRequest {
|
||||
this.transportListener = builder.transportListener;
|
||||
this.resultHandler = builder.resultHandler;
|
||||
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
|
||||
this.responseType = builder.responseType;
|
||||
this.outputParser = OutputParser.fromTypeReference(builder.targetType);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -117,9 +143,20 @@ public class ChatRequest implements ModelRequest {
|
||||
*/
|
||||
protected void checkTransportConsistency() {
|
||||
if (this.streaming && this.transportType == TransportType.HTTP) {
|
||||
throw new IllegalArgumentException("流式输出模式启用,但不支持HTTP传输。请使用SSE或WebSocket传输。");
|
||||
throw new IllegalArgumentException("Streaming mode is enabled, but HTTP transport is not supported. Please use SSE or WebSocket transport.");
|
||||
} else if (!this.streaming && this.transportType != TransportType.HTTP) {
|
||||
throw new IllegalArgumentException("阻塞式输出模式启用,但传输类型不支持HTTP。请使用HTTP传输。");
|
||||
throw new IllegalArgumentException("Blocking mode is enabled, but the transport type does not support HTTP. Please use HTTP transport.");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查响应类型与目标类型的一致性。
|
||||
* 如果响应类型为文本(TEXT),但目标类型不是 String,则抛出异常。
|
||||
*/
|
||||
protected void checkResponseTypeConsistency() {
|
||||
if (this.responseType == ResponseType.TEXT &&
|
||||
!Objects.equals("java.lang.String", getTargetType().getTypeName())) {
|
||||
throw new IllegalArgumentException("Response type is TEXT, but target type is not String. Please check the targetType setting.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +197,18 @@ public class ChatRequest implements ModelRequest {
|
||||
return chunkCallbackTransformer;
|
||||
}
|
||||
|
||||
public ResponseType getResponseType() {
|
||||
return responseType;
|
||||
}
|
||||
|
||||
public Type getTargetType() {
|
||||
return outputParser.getTargetType();
|
||||
}
|
||||
|
||||
public OutputParser<?> getOutputParser() {
|
||||
return outputParser;
|
||||
}
|
||||
|
||||
public void setResultHandler(ResultHandler resultHandler) {
|
||||
this.resultHandler = resultHandler;
|
||||
}
|
||||
@@ -177,6 +226,9 @@ public class ChatRequest implements ModelRequest {
|
||||
protected TransportListener transportListener;
|
||||
protected ResultHandler resultHandler;
|
||||
protected ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
protected ResponseType responseType = ResponseType.TEXT;
|
||||
protected TypeReference<?> targetType = new TypeReference<String>() {
|
||||
};
|
||||
|
||||
public abstract B self();
|
||||
|
||||
@@ -347,6 +399,28 @@ public class ChatRequest implements ModelRequest {
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置响应类型
|
||||
*
|
||||
* @param responseType 响应类型
|
||||
* @see ResponseType
|
||||
*/
|
||||
public B responseType(ResponseType responseType) {
|
||||
this.responseType = responseType;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置目标类型引用,用于指定响应体的具体类型
|
||||
*
|
||||
* @param targetType 目标类型引用
|
||||
* @see TypeReference
|
||||
*/
|
||||
public B targetType(TypeReference<?> targetType) {
|
||||
this.targetType = targetType;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部聚合类
|
||||
*/
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package com.yomahub.liteflow.ai.engine.model.chat.entity;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.Response;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.TokenUsage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
@@ -32,6 +36,46 @@ public class ChatResponse extends Response<AssistantMessage> {
|
||||
super(builder);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将响应内容转换为指定类型
|
||||
*
|
||||
* @param parser 目标类型解析器
|
||||
* @param <T> 目标类型
|
||||
* @return 转换后的对象
|
||||
*/
|
||||
public <T> T as(OutputParser<T> parser) {
|
||||
String rawTextContent = this.getContent().getContent();
|
||||
if (StrUtil.isBlank(rawTextContent)) {
|
||||
throw new IllegalStateException("Cannot convert empty content to target type: " + parser.getTargetType());
|
||||
}
|
||||
|
||||
return parser.convert(rawTextContent);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将响应内容转换为指定类型
|
||||
*
|
||||
* @param targetType 目标类型
|
||||
* @param <T> 目标类型
|
||||
* @return 转换后的对象
|
||||
*/
|
||||
public <T> T as(TypeReference<T> targetType) {
|
||||
OutputParser<T> parser = OutputParser.fromTypeReference(targetType);
|
||||
return this.as(parser);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将响应内容转换为指定类型
|
||||
*
|
||||
* @param targetType 目标类型
|
||||
* @param <T> 目标类型
|
||||
* @return 转换后的对象
|
||||
*/
|
||||
public <T> T as(Type targetType) {
|
||||
return this.as(new TypeReference<T>(targetType.getTypeName()) {
|
||||
});
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.yomahub.liteflow.ai.engine.model.output.structure;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
|
||||
import java.lang.reflect.ParameterizedType;
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
/**
|
||||
* 保存泛型信息,绕开 java 泛型擦除
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since TODO
|
||||
*/
|
||||
|
||||
public abstract class TypeReference<T> {
|
||||
|
||||
private final Type type;
|
||||
|
||||
/**
|
||||
* 构造函数,获取当前类的泛型类型
|
||||
*/
|
||||
protected TypeReference() {
|
||||
Type superClass = this.getClass().getGenericSuperclass();
|
||||
if (superClass instanceof Class) {
|
||||
throw new RuntimeException("TypeReference must be a parameterized type");
|
||||
} else {
|
||||
this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构造函数,使用类全限定名表示的类型名称
|
||||
*
|
||||
* @param typeName 类全限定名
|
||||
*/
|
||||
public TypeReference(String typeName) {
|
||||
this.type = JsonSchemaGenerator.typeFromString(typeName);
|
||||
}
|
||||
|
||||
public Type getType() {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import com.github.victools.jsonschema.module.jackson.JacksonModule;
|
||||
import com.github.victools.jsonschema.module.jackson.JacksonOption;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.Description;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.ParameterizedTypeImpl;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.ArrayList;
|
||||
@@ -70,6 +71,27 @@ public class JsonSchemaGenerator {
|
||||
return configBuilder;
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成指定类型的 JSON Schema (默认严格模式)
|
||||
*
|
||||
* @param typeReference 类型引用
|
||||
* @return 生成的 JSON Schema
|
||||
*/
|
||||
public static JsonNode generate(TypeReference<?> typeReference) {
|
||||
return generate(typeReference.getType(), true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成指定类型的 JSON Schema (默认严格模式)
|
||||
*
|
||||
* @param typeReference 类型引用
|
||||
* @param strict 是否为严格模式
|
||||
* @return 生成的 JSON Schema
|
||||
*/
|
||||
public static JsonNode generate(TypeReference<?> typeReference, boolean strict) {
|
||||
return generate(typeReference.getType(), strict);
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成指定类型的 JSON Schema (默认严格模式)
|
||||
*
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.yomahub.liteflow.ai.engine.model.output.structure.parser;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
@@ -20,6 +21,33 @@ public class OutputParser<T> {
|
||||
private final ObjectMapper objectMapper;
|
||||
private final JsonNode jsonSchema;
|
||||
|
||||
/**
|
||||
* 通过 TypeReference 生成输出解析器
|
||||
*
|
||||
* @param reference TypeReference
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference) {
|
||||
return new OutputParser<>(reference.getType());
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 Type 生成输出解析器
|
||||
*
|
||||
* @param type 目标类型
|
||||
*/
|
||||
public static <T> OutputParser<T> fromType(Type type) {
|
||||
return new OutputParser<>(type);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类型全限定名生成输出解析器
|
||||
*
|
||||
* @param typeName 目标类型的全限定名
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeName(String typeName) {
|
||||
return new OutputParser<>(typeName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类型生成输出解析器
|
||||
*
|
||||
@@ -88,7 +116,7 @@ public class OutputParser<T> {
|
||||
*/
|
||||
public String getOutputInstruction() {
|
||||
String template =
|
||||
"Your response should be in JSON format.\n" +
|
||||
"Your response should be in JSON format.\n" +
|
||||
"Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.\n" +
|
||||
"Do not include markdown code blocks in your response.\n" +
|
||||
"Remove the ```json markdown from the output.\n" +
|
||||
@@ -110,4 +138,13 @@ public class OutputParser<T> {
|
||||
public JsonNode getJsonSchema() {
|
||||
return jsonSchema;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取目标类型
|
||||
*
|
||||
* @return 目标类型
|
||||
*/
|
||||
public Type getTargetType() {
|
||||
return targetType;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,11 +7,13 @@ import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* TODO
|
||||
* Ollama 聊天请求体
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since TODO
|
||||
@@ -30,10 +32,13 @@ public class OllamaChatRequest extends ChatRequest {
|
||||
TransportType transportType,
|
||||
TransportListener transportListener,
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType
|
||||
) {
|
||||
super(messages, options, streaming, transportType,
|
||||
transportListener, resultHandler, chunkCallbackTransformer);
|
||||
transportListener, resultHandler, chunkCallbackTransformer,
|
||||
responseType, targetType);
|
||||
}
|
||||
|
||||
public OllamaChatRequest(Builder builder) {
|
||||
|
||||
@@ -2,12 +2,15 @@ package com.yomahub.liteflow.test.ai.engine.structure;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
import com.yomahub.liteflow.test.ai.engine.structure.param.Output;
|
||||
import com.yomahub.liteflow.test.ai.engine.structure.param.OutputWithRequiredFalse;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* JsonSchemaGeneratorTest
|
||||
*
|
||||
@@ -26,6 +29,43 @@ public class JsonSchemaGeneratorTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTypeReference() {
|
||||
TypeReference<Object> typeReference1 = new TypeReference<Object>("java.util.List<java.lang.String>") {
|
||||
};
|
||||
JsonNode node1 = JsonSchemaGenerator.generate(typeReference1);
|
||||
Assertions.assertEquals("java.util.List<java.lang.String>", typeReference1.getType().getTypeName());
|
||||
Assertions.assertEquals(
|
||||
"{\n" +
|
||||
" \"type\" : \"array\",\n" +
|
||||
" \"items\" : {\n" +
|
||||
" \"type\" : \"string\"\n" +
|
||||
" }\n" +
|
||||
"}",
|
||||
toPrettyJson(node1)
|
||||
);
|
||||
|
||||
TypeReference<List<String>> typeReference2 = new TypeReference<List<String>>() {
|
||||
};
|
||||
JsonNode node2 = JsonSchemaGenerator.generate(typeReference2);
|
||||
Assertions.assertEquals("java.util.List<java.lang.String>", typeReference2.getType().getTypeName());
|
||||
Assertions.assertEquals(
|
||||
"{\n" +
|
||||
" \"type\" : \"array\",\n" +
|
||||
" \"items\" : {\n" +
|
||||
" \"type\" : \"string\"\n" +
|
||||
" }\n" +
|
||||
"}",
|
||||
toPrettyJson(node2)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInvalidTypeReference() {
|
||||
Assertions.assertThrows(RuntimeException.class, () -> new TypeReference() {});
|
||||
Assertions.assertThrows(RuntimeException.class, () -> new TypeReference<Object>("invalid.type.name") {});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPrimitiveType() {
|
||||
JsonNode stringJson = JsonSchemaGenerator.generate(String.class);
|
||||
|
||||
Reference in New Issue
Block a user