test: 意图识别节点测试 + bug fix

This commit is contained in:
LuanY77
2025-09-17 21:18:36 +08:00
parent c04b44120d
commit 58eb5a26ed
15 changed files with 244 additions and 45 deletions

View File

@@ -46,7 +46,7 @@ public class ClassifyAIInvocationHandler extends AbstractAIInvocationHandler<Cla
List<String> resList = (List<String>) response.as(processorContext.getModelRequest().toChatRequest().getOutputParser());
return String.join(",", resList);
} else {
return response.getOutput().getContent();
return response.getOutput().getContentWithoutThink();
}
}
}

View File

@@ -34,7 +34,6 @@ public class ChatOptions implements ModelOptions {
protected Boolean enableThinking;
// ==== RequestBody 相关参数 =====
// TODO options 重构
protected static final String TEMPERATURE_KEY = "temperature";
protected static final String TOP_P_KEY = "top_p";
protected static final String TOP_K_KEY = "top_k";

View File

@@ -45,7 +45,7 @@ public class ChatResponse extends Response<AssistantMessage> {
* @return 转换后的对象
*/
public <T> T as(JsonSchemaParser<T> parser) {
String rawTextContent = this.getOutput().getContent();
String rawTextContent = this.getOutput().getContentWithoutThink();
if (StrUtil.isBlank(rawTextContent)) {
throw new IllegalStateException("Cannot convert empty content to target type: " + parser.getTargetType());
}

View File

@@ -1,5 +1,6 @@
package com.yomahub.liteflow.ai.engine.model.chat.message;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.yomahub.liteflow.ai.engine.tool.ToolCall;
@@ -38,4 +39,17 @@ public class AssistantMessage extends AbstractMessage {
public void setToolCalls(List<ToolCall> toolCalls) {
this.toolCalls = toolCalls;
}
/**
* 获取不包含 <think> 标签内容的消息内容
*
* @return 清理后的消息内容
*/
public String getContentWithoutThink() {
String content = this.getContent();
if (StrUtil.isBlank(content)) {
return "";
}
return content.replaceAll("(?s)<think>.*?</think>", "").trim();
}
}

View File

@@ -138,6 +138,10 @@ public class DashScopeProtocolTransformer implements ProtocolTransformer {
// 解析 AI 消息内容
String content = extractContentFromMessage(message);
String thinking = extractThinkingFromMessage(message);
if (StrUtil.isNotBlank(thinking)) {
content = "<think>\n" + thinking + "\n</think>\n" + (StrUtil.isNotBlank(content) ? content : "");
}
AssistantMessage assistantMessage = new AssistantMessage(content, toolCalls);
// 解析 Token 使用情况
@@ -208,6 +212,11 @@ public class DashScopeProtocolTransformer implements ProtocolTransformer {
return messageJson.path("content").asText(null);
}
private String extractThinkingFromMessage(JsonNode messageJson) {
if (Objects.isNull(messageJson) || !messageJson.has("reasoning_content")) return null;
return messageJson.path("reasoning_content").asText(null);
}
private String extractContentFromDelta(JsonNode deltaJson) {
if (Objects.isNull(deltaJson) || !deltaJson.has("content")) return null;
return deltaJson.path("content").asText(null);

View File

@@ -27,7 +27,7 @@ import java.util.List;
public class OllamaChatRequest extends ChatRequest {
// ==== RequestBody 相关参数 =====
private static final String THINKING_KEY = "thinking";
private static final String THINKING_KEY = "think";
private static final String FORMAT_KEY = "format";
// ==== RequestBody 相关参数 =====
@@ -59,7 +59,7 @@ public class OllamaChatRequest extends ChatRequest {
@Override
public RequestBody toRequestBody() {
return super.toRequestBody()
.remove("enableThinking")
.remove("enable_thinking")
.put(THINKING_KEY, this.options.getEnableThinking())
.putIf(ResponseType.JSON.equals(this.responseType), FORMAT_KEY, outputParser.getJsonSchema());
}

View File

@@ -138,6 +138,10 @@ public class OpenAIProtocolTransformer implements ProtocolTransformer {
// 解析 AI 消息内容
String content = extractContentFromMessage(message);
String thinking = extractThinkingFromMessage(message);
if (StrUtil.isNotBlank(thinking)) {
content = "<think>\n" + thinking + "\n</think>\n" + (StrUtil.isNotBlank(content) ? content : "");
}
AssistantMessage assistantMessage = new AssistantMessage(content, toolCalls);
// 解析 Token 使用情况
@@ -208,6 +212,11 @@ public class OpenAIProtocolTransformer implements ProtocolTransformer {
return messageJson.path("content").asText(null);
}
private String extractThinkingFromMessage(JsonNode messageJson) {
if (Objects.isNull(messageJson) || !messageJson.has("reasoning_content")) return null;
return messageJson.path("reasoning_content").asText(null);
}
private String extractContentFromDelta(JsonNode deltaJson) {
if (Objects.isNull(deltaJson) || !deltaJson.has("content")) return null;
return deltaJson.path("content").asText(null);

View File

@@ -14,33 +14,63 @@ import org.springframework.test.context.TestPropertySource;
import javax.annotation.Resource;
/**
* TODO
* 分类节点测试类
*
* @author 苍镜月
* @since TODO
*/
@TestPropertySource(properties = {"spring.config.location=classpath:core/classify/application.yaml"})
@SpringBootTest(classes = {ClassifyTest.class, SpringUtil.class})
@TestPropertySource(properties = { "spring.config.location=classpath:core/classify/application.yaml" })
@SpringBootTest(classes = { ClassifyTest.class, SpringUtil.class })
@EnableAutoConfiguration
@ComponentScan({"com.yomahub.liteflow.test.ai.core.classify.cmp"})
@ComponentScan({ "com.yomahub.liteflow.test.ai.core.classify.cmp" })
public class ClassifyTest {
@Resource
private FlowExecutor flowExecutor;
@Test
public void testClassify() {
public void testOpenAIClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain1", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>aiSwitch[aiSwitch]==>java", response.getExecuteStepStr());
Assertions.assertEquals("a==>openaiSwitch[openaiSwitch]==>java", response.getExecuteStepStr());
}
@Test
public void testMultiClassify() {
public void testOpenAIMultiClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain2", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>aiMultiSwitch[aiMultiSwitch]==>java==>python", response.getExecuteStepStr());
Assertions.assertEquals("a==>openaiMultiSwitch[openaiMultiSwitch]==>java==>python",
response.getExecuteStepStr());
}
@Test
public void testDashScopeClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain3", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>dashscopeSwitch[dashscopeSwitch]==>java", response.getExecuteStepStr());
}
@Test
public void testDashScopeMultiClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain4", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>dashscopeMultiSwitch[dashscopeMultiSwitch]==>java==>python",
response.getExecuteStepStr());
}
@Test
public void testOllamaClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain5", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>ollamaSwitch[ollamaSwitch]==>java", response.getExecuteStepStr());
}
@Test
public void testOllamaMultiClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain6", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>ollamaMultiSwitch[ollamaMultiSwitch]==>java==>python",
response.getExecuteStepStr());
}
}

