Feat: AI 意图识别节点多路选择初步实现

This commit is contained in:
LuanY77
2025-08-28 17:48:47 +08:00
parent 7720894361
commit ecc0db3e53
8 changed files with 73 additions and 14 deletions

View File

@@ -172,6 +172,7 @@ public abstract class AbstractAIComponentHandler<T extends Annotation> {
/**
* 获取拦截方法名称
*
* @param wrapBean 包装bean现在暂时不用后续如果单节点需要根据条件判断拦截方法可以用上
* @return 拦截方法名称
*/
protected abstract ElementMatcher<? super MethodDescription> getInterceptMethodName(AIProxyWrapBean<T> wrapBean);

View File

@@ -1,7 +1,7 @@
package com.yomahub.liteflow.ai.proxy.handler;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.domain.enums.AITypeEnum;
import com.yomahub.liteflow.ai.proxy.invocation.ClassifyAIInvocationHandler;
import com.yomahub.liteflow.ai.proxy.wrap.AIProxyWrapBean;
@@ -22,7 +22,6 @@ import java.lang.reflect.InvocationHandler;
public class ClassifyComponentHandler extends AbstractAIComponentHandler<AIClassify> {
private static final String INTERCEPT_SWITCH_METHOD_NAME = "processSwitch";
private static final String INTERCEPT_MULTI_SWITCH_METHOD_NAME = "processMultiSwitch";
@Override
public AITypeEnum getAIType() {
@@ -47,10 +46,6 @@ public class ClassifyComponentHandler extends AbstractAIComponentHandler<AIClass
@Override
protected ElementMatcher<? super MethodDescription> getInterceptMethodName(AIProxyWrapBean<AIClassify> wrapBean) {
if (wrapBean.getAnnotation().multiLabel()) {
return ElementMatchers.named(INTERCEPT_MULTI_SWITCH_METHOD_NAME);
} else {
return ElementMatchers.named(INTERCEPT_SWITCH_METHOD_NAME);
}
return ElementMatchers.named(INTERCEPT_SWITCH_METHOD_NAME);
}
}

View File

@@ -9,6 +9,8 @@ import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
import com.yomahub.liteflow.ai.proxy.wrap.ClassifyProxyWrapBean;
import com.yomahub.liteflow.ai.util.SetUtil;
import java.util.List;
/**
* 分类组件的调用处理器
*
@@ -33,6 +35,7 @@ public class ClassifyAIInvocationHandler extends AbstractAIInvocationHandler<Cla
}
}
@SuppressWarnings("unchecked")
@Override
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
ChatModel chatModel = ModelFactory.getChatModel(wrapBean);
@@ -40,7 +43,8 @@ public class ClassifyAIInvocationHandler extends AbstractAIInvocationHandler<Cla
ParsedClassifyAnnotationConfig annotationConfig = (ParsedClassifyAnnotationConfig) processorContext.getParsedAnnotationConfig();
// 如果是多标签则返回结构化转换的list, 单标签返回 String
if (annotationConfig.isMultiLabel()) {
return response.as(processorContext.getModelRequest().toChatRequest().getOutputParser());
List<String> resList = (List<String>) response.as(processorContext.getModelRequest().toChatRequest().getOutputParser());
return String.join(",", resList);
} else {
return response.getContent().getContent();
}

View File

@@ -78,7 +78,7 @@ public class OpenAIChatRequest extends ChatRequest {
ObjectNode responseFormat = ObjectMapperHolder.createObjectNode();
responseFormat.put("type", "json_schema");
ObjectNode jsonSchema = ObjectMapperHolder.createObjectNode();
jsonSchema.put("name", ((Class<?>) outputParser.getTargetType()).getSimpleName());
jsonSchema.put("name", outputParser.getTargetType().getTypeName());
jsonSchema.set("schema", outputParser.getJsonSchema());
responseFormat.set("json_schema", jsonSchema);

View File

@@ -33,5 +33,14 @@ public class ClassifyTest {
public void testClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain1", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>aiSwitch[aiSwitch]==>java", response.getExecuteStepStr());
}
@Test
public void testMultiClassify() {
LiteflowResponse response = flowExecutor.execute2Resp("chain2", null, ChatContext.class);
Assertions.assertTrue(response.isSuccess());
Assertions.assertEquals("a==>aiMultiSwitch[aiMultiSwitch]==>java==>python", response.getExecuteStepStr());
}
}

View File

@@ -15,8 +15,8 @@ import com.yomahub.liteflow.ai.util.TriState;
*/
@AIComponent(
nodeId = "ai",
nodeName = "ai",
nodeId = "aiSwitch",
nodeName = "aiSwitch",
// provider = "ollama",
// apiUrl = "http://localhost:11434",
// model = "qwen3:32b",
@@ -28,9 +28,9 @@ import com.yomahub.liteflow.ai.util.TriState;
connectTimeout = "10m"
)
@AIClassify(
systemPrompt = "你是一个意图识别高手,你需要识别用户的意图",
userPrompt = "{{question}}",
categories = {"java", "python"})
categories = {"java", "python"}
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码"),

View File

@@ -0,0 +1,46 @@
package com.yomahub.liteflow.test.ai.core.classify.cmp;
import com.yomahub.liteflow.ai.annotation.AIComponent;
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
import com.yomahub.liteflow.ai.util.TriState;
/**
* TODO
*
* @author 苍镜月
* @since TODO
*/
@AIComponent(
nodeId = "aiMultiSwitch",
nodeName = "aiMultiSwitch",
// provider = "ollama",
// apiUrl = "http://localhost:11434",
// model = "qwen3:32b",
provider = "openai",
apiUrl = "https://ark.cn-beijing.volces.com/api/v3",
model = "doubao-seed-1-6-250615",
enableThinking = TriState.FALSE,
readTimeout = "10m",
connectTimeout = "10m"
)
@AIClassify(
userPrompt = "{{question}}",
categories = {"java", "python"},
multiLabel = true
)
@AIInput(
mapping = {
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码, 同时给出 Python 代码"),
}
)
@AIOutput(
methodExpress = "setData",
useKeyIndex = true,
key = "result"
)
public interface AIMultiClassifyCmp {
}

View File

@@ -2,6 +2,10 @@
<!DOCTYPE flow PUBLIC "liteflow" "liteflow.dtd">
<flow>
<chain name="chain1">
THEN(a, SWITCH(ai).TO(java, python));
THEN(a, SWITCH(aiSwitch).TO(java, python));
</chain>
<chain name="chain2">
THEN(a, SWITCH(aiMultiSwitch).TO(java, python));
</chain>
</flow>