Feat(core): 结构化输出

This commit is contained in:
LuanY77
2025-08-13 00:37:10 +08:00
parent 4ea8ac6ba1
commit 3528d724b0
14 changed files with 102 additions and 25 deletions

View File

@@ -99,8 +99,12 @@ public @interface AIOutput {
* <p>
* 表示输出的 JSON Schema 定义。如果需要添加描述信息,请使用{@link com.yomahub.liteflow.ai.engine.model.output.structure.Description}
*/
Class<?> entityClass() default String.class;
// String entityClass() default "java.lang.String";
String typeName() default "java.lang.String";
/**
* 是否严格模式(默认为 true输出模式为 JSON 时使用)
*/
boolean strict() default true;
/**
* 如需启用,请设置 {@link AIOutput#responseType()} 为 {@link ResponseType#JSON}

View File

@@ -19,7 +19,7 @@ public @interface OutputField {
/**
* 源字段名称(必需)。
* <p>
* 指定要从结构化输出对象 ({@link AIOutput#entityClass()}) 中读取的字段名。
* 指定要从结构化输出对象 ({@link AIOutput#typeName()}) 中读取的字段名。
*/
String sourceField();

View File

@@ -1,21 +1,28 @@
package com.yomahub.liteflow.ai.context;
import com.yomahub.liteflow.slot.DefaultContext;
import java.util.UUID;
/**
* chat 上下文
* chat 上下文, 用于存储 StreamHandler
* <p>
* 对于 StreamHandler 参数,可以通过流程参数传入,也可以通过 ChatContext 的构造函数传入
*
* @author 苍镜月
* @since TODO
*/
public class ChatContext {
public class ChatContext extends DefaultContext {
private String chatId;
private StreamHandler streamHandler;
public ChatContext() {}
public ChatContext() {
this.chatId = "chat_" + UUID.randomUUID();
this.streamHandler = null;
}
public ChatContext(StreamHandler streamHandler) {
this.chatId = "chat_" + UUID.randomUUID();

View File

@@ -14,7 +14,8 @@ public class ParsedAnnotationConfig {
protected String systemPrompt;
protected String userPrompt;
protected ResponseType responseType = ResponseType.TEXT;
protected Class<?> entityClass = String.class;
protected String typeName = "java.lang.String";
protected boolean strict = true;
public String getSystemPrompt() {
return systemPrompt;
@@ -28,8 +29,12 @@ public class ParsedAnnotationConfig {
return responseType;
}
public Class<?> getEntityClass() {
return entityClass;
public String getTypeName() {
return typeName;
}
public boolean isStrict() {
return strict;
}
public void setSystemPrompt(String systemPrompt) {
@@ -44,7 +49,11 @@ public class ParsedAnnotationConfig {
this.responseType = responseType;
}
public void setEntityClass(Class<?> entityClass) {
this.entityClass = entityClass;
public void setTypeName(String typeName) {
this.typeName = typeName;
}
public void setStrict(boolean strict) {
this.strict = strict;
}
}

View File

@@ -2,6 +2,7 @@ package com.yomahub.liteflow.ai.model;
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
@@ -47,6 +48,19 @@ public class ModelFactory {
LOG.info("Registered model provider: {}", provider.getProviderName());
}
/**
* 获取指定提供者名称的ChatRequest.Builder
*
* @param providerName 模型提供者名称
* @return ChatRequest.Builder实例
*/
public static ChatRequest.Builder<?> getChatRequestBuilder(String providerName) {
return MODEL_PROVIDER_MAP
.get(providerName)
.createChatRequestBuilder()
.orElseThrow(() -> new LiteFlowAIException("ChatRequest.Builder is not supported for provider: " + providerName));
}
/**
* 获取指定提供者名称的ChatModel实例
*

View File

@@ -2,6 +2,7 @@ package com.yomahub.liteflow.ai.model;
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
import java.util.Optional;
@@ -22,6 +23,15 @@ public interface ModelProvider {
*/
String getProviderName();
/**
* 创建ChatRequest构建器
*
* @return ChatRequest 的建造者
*/
default Optional<ChatRequest.Builder<?>> createChatRequestBuilder() {
return Optional.of(ChatRequest.builder());
}
/**
* 创建ChatModel实例
*

View File

@@ -118,12 +118,13 @@ public abstract class AbstractAnnotationProcessor<A extends Annotation, C extend
if (Objects.equals(ResponseType.JSON, outputAnno.responseType())) {
annotationConfig.setResponseType(ResponseType.JSON);
// 设置输出实体类
annotationConfig.setEntityClass(outputAnno.entityClass());
annotationConfig.setTypeName(outputAnno.typeName());
} else {
annotationConfig.setResponseType(ResponseType.TEXT);
// 文本输出,设置 entityClass 为 String
annotationConfig.setEntityClass(String.class);
annotationConfig.setTypeName("java.lang.String");
}
annotationConfig.setStrict(outputAnno.strict());
}
/**

View File

@@ -9,6 +9,8 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.model.ModelFactory;
import com.yomahub.liteflow.ai.util.SetUtil;
import java.util.ArrayList;
@@ -31,7 +33,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ChatRequest,
.map(ChatRequest::getOptions)
.orElse(ChatOptions.builder().build());
ChatRequest.Builder<?> builder = ChatRequest.builder();
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
// 1. 连接 StreamHandler 回调
StreamHandler streamHandler = context.getStreamHandler();
@@ -76,7 +78,18 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ChatRequest,
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getTransportType() : null, annotationConfig::getTransportType)
);
// TODO 4. 结构化输出
// 4. 结构化输出
builder.targetType(
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getTargetType() : null,
() -> new TypeReference(annotationConfig.getTypeName()) {
}.getType())
);
builder.responseType(
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getResponseType() : null, annotationConfig::getResponseType)
);
builder.strict(
Boolean.TRUE.equals(merge(() -> Objects.nonNull(contextRequest) ? contextRequest.isStrict() : null, annotationConfig::isStrict))
);
return builder.build();
}

View File

@@ -70,6 +70,11 @@ public class ChatRequest implements ModelRequest {
*/
protected final ResponseType responseType;
/**
* 是否为严格模式
*/
protected final boolean strict;
/**
* 输出解析器,用于解析模型的输出结果,由 targetType 生成
*/
@@ -91,6 +96,7 @@ public class ChatRequest implements ModelRequest {
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
TypeReference<String> targetType = new TypeReference<String>() {
}; // 默认目标类型为 String
this.strict = true;
this.outputParser = OutputParser.fromTypeReference(targetType);
}
@@ -114,6 +120,7 @@ public class ChatRequest implements ModelRequest {
this.resultHandler = resultHandler;
this.chunkCallbackTransformer = chunkCallbackTransformer;
this.responseType = responseType;
this.strict = strict;
this.outputParser = OutputParser.fromTypeReference(targetType, strict);
checkTransportConsistency();
checkResponseTypeConsistency();
@@ -133,7 +140,8 @@ public class ChatRequest implements ModelRequest {
this.resultHandler = builder.resultHandler;
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
this.responseType = builder.responseType;
this.outputParser = OutputParser.fromTypeReference(builder.targetType, builder().strict);
this.strict = builder.strict;
this.outputParser = OutputParser.fromType(builder.targetType, builder.strict);
checkTransportConsistency();
checkResponseTypeConsistency();
}
@@ -220,6 +228,10 @@ public class ChatRequest implements ModelRequest {
return outputParser;
}
public boolean isStrict() {
return strict;
}
public void setResultHandler(ResultHandler resultHandler) {
this.resultHandler = resultHandler;
}
@@ -238,8 +250,7 @@ public class ChatRequest implements ModelRequest {
protected ResultHandler resultHandler;
protected ChunkCallbackTransformer chunkCallbackTransformer;
protected ResponseType responseType = ResponseType.TEXT;
protected TypeReference<?> targetType = new TypeReference<String>() {
};
protected Type targetType = String.class;
protected boolean strict = true;
public abstract B self();
@@ -428,7 +439,7 @@ public class ChatRequest implements ModelRequest {
* @param targetType 目标类型引用
* @see TypeReference
*/
public B targetType(TypeReference<?> targetType) {
public B targetType(Type targetType) {
this.targetType = targetType;
return self();
}

View File

@@ -66,6 +66,8 @@ public class JsonSchemaGenerator {
});
// 全部属性都应当为 required
// https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=chat#all-fields-must-be-required
// 详细可见 OpenAI 官方说明,翻译成人话就是,都得加上 Required但是如果非 strict那么 type 允许加个 null
configBuilder.forFields().withRequiredCheck(field -> true);
return configBuilder;

View File

@@ -2,10 +2,12 @@ package com.yomahub.liteflow.ai.model.ollama.model;
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
import com.yomahub.liteflow.ai.model.ModelProviderRegistrar;
import com.yomahub.liteflow.ai.model.ollama.constants.OllamaConstant;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
import java.util.Optional;
@@ -20,6 +22,11 @@ import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
public class OllamaModelProvider extends ModelProviderRegistrar {
@Override
public Optional<ChatRequest.Builder<?>> createChatRequestBuilder() {
return Optional.of(OllamaChatRequest.builder());
}
@Override
public Optional<ChatModel> createChatModel(ModelConfigAggregator configAggregator) {
return Optional.of(OllamaChatModel.builder())

View File

@@ -4,7 +4,6 @@ import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.core.FlowExecutor;
import com.yomahub.liteflow.flow.LiteflowResponse;
import com.yomahub.liteflow.slot.DefaultContext;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -32,7 +31,7 @@ public class ChatTest {
@Test
public void testBlockingChat() {
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain1", null, ChatContext.class, DefaultContext.class);
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain1", null, ChatContext.class);
Assertions.assertTrue(liteflowResponse.isSuccess());
}
@@ -64,7 +63,7 @@ public class ChatTest {
ChatContext chatContext = new ChatContext(streamHandler);
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain2", null, chatContext, DefaultContext.class);
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain2", null, chatContext);
Assertions.assertTrue(liteflowResponse.isSuccess());
}
}

View File

@@ -34,8 +34,8 @@ import com.yomahub.liteflow.ai.util.TriState;
}
)
@AIOutput(
responseType = ResponseType.TEXT,
entityClass = Output.class,
responseType = ResponseType.JSON,
typeName = "com.yomahub.liteflow.test.ai.core.proxy.cmp.Output",
methodExpress = "setData",
useKeyIndex = true,
key = "result"

View File

@@ -35,7 +35,7 @@ import com.yomahub.liteflow.ai.util.TriState;
)
@AIOutput(
responseType = ResponseType.TEXT,
entityClass = Output.class,
typeName = "com.yomahub.liteflow.test.ai.core.proxy.cmp.Output",
methodExpress = "setData",
useKeyIndex = true,
key = "result"