View File

@@ -0,0 +1,35 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp.dashscope;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
@AIComponent(
nodeId = "dashscopeSwitch",
nodeName = "dashscopeSwitch",
provider = "dashscope",
apiUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1",
model = "qwen-flash",
enableThinking = TriState.TRUE,
readTimeout = "10m",
connectTimeout = "10m"
)
@AIClassify(
userPrompt = "{{question}}",
categories = {"java", "python"}
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码"),
}
)
@AIOutput(
methodExpress = "setData",
useKeyIndex = true,
key = "result"
)
public interface DashScopeClassifyCmp {
}

View File

@@ -0,0 +1,36 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp.dashscope;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
@AIComponent(
nodeId = "dashscopeMultiSwitch",
nodeName = "dashscopeMultiSwitch",
provider = "dashscope",
apiUrl = "https://dashscope.aliyuncs.com/compatible-mode/v1",
model = "qwen-flash",
enableThinking = TriState.FALSE,
readTimeout = "10m",
connectTimeout = "10m"
)
@AIClassify(
userPrompt = "{{question}}",
categories = {"java", "python"},
multiLabel = true
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码, 同时给出 Python 代码"),
}
)
@AIOutput(
methodExpress = "setData",
useKeyIndex = true,
key = "result"
)
public interface DashScopeMultiClassifyCmp {
}

View File

