Feat(engine): 引入 TypeReference,ChatRequest与 ChatResponse 集成结构化输出能力

This commit is contained in:
LuanY77
2025-08-12 15:27:36 +08:00
parent 5ddede46d7
commit 9b4d6b6890
8 changed files with 281 additions and 8 deletions

View File

@@ -4,6 +4,8 @@ import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.yomahub.liteflow.ai.annotation.AIOutput;
import com.yomahub.liteflow.ai.annotation.OutputField;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatResponse;
import com.yomahub.liteflow.ai.engine.model.output.Response;
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
import com.yomahub.liteflow.ai.util.SetUtil;
@@ -59,7 +61,13 @@ public class ContextAccessor {
if (!(value instanceof Response)) {
throw new LiteFlowAIException("AI node output value must be of type Response.");
}
value = ((Response<?>) value).getContent();
// 如果 request 是 ChatRequest 且 value 是 ChatResponse则尝试进行结构化转换
if (context.getModelRequest() instanceof ChatRequest && value instanceof ChatResponse) {
ChatRequest chatRequest = context.getModelRequest().toChatRequest();
value = ((ChatResponse) value).as(chatRequest.getOutputParser());
} else {
value = ((Response<?>) value).getContent();
}
NodeComponent nodeComponent = context.getNodeComponent();
AIOutput outputAnno = context.getAiOutputAnno();

View File

@@ -8,8 +8,12 @@ import com.yomahub.liteflow.ai.engine.interact.transport.TransportListener;
import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.ModelRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
@@ -60,6 +64,16 @@ public class ChatRequest implements ModelRequest {
*/
protected final ChunkCallbackTransformer chunkCallbackTransformer;
/**
* 响应类型,默认为文本类型
*/
protected final ResponseType responseType;
/**
* 输出解析器,用于解析模型的输出结果,由 targetType 生成
*/
protected final OutputParser<?> outputParser;
// ==== RequestBody 相关参数 =====
protected static final String MESSAGES_KEY = "messages";
protected static final String STREAM_KEY = "stream";
@@ -73,6 +87,10 @@ public class ChatRequest implements ModelRequest {
this.transportListener = TransportListener.getDefault();
this.resultHandler = ResultHandler.getDefault();
this.chunkCallbackTransformer = ChunkCallbackTransformer.getDefault();
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
TypeReference<String> targetType = new TypeReference<String>() {
}; // 默认目标类型为 String
this.outputParser = OutputParser.fromTypeReference(targetType);
}
public ChatRequest(
@@ -82,7 +100,9 @@ public class ChatRequest implements ModelRequest {
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType
) {
this.messages = messages;
this.options = options;
@@ -91,7 +111,10 @@ public class ChatRequest implements ModelRequest {
this.transportListener = transportListener;
this.resultHandler = resultHandler;
this.chunkCallbackTransformer = chunkCallbackTransformer;
this.responseType = responseType;
this.outputParser = OutputParser.fromTypeReference(targetType);
checkTransportConsistency();
checkResponseTypeConsistency();
}
/**
@@ -107,7 +130,10 @@ public class ChatRequest implements ModelRequest {
this.transportListener = builder.transportListener;
this.resultHandler = builder.resultHandler;
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
this.responseType = builder.responseType;
this.outputParser = OutputParser.fromTypeReference(builder.targetType);
checkTransportConsistency();
checkResponseTypeConsistency();
}
/**
@@ -117,9 +143,20 @@ public class ChatRequest implements ModelRequest {
*/
protected void checkTransportConsistency() {
if (this.streaming && this.transportType == TransportType.HTTP) {
throw new IllegalArgumentException("流式输出模式启用但不支持HTTP传输。请使用SSE或WebSocket传输。");
throw new IllegalArgumentException("Streaming mode is enabled, but HTTP transport is not supported. Please use SSE or WebSocket transport.");
} else if (!this.streaming && this.transportType != TransportType.HTTP) {
throw new IllegalArgumentException("阻塞式输出模式启用但传输类型不支持HTTP。请使用HTTP传输。");
throw new IllegalArgumentException("Blocking mode is enabled, but the transport type does not support HTTP. Please use HTTP transport.");
}
}
/**
* 检查响应类型与目标类型的一致性。
* 如果响应类型为文本TEXT但目标类型不是 String则抛出异常。
*/
protected void checkResponseTypeConsistency() {
if (this.responseType == ResponseType.TEXT &&
!Objects.equals("java.lang.String", getTargetType().getTypeName())) {
throw new IllegalArgumentException("Response type is TEXT, but target type is not String. Please check the targetType setting.");
}
}
@@ -160,6 +197,18 @@ public class ChatRequest implements ModelRequest {
return chunkCallbackTransformer;
}
public ResponseType getResponseType() {
return responseType;
}
public Type getTargetType() {
return outputParser.getTargetType();
}
public OutputParser<?> getOutputParser() {
return outputParser;
}
public void setResultHandler(ResultHandler resultHandler) {
this.resultHandler = resultHandler;
}
@@ -177,6 +226,9 @@ public class ChatRequest implements ModelRequest {
protected TransportListener transportListener;
protected ResultHandler resultHandler;
protected ChunkCallbackTransformer chunkCallbackTransformer;
protected ResponseType responseType = ResponseType.TEXT;
protected TypeReference<?> targetType = new TypeReference<String>() {
};
public abstract B self();
@@ -347,6 +399,28 @@ public class ChatRequest implements ModelRequest {
return self();
}
/**
* 设置响应类型
*
* @param responseType 响应类型
* @see ResponseType
*/
public B responseType(ResponseType responseType) {
this.responseType = responseType;
return self();
}
/**
* 设置目标类型引用,用于指定响应体的具体类型
*
* @param targetType 目标类型引用
* @see TypeReference
*/
public B targetType(TypeReference<?> targetType) {
this.targetType = targetType;
return self();
}
/**
* 内部聚合类
*/

View File

@@ -1,10 +1,14 @@
package com.yomahub.liteflow.ai.engine.model.chat.entity;
import cn.hutool.core.util.StrUtil;
import com.yomahub.liteflow.ai.engine.model.chat.message.AssistantMessage;
import com.yomahub.liteflow.ai.engine.model.output.FinishReason;
import com.yomahub.liteflow.ai.engine.model.output.Response;
import com.yomahub.liteflow.ai.engine.model.output.TokenUsage;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
import java.lang.reflect.Type;
import java.util.Map;
/**
@@ -32,6 +36,46 @@ public class ChatResponse extends Response<AssistantMessage> {
super(builder);
}
/**
* 将响应内容转换为指定类型
*
* @param parser 目标类型解析器
* @param <T> 目标类型
* @return 转换后的对象
*/
public <T> T as(OutputParser<T> parser) {
String rawTextContent = this.getContent().getContent();
if (StrUtil.isBlank(rawTextContent)) {
throw new IllegalStateException("Cannot convert empty content to target type: " + parser.getTargetType());
}
return parser.convert(rawTextContent);
}
/**
* 将响应内容转换为指定类型
*
* @param targetType 目标类型
* @param <T> 目标类型
* @return 转换后的对象
*/
public <T> T as(TypeReference<T> targetType) {
OutputParser<T> parser = OutputParser.fromTypeReference(targetType);
return this.as(parser);
}
/**
* 将响应内容转换为指定类型
*
* @param targetType 目标类型
* @param <T> 目标类型
* @return 转换后的对象
*/
public <T> T as(Type targetType) {
return this.as(new TypeReference<T>(targetType.getTypeName()) {
});
}
public static Builder builder() {
return new Builder();
}

View File

@@ -0,0 +1,43 @@
package com.yomahub.liteflow.ai.engine.model.output.structure;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
/**
* 保存泛型信息,绕开 java 泛型擦除
*
* @author 苍镜月
* @since TODO
*/
public abstract class TypeReference<T> {
private final Type type;
/**
* 构造函数,获取当前类的泛型类型
*/
protected TypeReference() {
Type superClass = this.getClass().getGenericSuperclass();
if (superClass instanceof Class) {
throw new RuntimeException("TypeReference must be a parameterized type");
} else {
this.type = ((ParameterizedType) superClass).getActualTypeArguments()[0];
}
}
/**
* 构造函数,使用类全限定名表示的类型名称
*
* @param typeName 类全限定名
*/
public TypeReference(String typeName) {
this.type = JsonSchemaGenerator.typeFromString(typeName);
}
public Type getType() {
return type;
}
}

View File

@@ -6,6 +6,7 @@ import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;
import com.yomahub.liteflow.ai.engine.model.output.structure.Description;
import com.yomahub.liteflow.ai.engine.model.output.structure.ParameterizedTypeImpl;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import java.lang.reflect.Type;
import java.util.ArrayList;
@@ -70,6 +71,27 @@ public class JsonSchemaGenerator {
return configBuilder;
}
/**
* 生成指定类型的 JSON Schema (默认严格模式)
*
* @param typeReference 类型引用
* @return 生成的 JSON Schema
*/
public static JsonNode generate(TypeReference<?> typeReference) {
return generate(typeReference.getType(), true);
}
/**
* 生成指定类型的 JSON Schema (默认严格模式)
*
* @param typeReference 类型引用
* @param strict 是否为严格模式
* @return 生成的 JSON Schema
*/
public static JsonNode generate(TypeReference<?> typeReference, boolean strict) {
return generate(typeReference.getType(), strict);
}
/**
* 生成指定类型的 JSON Schema (默认严格模式)
*

View File

@@ -3,6 +3,7 @@ package com.yomahub.liteflow.ai.engine.model.output.structure.parser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import java.lang.reflect.Type;
@@ -20,6 +21,33 @@ public class OutputParser<T> {
private final ObjectMapper objectMapper;
private final JsonNode jsonSchema;
/**
* 通过 TypeReference 生成输出解析器
*
* @param reference TypeReference
*/
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference) {
return new OutputParser<>(reference.getType());
}
/**
* 通过 Type 生成输出解析器
*
* @param type 目标类型
*/
public static <T> OutputParser<T> fromType(Type type) {
return new OutputParser<>(type);
}
/**
* 通过类型全限定名生成输出解析器
*
* @param typeName 目标类型的全限定名
*/
public static <T> OutputParser<T> fromTypeName(String typeName) {
return new OutputParser<>(typeName);
}
/**
* 通过类型生成输出解析器
*
@@ -88,7 +116,7 @@ public class OutputParser<T> {
*/
public String getOutputInstruction() {
String template =
"Your response should be in JSON format.\n" +
"Your response should be in JSON format.\n" +
"Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.\n" +
"Do not include markdown code blocks in your response.\n" +
"Remove the ```json markdown from the output.\n" +
@@ -110,4 +138,13 @@ public class OutputParser<T> {
public JsonNode getJsonSchema() {
return jsonSchema;
}
/**
* 获取目标类型
*
* @return 目标类型
*/
public Type getTargetType() {
return targetType;
}
}

View File

@@ -7,11 +7,13 @@ import com.yomahub.liteflow.ai.engine.interact.transport.TransportType;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatOptions;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import java.util.List;
/**
* TODO
* Ollama 聊天请求体
*
* @author 苍镜月
* @since TODO
@@ -30,10 +32,13 @@ public class OllamaChatRequest extends ChatRequest {
TransportType transportType,
TransportListener transportListener,
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType
) {
super(messages, options, streaming, transportType,
transportListener, resultHandler, chunkCallbackTransformer);
transportListener, resultHandler, chunkCallbackTransformer,
responseType, targetType);
}
public OllamaChatRequest(Builder builder) {

View File

@@ -2,12 +2,15 @@ package com.yomahub.liteflow.test.ai.engine.structure;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import com.yomahub.liteflow.test.ai.engine.structure.param.Output;
import com.yomahub.liteflow.test.ai.engine.structure.param.OutputWithRequiredFalse;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
/**
* JsonSchemaGeneratorTest
*
@@ -26,6 +29,43 @@ public class JsonSchemaGeneratorTest {
}
}
@Test
public void testTypeReference() {
TypeReference<Object> typeReference1 = new TypeReference<Object>("java.util.List<java.lang.String>") {
};
JsonNode node1 = JsonSchemaGenerator.generate(typeReference1);
Assertions.assertEquals("java.util.List<java.lang.String>", typeReference1.getType().getTypeName());
Assertions.assertEquals(
"{\n" +
" \"type\" : \"array\",\n" +
" \"items\" : {\n" +
" \"type\" : \"string\"\n" +
" }\n" +
"}",
toPrettyJson(node1)
);
TypeReference<List<String>> typeReference2 = new TypeReference<List<String>>() {
};
JsonNode node2 = JsonSchemaGenerator.generate(typeReference2);
Assertions.assertEquals("java.util.List<java.lang.String>", typeReference2.getType().getTypeName());
Assertions.assertEquals(
"{\n" +
" \"type\" : \"array\",\n" +
" \"items\" : {\n" +
" \"type\" : \"string\"\n" +
" }\n" +
"}",
toPrettyJson(node2)
);
}
@Test
public void testInvalidTypeReference() {
Assertions.assertThrows(RuntimeException.class, () -> new TypeReference() {});
Assertions.assertThrows(RuntimeException.class, () -> new TypeReference<Object>("invalid.type.name") {});
}
@Test
public void testPrimitiveType() {
JsonNode stringJson = JsonSchemaGenerator.generate(String.class);