Feat(engine): 补充结构化输出相关实现

This commit is contained in:
LuanY77
2025-08-12 23:54:48 +08:00
parent 9b4d6b6890
commit 4ea8ac6ba1
3 changed files with 111 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ import com.yomahub.liteflow.ai.engine.model.ModelRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import com.yomahub.liteflow.ai.engine.model.output.structure.parser.OutputParser;
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
@@ -102,7 +103,8 @@ public class ChatRequest implements ModelRequest {
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType
TypeReference<?> targetType,
boolean strict
) {
this.messages = messages;
this.options = options;
@@ -112,7 +114,7 @@ public class ChatRequest implements ModelRequest {
this.resultHandler = resultHandler;
this.chunkCallbackTransformer = chunkCallbackTransformer;
this.responseType = responseType;
this.outputParser = OutputParser.fromTypeReference(targetType);
this.outputParser = OutputParser.fromTypeReference(targetType, strict);
checkTransportConsistency();
checkResponseTypeConsistency();
}
@@ -131,7 +133,7 @@ public class ChatRequest implements ModelRequest {
this.resultHandler = builder.resultHandler;
this.chunkCallbackTransformer = builder.chunkCallbackTransformer;
this.responseType = builder.responseType;
this.outputParser = OutputParser.fromTypeReference(builder.targetType);
this.outputParser = OutputParser.fromTypeReference(builder.targetType, builder().strict);
checkTransportConsistency();
checkResponseTypeConsistency();
}
@@ -163,12 +165,21 @@ public class ChatRequest implements ModelRequest {
@Override
public RequestBody toRequestBody() {
return RequestBody.of()
.putIfNotEmpty(MESSAGES_KEY, messages)
.putIfNotEmpty(MESSAGES_KEY, appendFormatInstructionsIfNeeded())
// 默认流式如果不需要流式输出则设置为false
.putIf(!streaming, STREAM_KEY, streaming)
.merge(options.toRequestBody());
}
/**
* 对于某些模型,如果需要附加结构化输出提示词,可以在此方法中实现。
*
* @return 添加了结构化输出提示词的上下文
*/
protected List<Message> appendFormatInstructionsIfNeeded() {
return this.messages;
}
public List<Message> getMessages() {
return messages;
}
@@ -229,6 +240,7 @@ public class ChatRequest implements ModelRequest {
protected ResponseType responseType = ResponseType.TEXT;
protected TypeReference<?> targetType = new TypeReference<String>() {
};
protected boolean strict = true;
public abstract B self();
@@ -421,6 +433,17 @@ public class ChatRequest implements ModelRequest {
return self();
}
/**
* 是否为严格模式,默认开启
*
* @param strict 是否为严格模式
* @see JsonSchemaGenerator#generate(Type, boolean)
*/
public B strict(boolean strict) {
this.strict = strict;
return self();
}
/**
* 内部聚合类
*/

View File

@@ -27,7 +27,7 @@ public class OutputParser<T> {
* @param reference TypeReference
*/
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference) {
return new OutputParser<>(reference.getType());
return new OutputParser<>(reference.getType(), true);
}
/**
@@ -36,7 +36,7 @@ public class OutputParser<T> {
* @param type 目标类型
*/
public static <T> OutputParser<T> fromType(Type type) {
return new OutputParser<>(type);
return new OutputParser<>(type, true);
}
/**
@@ -45,18 +45,44 @@ public class OutputParser<T> {
* @param typeName 目标类型的全限定名
*/
public static <T> OutputParser<T> fromTypeName(String typeName) {
return new OutputParser<>(typeName);
return new OutputParser<>(typeName, true);
}
/**
* 通过 TypeReference 生成输出解析器
*
* @param reference TypeReference
*/
public static <T> OutputParser<T> fromTypeReference(TypeReference<T> reference, boolean strict) {
return new OutputParser<>(reference.getType(), strict);
}
/**
* 通过 Type 生成输出解析器
*
* @param type 目标类型
*/
public static <T> OutputParser<T> fromType(Type type, boolean strict) {
return new OutputParser<>(type, strict);
}
/**
* 通过类型全限定名生成输出解析器
*
* @param typeName 目标类型的全限定名
*/
public static <T> OutputParser<T> fromTypeName(String typeName, boolean strict) {
return new OutputParser<>(typeName, strict);
}
/**
* 通过类型生成输出解析器
*
* @param targetType 目标类型
*/
public OutputParser(Type targetType) {
this.targetType = targetType;
this.objectMapper = new ObjectMapper();
this.jsonSchema = JsonSchemaGenerator.generate(targetType);
this(targetType, true);
}
/**
@@ -65,9 +91,29 @@ public class OutputParser<T> {
* @param typeName 目标类型的全限定名
*/
public OutputParser(String typeName) {
this.targetType = JsonSchemaGenerator.typeFromString(typeName);
this(JsonSchemaGenerator.typeFromString(typeName), true);
}
/**
* 通过类型生成输出解析器
*
* @param targetType 目标类型
* @param strict 严格模式
*/
public OutputParser(Type targetType, boolean strict) {
this.targetType = targetType;
this.objectMapper = new ObjectMapper();
this.jsonSchema = JsonSchemaGenerator.generate(typeName);
this.jsonSchema = JsonSchemaGenerator.generate(targetType, strict);
}
/**
* 通过类全限定名生成输出解析器
*
* @param typeName 目标类型的全限定名
* @param strict 严格模式
*/
public OutputParser(String typeName, boolean strict) {
this(JsonSchemaGenerator.typeFromString(typeName), strict);
}
/**
@@ -139,6 +185,19 @@ public class OutputParser<T> {
return jsonSchema;
}
/**
* JsonSchema
*
* @return JsonSchema
*/
public String getJsonSchemaString() {
try {
return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(this.jsonSchema);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
/**
* 获取目标类型
*

View File

@@ -9,6 +9,7 @@ import com.yomahub.liteflow.ai.engine.model.chat.entity.ChatRequest;
import com.yomahub.liteflow.ai.engine.model.chat.message.Message;
import com.yomahub.liteflow.ai.engine.model.output.ResponseType;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.util.request.RequestBody;
import java.util.List;
@@ -16,11 +17,18 @@ import java.util.List;
* Ollama 聊天请求体
*
* @author 苍镜月
* @see <a href=
* "https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion">Chat
* Completion API</a>
* @since TODO
*/
public class OllamaChatRequest extends ChatRequest {
// ==== RequestBody 相关参数 =====
private static final String FORMAT_KEY = "format";
// ==== RequestBody 相关参数 =====
public OllamaChatRequest() {
super();
}
@@ -34,17 +42,24 @@ public class OllamaChatRequest extends ChatRequest {
ResultHandler resultHandler,
ChunkCallbackTransformer chunkCallbackTransformer,
ResponseType responseType,
TypeReference<?> targetType
TypeReference<?> targetType,
boolean strict
) {
super(messages, options, streaming, transportType,
transportListener, resultHandler, chunkCallbackTransformer,
responseType, targetType);
responseType, targetType, strict);
}
public OllamaChatRequest(Builder builder) {
super(builder);
}
@Override
public RequestBody toRequestBody() {
return super.toRequestBody()
.putIf(ResponseType.JSON.equals(this.responseType), FORMAT_KEY, outputParser.getJsonSchema());
}
@Override
protected void checkTransportConsistency() {
super.checkTransportConsistency();