@@ -0,0 +1,35 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp.ollama;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
@AIComponent(
nodeId = "ollamaSwitch",
nodeName = "ollamaSwitch",
provider = "ollama",
apiUrl = "http://localhost:11434",
model = "qwen3:32b",
enableThinking = TriState.FALSE,
readTimeout = "10m",
connectTimeout = "10m"
)
@AIClassify(
userPrompt = "{{question}}",
categories = {"java", "python"}
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码"),
}
)
@AIOutput(
methodExpress = "setData",
useKeyIndex = true,
key = "result"
)
public interface OllamaClassifyCmp {
}

View File

@@ -0,0 +1,36 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp.ollama;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
@AIComponent(
nodeId = "ollamaMultiSwitch",
nodeName = "ollamaMultiSwitch",
provider = "ollama",
apiUrl = "http://localhost:11434",
model = "qwen3:32b",
enableThinking = TriState.FALSE,
readTimeout = "10m",
connectTimeout = "10m"
)
@AIClassify(
userPrompt = "{{question}}",
categories = {"java", "python"},
multiLabel = true
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码, 同时给出 Python 代码"),
}
)
@AIOutput(
methodExpress = "setData",
useKeyIndex = true,
key = "result"
)
public interface OllamaMultiClassifyCmp {
}

View File

@@ -1,4 +1,4 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp;
package com.yomahub.liteflow.test.ai.core.classify.cmp.openai;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
@@ -7,23 +7,13 @@ import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
/**
* TODO
*
* @author 苍镜月
* @since TODO
*/
@AIComponent(
nodeId = "aiSwitch",
nodeName = "aiSwitch",
// provider = "ollama",
// apiUrl = "http://localhost:11434",
// model = "qwen3:32b",
nodeId = "openaiSwitch",
nodeName = "openaiSwitch",
provider = "openai",
apiUrl = "https://ark.cn-beijing.volces.com/api/v3",
model = "doubao-seed-1-6-250615",
enableThinking = TriState.FALSE,
enableThinking = TriState.TRUE,
readTimeout = "10m",
connectTimeout = "10m"
)
@@ -41,5 +31,5 @@ import com.yomahub.liteflow.ai.util.TriState;
useKeyIndex = true,
key = "result"
)
public interface AIClassifyCmp {
public interface OpenAIClassifyCmp {
}

View File

@@ -1,4 +1,4 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp;
package com.yomahub.liteflow.test.ai.core.classify.cmp.openai;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
@@ -7,23 +7,13 @@ import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
/**
* TODO
*
* @author 苍镜月
* @since TODO
*/
@AIComponent(
nodeId = "aiMultiSwitch",
nodeName = "aiMultiSwitch",
// provider = "ollama",
// apiUrl = "http://localhost:11434",
// model = "qwen3:32b",
nodeId = "openaiMultiSwitch",
nodeName = "openaiMultiSwitch",
provider = "openai",
apiUrl = "https://ark.cn-beijing.volces.com/api/v3",
model = "doubao-seed-1-6-250615",
enableThinking = TriState.FALSE,
enableThinking = TriState.TRUE,
readTimeout = "10m",
connectTimeout = "10m"
)
@@ -42,5 +32,5 @@ import com.yomahub.liteflow.ai.util.TriState;
useKeyIndex = true,
key = "result"
)
public interface AIMultiClassifyCmp {
public interface OpenAIMultiClassifyCmp {
}

View File

@@ -2,10 +2,26 @@
<!DOCTYPE flow PUBLIC "liteflow" "liteflow.dtd">
<flow>
<chain name="chain1">
THEN(a, SWITCH(aiSwitch).TO(java, python));
THEN(a, SWITCH(openaiSwitch).TO(java, python));
</chain>
<chain name="chain2">
THEN(a, SWITCH(aiMultiSwitch).TO(java, python));
THEN(a, SWITCH(openaiMultiSwitch).TO(java, python));
</chain>
<chain name="chain3">
THEN(a, SWITCH(dashscopeSwitch).TO(java, python));
</chain>
<chain name="chain4">
THEN(a, SWITCH(dashscopeMultiSwitch).TO(java, python));
</chain>
<chain name="chain5">
THEN(a, SWITCH(ollamaSwitch).TO(java, python));
</chain>
<chain name="chain6">
THEN(a, SWITCH(ollamaMultiSwitch).TO(java, python));
</chain>
</flow>