mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-11 12:26:52 +08:00
Feat(engine): 补充结构化输出相关实现
This commit is contained in:
@@ -10,6 +10,7 @@ import com.yomahub.liteflow.ai.engine.model.ModelRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
|
||||
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
|
||||
|
||||
@@ -102,7 +103,8 @@ public class ChatRequest implements ModelRequest {
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType
|
||||
TypeReference<?> targetType,
|
||||
boolean strict
|
||||
) {
|
||||
this.messages = messages;
|
||||
this.options = options;
|
||||
@@ -112,7 +114,7 @@ public class ChatRequest implements ModelRequest {
|
||||
this.resultHandler = resultHandler;
|
||||
this.chunkCallbackTransformer = chunkCallbackTransformer;
|
||||
this.responseType = responseType;
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType);
|
||||
this.outputParser = OutputParser.fromTypeReference(targetType, strict);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
}
|
||||
@@ -131,7 +133,7 @@ public class ChatRequest implements ModelRequest {
|
||||
this.resultHandler = builder.resultHandler;
|
||||
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
|
||||
this.responseType = builder.responseType;
|
||||
this.outputParser = OutputParser.fromTypeReference(builder.targetType);
|
||||
this.outputParser = OutputParser.fromTypeReference(builder.targetType, builder().strict);
|
||||
checkTransportConsistency();
|
||||
checkResponseTypeConsistency();
|
||||
}
|
||||
@@ -163,12 +165,21 @@ public class ChatRequest implements ModelRequest {
|
||||
@Override
|
||||
public RequestBody toRequestBody() {
|
||||
return RequestBody.of()
|
||||
.putIfNotEmpty(MESSAGES_KEY, messages)
|
||||
.putIfNotEmpty(MESSAGES_KEY, appendFormatInstructionsIfNeeded())
|
||||
// 默认流式,如果不需要流式输出,则设置为false
|
||||
.putIf(!streaming, STREAM_KEY, streaming)
|
||||
.merge(options.toRequestBody());
|
||||
}
|
||||
|
||||
/**
|
||||
* 对于某些模型,如果需要附加结构化输出提示词,可以在此方法中实现。
|
||||
*
|
||||
* @return 添加了结构化输出提示词的上下文
|
||||
*/
|
||||
protected List<Message> appendFormatInstructionsIfNeeded() {
|
||||
return this.messages;
|
||||
}
|
||||
|
||||
public List<Message> getMessages() {
|
||||
return messages;
|
||||
}
|
||||
@@ -229,6 +240,7 @@ public class ChatRequest implements ModelRequest {
|
||||
protected ResponseType responseType = ResponseType.TEXT;
|
||||
protected TypeReference<?> targetType = new TypeReference<String>() {
|
||||
};
|
||||
protected boolean strict = true;
|
||||
|
||||
public abstract B self();
|
||||
|
||||
@@ -421,6 +433,17 @@ public class ChatRequest implements ModelRequest {
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否为严格模式,默认开启
|
||||
*
|
||||
* @param strict 是否为严格模式
|
||||
* @see JsonSchemaGenerator#generate(Type, boolean)
|
||||
*/
|
||||
public B strict(boolean strict) {
|
||||
this.strict = strict;
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* 内部聚合类
|
||||
*/
|
||||
|
||||
@@ -27,7 +27,7 @@ public class OutputParser<T> {
|
||||
* @param reference TypeReference
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference) {
|
||||
return new OutputParser<>(reference.getType());
|
||||
return new OutputParser<>(reference.getType(), true);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,7 +36,7 @@ public class OutputParser<T> {
|
||||
* @param type 目标类型
|
||||
*/
|
||||
public static <T> OutputParser<T> fromType(Type type) {
|
||||
return new OutputParser<>(type);
|
||||
return new OutputParser<>(type, true);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -45,18 +45,44 @@ public class OutputParser<T> {
|
||||
* @param typeName 目标类型的全限定名
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeName(String typeName) {
|
||||
return new OutputParser<>(typeName);
|
||||
return new OutputParser<>(typeName, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 TypeReference 生成输出解析器
|
||||
*
|
||||
* @param reference TypeReference
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference, boolean strict) {
|
||||
return new OutputParser<>(reference.getType(), strict);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过 Type 生成输出解析器
|
||||
*
|
||||
* @param type 目标类型
|
||||
*/
|
||||
public static <T> OutputParser<T> fromType(Type type, boolean strict) {
|
||||
return new OutputParser<>(type, strict);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类型全限定名生成输出解析器
|
||||
*
|
||||
* @param typeName 目标类型的全限定名
|
||||
*/
|
||||
public static <T> OutputParser<T> fromTypeName(String typeName, boolean strict) {
|
||||
return new OutputParser<>(typeName, strict);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 通过类型生成输出解析器
|
||||
*
|
||||
* @param targetType 目标类型
|
||||
*/
|
||||
public OutputParser(Type targetType) {
|
||||
this.targetType = targetType;
|
||||
this.objectMapper = new ObjectMapper();
|
||||
this.jsonSchema = JsonSchemaGenerator.generate(targetType);
|
||||
this(targetType, true);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -65,9 +91,29 @@ public class OutputParser<T> {
|
||||
* @param typeName 目标类型的全限定名
|
||||
*/
|
||||
public OutputParser(String typeName) {
|
||||
this.targetType = JsonSchemaGenerator.typeFromString(typeName);
|
||||
this(JsonSchemaGenerator.typeFromString(typeName), true);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类型生成输出解析器
|
||||
*
|
||||
* @param targetType 目标类型
|
||||
* @param strict 严格模式
|
||||
*/
|
||||
public OutputParser(Type targetType, boolean strict) {
|
||||
this.targetType = targetType;
|
||||
this.objectMapper = new ObjectMapper();
|
||||
this.jsonSchema = JsonSchemaGenerator.generate(typeName);
|
||||
this.jsonSchema = JsonSchemaGenerator.generate(targetType, strict);
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过类全限定名生成输出解析器
|
||||
*
|
||||
* @param typeName 目标类型的全限定名
|
||||
* @param strict 严格模式
|
||||
*/
|
||||
public OutputParser(String typeName, boolean strict) {
|
||||
this(JsonSchemaGenerator.typeFromString(typeName), strict);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -139,6 +185,19 @@ public class OutputParser<T> {
|
||||
return jsonSchema;
|
||||
}
|
||||
|
||||
/**
|
||||
* JsonSchema
|
||||
*
|
||||
* @return JsonSchema
|
||||
*/
|
||||
public String getJsonSchemaString() {
|
||||
try {
|
||||
return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(this.jsonSchema);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取目标类型
|
||||
*
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
|
||||
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -16,11 +17,18 @@ import java.util.List;
|
||||
* Ollama 聊天请求体
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @see <a href=
|
||||
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
|
||||
* Completion API</a>
|
||||
* @since TODO
|
||||
*/
|
||||
|
||||
public class OllamaChatRequest extends ChatRequest {
|
||||
|
||||
// ==== RequestBody 相关参数 =====
|
||||
private static final String FORMAT_KEY = "format";
|
||||
// ==== RequestBody 相关参数 =====
|
||||
|
||||
public OllamaChatRequest() {
|
||||
super();
|
||||
}
|
||||
@@ -34,17 +42,24 @@ public class OllamaChatRequest extends ChatRequest {
|
||||
ResultHandler resultHandler,
|
||||
ChunkCallbackTransformer chunkCallbackTransformer,
|
||||
ResponseType responseType,
|
||||
TypeReference<?> targetType
|
||||
TypeReference<?> targetType,
|
||||
boolean strict
|
||||
) {
|
||||
super(messages, options, streaming, transportType,
|
||||
transportListener, resultHandler, chunkCallbackTransformer,
|
||||
responseType, targetType);
|
||||
responseType, targetType, strict);
|
||||
}
|
||||
|
||||
public OllamaChatRequest(Builder builder) {
|
||||
super(builder);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RequestBody toRequestBody() {
|
||||
return super.toRequestBody()
|
||||
.putIf(ResponseType.JSON.equals(this.responseType), FORMAT_KEY, outputParser.getJsonSchema());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void checkTransportConsistency() {
|
||||
super.checkTransportConsistency();
|
||||
|
||||
Reference in New Issue
Block a user