mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-10 11:17:00 +08:00
Feat: AI 意图识别节点多路选择初步实现
This commit is contained in:
@@ -172,6 +172,7 @@ public abstract class AbstractAIComponentHandler<T extends Annotation> {
|
||||
/**
|
||||
* 获取拦截方法名称
|
||||
*
|
||||
* @param wrapBean 包装bean,现在暂时不用,后续如果单节点需要根据条件判断拦截方法,可以用上
|
||||
* @return 拦截方法名称
|
||||
*/
|
||||
protected abstract ElementMatcher<? super MethodDescription> getInterceptMethodName(AIProxyWrapBean<T> wrapBean);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.yomahub.liteflow.ai.proxy.handler;
|
||||
|
||||
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
|
||||
import com.yomahub.liteflow.ai.annotation.AIComponent;
|
||||
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
|
||||
import com.yomahub.liteflow.ai.domain.enums.AITypeEnum;
|
||||
import com.yomahub.liteflow.ai.proxy.invocation.ClassifyAIInvocationHandler;
|
||||
import com.yomahub.liteflow.ai.proxy.wrap.AIProxyWrapBean;
|
||||
@@ -22,7 +22,6 @@ import java.lang.reflect.InvocationHandler;
|
||||
public class ClassifyComponentHandler extends AbstractAIComponentHandler<AIClassify> {
|
||||
|
||||
private static final String INTERCEPT_SWITCH_METHOD_NAME = "processSwitch";
|
||||
private static final String INTERCEPT_MULTI_SWITCH_METHOD_NAME = "processMultiSwitch";
|
||||
|
||||
@Override
|
||||
public AITypeEnum getAIType() {
|
||||
@@ -47,10 +46,6 @@ public class ClassifyComponentHandler extends AbstractAIComponentHandler<AIClass
|
||||
|
||||
@Override
|
||||
protected ElementMatcher<? super MethodDescription> getInterceptMethodName(AIProxyWrapBean<AIClassify> wrapBean) {
|
||||
if (wrapBean.getAnnotation().multiLabel()) {
|
||||
return ElementMatchers.named(INTERCEPT_MULTI_SWITCH_METHOD_NAME);
|
||||
} else {
|
||||
return ElementMatchers.named(INTERCEPT_SWITCH_METHOD_NAME);
|
||||
}
|
||||
return ElementMatchers.named(INTERCEPT_SWITCH_METHOD_NAME);
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,8 @@ import com.yomahub.liteflow.ai.parse.context.ProcessorContext;
|
||||
import com.yomahub.liteflow.ai.proxy.wrap.ClassifyProxyWrapBean;
|
||||
import com.yomahub.liteflow.ai.util.SetUtil;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 分类组件的调用处理器
|
||||
*
|
||||
@@ -33,6 +35,7 @@ public class ClassifyAIInvocationHandler extends AbstractAIInvocationHandler<Cla
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
protected Object doExecuteAIProcess(ProcessorContext<?> processorContext, Object[] args) {
|
||||
ChatModel chatModel = ModelFactory.getChatModel(wrapBean);
|
||||
@@ -40,7 +43,8 @@ public class ClassifyAIInvocationHandler extends AbstractAIInvocationHandler<Cla
|
||||
ParsedClassifyAnnotationConfig annotationConfig = (ParsedClassifyAnnotationConfig) processorContext.getParsedAnnotationConfig();
|
||||
// 如果是多标签则返回结构化转换的list, 单标签返回 String
|
||||
if (annotationConfig.isMultiLabel()) {
|
||||
return response.as(processorContext.getModelRequest().toChatRequest().getOutputParser());
|
||||
List<String> resList = (List<String>) response.as(processorContext.getModelRequest().toChatRequest().getOutputParser());
|
||||
return String.join(",", resList);
|
||||
} else {
|
||||
return response.getContent().getContent();
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ public class OpenAIChatRequest extends ChatRequest {
|
||||
ObjectNode responseFormat = ObjectMapperHolder.createObjectNode();
|
||||
responseFormat.put("type", "json_schema");
|
||||
ObjectNode jsonSchema = ObjectMapperHolder.createObjectNode();
|
||||
jsonSchema.put("name", ((Class<?>) outputParser.getTargetType()).getSimpleName());
|
||||
jsonSchema.put("name", outputParser.getTargetType().getTypeName());
|
||||
jsonSchema.set("schema", outputParser.getJsonSchema());
|
||||
|
||||
responseFormat.set("json_schema", jsonSchema);
|
||||
|
||||
@@ -33,5 +33,14 @@ public class ClassifyTest {
|
||||
public void testClassify() {
|
||||
LiteflowResponse response = flowExecutor.execute2Resp("chain1", null, ChatContext.class);
|
||||
Assertions.assertTrue(response.isSuccess());
|
||||
Assertions.assertEquals("a==>aiSwitch[aiSwitch]==>java", response.getExecuteStepStr());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiClassify() {
|
||||
LiteflowResponse response = flowExecutor.execute2Resp("chain2", null, ChatContext.class);
|
||||
Assertions.assertTrue(response.isSuccess());
|
||||
Assertions.assertEquals("a==>aiMultiSwitch[aiMultiSwitch]==>java==>python", response.getExecuteStepStr());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ import com.yomahub.liteflow.ai.util.TriState;
|
||||
*/
|
||||
|
||||
@AIComponent(
|
||||
nodeId = "ai",
|
||||
nodeName = "ai",
|
||||
nodeId = "aiSwitch",
|
||||
nodeName = "aiSwitch",
|
||||
// provider = "ollama",
|
||||
// apiUrl = "http://localhost:11434",
|
||||
// model = "qwen3:32b",
|
||||
@@ -28,9 +28,9 @@ import com.yomahub.liteflow.ai.util.TriState;
|
||||
connectTimeout = "10m"
|
||||
)
|
||||
@AIClassify(
|
||||
systemPrompt = "你是一个意图识别高手,你需要识别用户的意图",
|
||||
userPrompt = "{{question}}",
|
||||
categories = {"java", "python"})
|
||||
categories = {"java", "python"}
|
||||
)
|
||||
@AIInput(
|
||||
mapping = {
|
||||
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码"),
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.yomahub.liteflow.test.ai.core.classify.cmp;
|
||||
|
||||
import com.yomahub.liteflow.ai.annotation.AIComponent;
|
||||
import com.yomahub.liteflow.ai.annotation.model.io.AIInput;
|
||||
import com.yomahub.liteflow.ai.annotation.model.io.AIOutput;
|
||||
import com.yomahub.liteflow.ai.annotation.model.io.InputField;
|
||||
import com.yomahub.liteflow.ai.annotation.model.node.AIClassify;
|
||||
import com.yomahub.liteflow.ai.util.TriState;
|
||||
|
||||
/**
|
||||
* TODO
|
||||
*
|
||||
* @author 苍镜月
|
||||
* @since TODO
|
||||
*/
|
||||
|
||||
@AIComponent(
|
||||
nodeId = "aiMultiSwitch",
|
||||
nodeName = "aiMultiSwitch",
|
||||
// provider = "ollama",
|
||||
// apiUrl = "http://localhost:11434",
|
||||
// model = "qwen3:32b",
|
||||
provider = "openai",
|
||||
apiUrl = "https://ark.cn-beijing.volces.com/api/v3",
|
||||
model = "doubao-seed-1-6-250615",
|
||||
enableThinking = TriState.FALSE,
|
||||
readTimeout = "10m",
|
||||
connectTimeout = "10m"
|
||||
)
|
||||
@AIClassify(
|
||||
userPrompt = "{{question}}",
|
||||
categories = {"java", "python"},
|
||||
multiLabel = true
|
||||
)
|
||||
@AIInput(
|
||||
mapping = {
|
||||
@InputField(name = "question", expression = "test", defaultValue = "请帮我写一段Java代码, 同时给出 Python 代码"),
|
||||
}
|
||||
)
|
||||
@AIOutput(
|
||||
methodExpress = "setData",
|
||||
useKeyIndex = true,
|
||||
key = "result"
|
||||
)
|
||||
public interface AIMultiClassifyCmp {
|
||||
}
|
||||
@@ -2,6 +2,10 @@
|
||||
<!DOCTYPE flow PUBLIC "liteflow" "liteflow.dtd">
|
||||
<flow>
|
||||
<chain name="chain1">
|
||||
THEN(a, SWITCH(ai).TO(java, python));
|
||||
THEN(a, SWITCH(aiSwitch).TO(java, python));
|
||||
</chain>
|
||||
|
||||
<chain name="chain2">
|
||||
THEN(a, SWITCH(aiMultiSwitch).TO(java, python));
|
||||
</chain>
|
||||
</flow>
|
||||
|
||||
Reference in New Issue
Block a user