mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-10 11:17:00 +08:00
Feat(core): 结构化输出
This commit is contained in:
@@ -99,8 +99,12 @@ public @interface AIOutput {
|
||||
* <p>
|
||||
* 表示输出的 JSON Schema 定义。如果需要添加描述信息,请使用{@link com.yomahub.liteflow.ai.engine.model.output.structure.Description}
|
||||
*/
|
||||
Class<?> entityClass() default String.class;
|
||||
// String entityClass() default "java.lang.String";
|
||||
String typeName() default "java.lang.String";
|
||||
|
||||
/**
|
||||
* 是否严格模式(默认为 true,输出模式为 JSON 时使用)
|
||||
*/
|
||||
boolean strict() default true;
|
||||
|
||||
/**
|
||||
* 如需启用,请设置 {@link AIOutput#responseType()} 为 {@link ResponseType#JSON}
|
||||
|
||||
@@ -19,7 +19,7 @@ public @interface OutputField {
|
||||
/**
|
||||
* 源字段名称(必需)。
|
||||
* <p>
|
||||
* 指定要从结构化输出对象 ({@link AIOutput#entityClass()}) 中读取的字段名。
|
||||
* 指定要从结构化输出对象 ({@link AIOutput#typeName()}) 中读取的字段名。
|
||||
*/
|
||||
String sourceField();
|
||||
|
||||
|
||||
@@ -1,21 +1,28 @@
|
||||
package com.yomahub.liteflow.ai.context;
|
||||
|
||||
import com.yomahub.liteflow.slot.DefaultContext;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
* chat 上下文
|
||||
* chat 上下文, 用于存储 StreamHandler
|
||||
* <p>
|
||||
* 对于 StreamHandler 参数,可以通过流程参数传入,也可以通过 ChatContext 的构造函数传入
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since TODO
|
||||
*/
|
||||
|
||||
public class ChatContext {
|
||||
public class ChatContext extends DefaultContext {
|
||||
|
||||
private String chatId;
|
||||
|
||||
private StreamHandler streamHandler;
|
||||
|
||||
public ChatContext() {}
|
||||
public ChatContext() {
|
||||
this.chatId = "chat_" + UUID.randomUUID();
|
||||
this.streamHandler = null;
|
||||
}
|
||||
|
||||
public ChatContext(StreamHandler streamHandler) {
|
||||
this.chatId = "chat_" + UUID.randomUUID();
|
||||
|
||||
@@ -14,7 +14,8 @@ public class ParsedAnnotationConfig {
|
||||
protected String systemPrompt;
|
||||
protected String userPrompt;
|
||||
protected ResponseType responseType = ResponseType.TEXT;
|
||||
protected Class<?> entityClass = String.class;
|
||||
protected String typeName = "java.lang.String";
|
||||
protected boolean strict = true;
|
||||
|
||||
public String getSystemPrompt() {
|
||||
return systemPrompt;
|
||||
@@ -28,8 +29,12 @@ public class ParsedAnnotationConfig {
|
||||
return responseType;
|
||||
}
|
||||
|
||||
public Class<?> getEntityClass() {
|
||||
return entityClass;
|
||||
public String getTypeName() {
|
||||
return typeName;
|
||||
}
|
||||
|
||||
public boolean isStrict() {
|
||||
return strict;
|
||||
}
|
||||
|
||||
public void setSystemPrompt(String systemPrompt) {
|
||||
@@ -44,7 +49,11 @@ public class ParsedAnnotationConfig {
|
||||
this.responseType = responseType;
|
||||
}
|
||||
|
||||
public void setEntityClass(Class<?> entityClass) {
|
||||
this.entityClass = entityClass;
|
||||
public void setTypeName(String typeName) {
|
||||
this.typeName = typeName;
|
||||
}
|
||||
|
||||
public void setStrict(boolean strict) {
|
||||
this.strict = strict;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.yomahub.liteflow.ai.model;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
|
||||
import com.yomahub.liteflow.ai.exception.LiteFlowAIException;
|
||||
import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
@@ -47,6 +48,19 @@ public class ModelFactory {
|
||||
LOG.info("Registered model provider: {}", provider.getProviderName());
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定提供者名称的ChatRequest.Builder
|
||||
*
|
||||
* @param providerName 模型提供者名称
|
||||
* @return ChatRequest.Builder实例
|
||||
*/
|
||||
public static ChatRequest.Builder<?> getChatRequestBuilder(String providerName) {
|
||||
return MODEL_PROVIDER_MAP
|
||||
.get(providerName)
|
||||
.createChatRequestBuilder()
|
||||
.orElseThrow(() -> new LiteFlowAIException("ChatRequest.Builder is not supported for provider: " + providerName));
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取指定提供者名称的ChatModel实例
|
||||
*
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.yomahub.liteflow.ai.model;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
|
||||
|
||||
import java.util.Optional;
|
||||
@@ -22,6 +23,15 @@ public interface ModelProvider {
|
||||
*/
|
||||
String getProviderName();
|
||||
|
||||
/**
|
||||
* 创建ChatRequest构建器
|
||||
*
|
||||
* @return ChatRequest 的建造者
|
||||
*/
|
||||
default Optional<ChatRequest.Builder<?>> createChatRequestBuilder() {
|
||||
return Optional.of(ChatRequest.builder());
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建ChatModel实例
|
||||
*
|
||||
|
||||
@@ -118,12 +118,13 @@ public abstract class AbstractAnnotationProcessor<A extends Annotation, C extend
|
||||
if (Objects.equals(ResponseType.JSON, outputAnno.responseType())) {
|
||||
annotationConfig.setResponseType(ResponseType.JSON);
|
||||
// 设置输出实体类
|
||||
annotationConfig.setEntityClass(outputAnno.entityClass());
|
||||
annotationConfig.setTypeName(outputAnno.typeName());
|
||||
} else {
|
||||
annotationConfig.setResponseType(ResponseType.TEXT);
|
||||
// 文本输出,设置 entityClass 为 String
|
||||
annotationConfig.setEntityClass(String.class);
|
||||
annotationConfig.setTypeName("java.lang.String");
|
||||
}
|
||||
annotationConfig.setStrict(outputAnno.strict());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -9,6 +9,8 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.SystemMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.UserMessage;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.model.ModelFactory;
|
||||
import com.yomahub.liteflow.ai.util.SetUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -31,7 +33,7 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ChatRequest,
|
||||
.map(ChatRequest::getOptions)
|
||||
.orElse(ChatOptions.builder().build());
|
||||
|
||||
ChatRequest.Builder<?> builder = ChatRequest.builder();
|
||||
ChatRequest.Builder<?> builder = ModelFactory.getChatRequestBuilder(config.getProvider());
|
||||
|
||||
// 1. 连接 StreamHandler 回调
|
||||
StreamHandler streamHandler = context.getStreamHandler();
|
||||
@@ -76,7 +78,18 @@ public class ChatRequestAssembler extends AbstractRequestAssembler<ChatRequest,
|
||||
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getTransportType() : null, annotationConfig::getTransportType)
|
||||
);
|
||||
|
||||
// TODO 4. 结构化输出
|
||||
// 4. 结构化输出
|
||||
builder.targetType(
|
||||
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getTargetType() : null,
|
||||
() -> new TypeReference(annotationConfig.getTypeName()) {
|
||||
}.getType())
|
||||
);
|
||||
builder.responseType(
|
||||
merge(() -> Objects.nonNull(contextRequest) ? contextRequest.getResponseType() : null, annotationConfig::getResponseType)
|
||||
);
|
||||
builder.strict(
|
||||
Boolean.TRUE.equals(merge(() -> Objects.nonNull(contextRequest) ? contextRequest.isStrict() : null, annotationConfig::isStrict))
|
||||
);
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@@ -70,6 +70,11 @@ public class ChatRequest implements ModelRequest {
|
||||
*/
|
||||
protected final ResponseType responseType;
|
||||
|
||||
/**
|
||||
* 是否为严格模式
|
||||
*/
|
||||
protected final boolean strict;
|
||||
|
||||
/**
|
||||
* 输出解析器,用于解析模型的输出结果,由 targetType 生成
|
||||
*/
|
||||
@@ -91,6 +96,7 @@ public class ChatRequest implements ModelRequest {
|
||||
this.responseType = ResponseType.TEXT; // 默认响应类型为文本
|
||||
TypeReference<String> targetType = new TypeReference<String>() {
|
||||
}; // 默认目标类型为 String
|
||||
this.strict = true;
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType);
|
||||
}
|
||||
|
||||
@@ -114,6 +120,7 @@ public class ChatRequest implements ModelRequest {
|
||||
this.resultHandler = resultHandler;
|
||||
this.chunkCallbackTransformer = chunkCallbackTransformer;
|
||||
this.responseType = responseType;
|
||||
this.strict = strict;
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType, strict);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
@@ -133,7 +140,8 @@ public class ChatRequest implements ModelRequest {
|
||||
this.resultHandler = builder.resultHandler;
|
||||
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
|
||||
this.responseType = builder.responseType;
|
||||
this.outputParser = OutputParser.fromTypeReference(builder.targetType, builder().strict);
|
||||
this.strict = builder.strict;
|
||||
this.outputParser = OutputParser.fromType(builder.targetType, builder.strict);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
}
|
||||
@@ -220,6 +228,10 @@ public class ChatRequest implements ModelRequest {
|
||||
return outputParser;
|
||||
}
|
||||
|
||||
public boolean isStrict() {
|
||||
return strict;
|
||||
}
|
||||
|
||||
public void setResultHandler(ResultHandler resultHandler) {
|
||||
this.resultHandler = resultHandler;
|
||||
}
|
||||
@@ -238,8 +250,7 @@ public class ChatRequest implements ModelRequest {
|
||||
protected ResultHandler resultHandler;
|
||||
protected ChunkCallbackTransformer chunkCallbackTransformer;
|
||||
protected ResponseType responseType = ResponseType.TEXT;
|
||||
protected TypeReference<?> targetType = new TypeReference<String>() {
|
||||
};
|
||||
protected Type targetType = String.class;
|
||||
protected boolean strict = true;
|
||||
|
||||
public abstract B self();
|
||||
@@ -428,7 +439,7 @@ public class ChatRequest implements ModelRequest {
|
||||
* @param targetType 目标类型引用
|
||||
* @see TypeReference
|
||||
*/
|
||||
public B targetType(TypeReference<?> targetType) {
|
||||
public B targetType(Type targetType) {
|
||||
this.targetType = targetType;
|
||||
return self();
|
||||
}
|
||||
|
||||
@@ -66,6 +66,8 @@ public class JsonSchemaGenerator {
|
||||
});
|
||||
|
||||
// 全部属性都应当为 required
|
||||
// https://platform.openai.com/docs/guides/structured-outputs/supported-schemas?api-mode=chat#all-fields-must-be-required
|
||||
// 详细可见 OpenAI 官方说明,翻译成人话就是,都得加上 Required,但是如果非 strict,那么 type 允许加个 null
|
||||
configBuilder.forFields().withRequiredCheck(field -> true);
|
||||
|
||||
return configBuilder;
|
||||
|
||||
@@ -2,10 +2,12 @@ package com.yomahub.liteflow.ai.model.ollama.model;
|
||||
|
||||
import com.yomahub.liteflow.ai.domain.dto.ModelConfigAggregator;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.ChatModel;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.embedding.EmbeddingModel;
|
||||
import com.yomahub.liteflow.ai.model.ModelProviderRegistrar;
|
||||
import com.yomahub.liteflow.ai.model.ollama.constants.OllamaConstant;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatModel;
|
||||
import com.yomahub.liteflow.ai.model.ollama.model.chat.OllamaChatRequest;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
@@ -20,6 +22,11 @@ import static com.yomahub.liteflow.ai.util.SetUtil.setIfPresent;
|
||||
|
||||
public class OllamaModelProvider extends ModelProviderRegistrar {
|
||||
|
||||
@Override
|
||||
public Optional<ChatRequest.Builder<?>> createChatRequestBuilder() {
|
||||
return Optional.of(OllamaChatRequest.builder());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<ChatModel> createChatModel(ModelConfigAggregator configAggregator) {
|
||||
return Optional.of(OllamaChatModel.builder())
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.core.FlowExecutor;
|
||||
import com.yomahub.liteflow.flow.LiteflowResponse;
|
||||
import com.yomahub.liteflow.slot.DefaultContext;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
||||
@@ -32,7 +31,7 @@ public class ChatTest {
|
||||
|
||||
@Test
|
||||
public void testBlockingChat() {
|
||||
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain1", null, ChatContext.class, DefaultContext.class);
|
||||
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain1", null, ChatContext.class);
|
||||
Assertions.assertTrue(liteflowResponse.isSuccess());
|
||||
}
|
||||
|
||||
@@ -64,7 +63,7 @@ public class ChatTest {
|
||||
|
||||
ChatContext chatContext = new ChatContext(streamHandler);
|
||||
|
||||
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain2", null, chatContext, DefaultContext.class);
|
||||
LiteflowResponse liteflowResponse = flowExecutor.execute2Resp("chain2", null, chatContext);
|
||||
Assertions.assertTrue(liteflowResponse.isSuccess());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,8 +34,8 @@ import com.yomahub.liteflow.ai.util.TriState;
|
||||
}
|
||||
)
|
||||
@AIOutput(
|
||||
responseType = ResponseType.TEXT,
|
||||
entityClass = Output.class,
|
||||
responseType = ResponseType.JSON,
|
||||
typeName = "com.yomahub.liteflow.test.ai.core.proxy.cmp.Output",
|
||||
methodExpress = "setData",
|
||||
useKeyIndex = true,
|
||||
key = "result"
|
||||
|
||||
@@ -35,7 +35,7 @@ import com.yomahub.liteflow.ai.util.TriState;
|
||||
)
|
||||
@AIOutput(
|
||||
responseType = ResponseType.TEXT,
|
||||
entityClass = Output.class,
|
||||
typeName = "com.yomahub.liteflow.test.ai.core.proxy.cmp.Output",
|
||||
methodExpress = "setData",
|
||||
useKeyIndex = true,
|
||||
key = "result"
|
||||
|
||||
Reference in New Issue
Block a user