mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-12 23:41:04 +08:00
Feat(engine): 实现ScanningToolRegistry,重构部分 JsonSchemaGenerator实现
This commit is contained in:
@@ -7,7 +7,6 @@ import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchem
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
|
||||
import com.yomahub.liteflow.ai.engine.tool.method.MethodToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
import com.yomahub.liteflow.log.LFLog;
|
||||
@@ -15,9 +14,10 @@ import com.yomahub.liteflow.log.LFLoggerManager;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Parameter;
|
||||
import java.lang.reflect.Type;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
@@ -111,51 +111,19 @@ public class SpringBeanToolRegistry implements ToolRegistry {
|
||||
* @param toolAnnotation 工具注解实例
|
||||
* @return ToolCallBack 实现
|
||||
*/
|
||||
// TODO 动态代理 Bean
|
||||
private ToolCallBack createToolCallBack(String toolName, Object bean, Method method, Tool toolAnnotation) {
|
||||
// 1. 获取工具描述
|
||||
String description = String.join("\n", toolAnnotation.value());
|
||||
|
||||
// 2. 动态生成 JsonSchema
|
||||
JsonNode schema;
|
||||
if (method.getParameterCount() == 1 && !method.getParameterTypes()[0].isPrimitive()) {
|
||||
// 如果方法参数只有一个非原始类型的参数,则直接使用该参数的类型作为 schema
|
||||
Type paramType = method.getGenericParameterTypes()[0];
|
||||
schema = JsonSchemaGenerator.generate(paramType);
|
||||
} else {
|
||||
// 多个参数,提取为 Map,最后聚合为 Schema
|
||||
Map<String, Type> typeMap = extractMethodParamsAsTypeMap(method);
|
||||
schema = JsonSchemaGenerator.generateFromTypeMap(typeMap, true);
|
||||
}
|
||||
JsonNode inputSchema = JsonSchemaGenerator.generate(method, true);
|
||||
|
||||
// 3. 创建 ToolDefinition
|
||||
ToolDefinition<?> toolDefinition = new ToolDefinition<>(toolName, description, schema);
|
||||
ToolDefinition<?> toolDefinition = new ToolDefinition<>(toolName, description, inputSchema);
|
||||
|
||||
// 4. 创建 ToolCallBack
|
||||
// TODO 动态代理 Bean 处理(?
|
||||
return new MethodToolCallBack(toolDefinition, bean, method);
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取方法参数名称和类型
|
||||
*
|
||||
* @param method 方法实例
|
||||
* @return 参数名和类型的映射
|
||||
*/
|
||||
private Map<String, Type> extractMethodParamsAsTypeMap(Method method) {
|
||||
Map<String, Type> paramsMap = new LinkedHashMap<>();
|
||||
for (Parameter parameter : method.getParameters()) {
|
||||
ToolParam toolParam = parameter.getAnnotation(ToolParam.class);
|
||||
|
||||
String paramName;
|
||||
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
|
||||
// 但是需要开启 -parameters 参数,否则可能无法获取到
|
||||
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
|
||||
paramName = toolParam.value();
|
||||
} else {
|
||||
paramName = parameter.getName();
|
||||
}
|
||||
paramsMap.put(paramName, parameter.getParameterizedType());
|
||||
}
|
||||
return paramsMap;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.yomahub.liteflow.ai.engine.model.output.structure.generator;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||
@@ -10,7 +11,10 @@ import com.github.victools.jsonschema.module.jackson.JacksonOption;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.Description;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.ParameterizedTypeImpl;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.lang.reflect.Parameter;
|
||||
import java.lang.reflect.Type;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.*;
|
||||
@@ -232,6 +236,52 @@ public class JsonSchemaGenerator {
|
||||
return arguments;
|
||||
}
|
||||
|
||||
// ===== 根据 Method 的参数 动态生成 =====
|
||||
|
||||
/**
|
||||
* 根据 Method 的参数动态生成 JSON Schema
|
||||
*
|
||||
* @param method 方法实例
|
||||
* @param strict 是否为严格模式
|
||||
* @return 生成的 JSON Schema
|
||||
*/
|
||||
public static JsonNode generate(Method method, boolean strict) {
|
||||
if (method.getParameterCount() == 1 && !method.getParameterTypes()[0].isPrimitive()) {
|
||||
// 如果方法参数只有一个非原始类型的参数,则直接使用该参数的类型作为 schema
|
||||
Type paramType = method.getGenericParameterTypes()[0];
|
||||
return generate(paramType, strict);
|
||||
} else {
|
||||
// 多个参数,提取为 Map,最后聚合为 Schema
|
||||
Map<String, Type> typeMap = extractMethodParamsAsTypeMap(method);
|
||||
return generateFromTypeMap(typeMap, strict);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 提取方法参数名称和类型
|
||||
*
|
||||
* @param method 方法实例
|
||||
* @return 参数名和类型的映射
|
||||
*/
|
||||
private static Map<String, Type> extractMethodParamsAsTypeMap(Method method) {
|
||||
Map<String, Type> paramsMap = new LinkedHashMap<>();
|
||||
for (Parameter parameter : method.getParameters()) {
|
||||
ToolParam toolParam = parameter.getAnnotation(ToolParam.class);
|
||||
|
||||
String paramName;
|
||||
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
|
||||
// 但是需要开启 -parameters 参数,否则可能无法获取到
|
||||
// 如果未开启获取到的是 arg0, arg1, ... 的形式
|
||||
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
|
||||
paramName = toolParam.value();
|
||||
} else {
|
||||
paramName = parameter.getName();
|
||||
}
|
||||
paramsMap.put(paramName, parameter.getParameterizedType());
|
||||
}
|
||||
return paramsMap;
|
||||
}
|
||||
|
||||
// ===== 根据 Map 动态生成 =====
|
||||
|
||||
/**
|
||||
|
||||
@@ -66,6 +66,7 @@ public class MethodToolCallBack implements ToolCallBack {
|
||||
String paramName;
|
||||
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
|
||||
// 但是需要开启 -parameters 参数,否则可能无法获取到
|
||||
// 如果未开启获取到的是 arg0, arg1, ... 的形式
|
||||
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
|
||||
paramName = toolParam.value();
|
||||
} else {
|
||||
@@ -73,7 +74,7 @@ public class MethodToolCallBack implements ToolCallBack {
|
||||
}
|
||||
JsonNode argNode = inputNode.get(paramName);
|
||||
if (argNode.isNull()) {
|
||||
if (toolParam.required()) {
|
||||
if (Objects.nonNull(toolParam) && toolParam.required()) {
|
||||
throw new IllegalArgumentException("Missing required parameter: " + paramName);
|
||||
}
|
||||
args[i] = null;
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
package com.yomahub.liteflow.ai.engine.tool.registry;
|
||||
|
||||
import cn.hutool.core.util.ClassUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLog;
|
||||
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
|
||||
import com.yomahub.liteflow.ai.engine.tool.method.MethodToolCallBack;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* 自动扫描注册工具
|
||||
@@ -15,6 +26,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class ScanningToolRegistry implements ToolRegistry {
|
||||
|
||||
private final EngineLog LOG = EngineLogManager.getLogger(ScanningToolRegistry.class);
|
||||
private final Map<String, ToolCallBack> registry = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
@@ -32,7 +44,54 @@ public class ScanningToolRegistry implements ToolRegistry {
|
||||
* @param packages 要扫描的包路径列表
|
||||
*/
|
||||
private void scanAndRegisterTools(String... packages) {
|
||||
// TODO
|
||||
if (Objects.isNull(packages) || packages.length == 0) {
|
||||
LOG.warn("No packages specified for tool scanning. Skipping tool registration.");
|
||||
return;
|
||||
}
|
||||
|
||||
// 扫描指定包路径下的所有类
|
||||
Set<Class<?>> classes = Arrays.stream(packages)
|
||||
.map(ClassUtil::scanPackage)
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
// 实例缓存
|
||||
Map<Class<?>, Object> instanceCache = new HashMap<>();
|
||||
|
||||
for (Class<?> clazz : classes) {
|
||||
for (Method method : clazz.getMethods()) {
|
||||
if (method.isAnnotationPresent(Tool.class)) {
|
||||
try {
|
||||
// 获取实例
|
||||
Object instance = instanceCache.computeIfAbsent(clazz, c -> {
|
||||
try {
|
||||
// 需要提供无参构造
|
||||
return c.getDeclaredConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new LiteFlowAIEngineException("Failed to create instance of class: " + c.getName(), e);
|
||||
}
|
||||
});
|
||||
|
||||
// 获取工具注解
|
||||
Tool toolAnnotation = method.getAnnotation(Tool.class);
|
||||
String toolName = StrUtil.isNotBlank(toolAnnotation.name()) ?
|
||||
toolAnnotation.name() : method.getName();
|
||||
// 尝试注册工具
|
||||
registry.computeIfAbsent(toolName, name -> {
|
||||
String description = String.join("\n", toolAnnotation.value());
|
||||
JsonNode inputSchema = JsonSchemaGenerator.generate(method, true);
|
||||
ToolDefinition<?> toolDefinition = new ToolDefinition<>(name, description, inputSchema);
|
||||
LOG.info("Registered tool: {} from method: [{}]", toolName, method.getName());
|
||||
return new MethodToolCallBack(toolDefinition, instance, method);
|
||||
});
|
||||
} catch (Exception e) {
|
||||
LOG.error("Failed to register tool from method: [{}] in class: [{}]. Error: {}",
|
||||
method.getName(), clazz.getName(), e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG.info("Tool scanning completed. Total tools registered: {}", registry.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.yomahub.liteflow.test.ai.core.tool;
|
||||
|
||||
import com.yomahub.liteflow.ai.context.ChatContext;
|
||||
import com.yomahub.liteflow.ai.context.StreamHandler;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
import com.yomahub.liteflow.ai.util.SpringUtil;
|
||||
import com.yomahub.liteflow.core.FlowExecutor;
|
||||
@@ -39,9 +40,14 @@ public class ToolTest {
|
||||
public void testToolRegistry() {
|
||||
Assertions.assertNotNull(toolRegistry);
|
||||
Assertions.assertFalse(toolRegistry.getAllTools().isEmpty(), "Tool registry should contain tools");
|
||||
Assertions.assertEquals(2, toolRegistry.getAllTools().size());
|
||||
Assertions.assertEquals(4, toolRegistry.getAllTools().size());
|
||||
Assertions.assertEquals("weather_tool", toolRegistry.getTool("weather_tool").getName());
|
||||
Assertions.assertEquals("assemble_tool", toolRegistry.getTool("assemble_tool").getName());
|
||||
ToolCallBack testTool = toolRegistry.getTool("test_tool");
|
||||
String res = testTool.call("{\"arg0\": \"Hello\", \"arg1\": \"World\"}");
|
||||
System.out.println(res);
|
||||
String nullRes = toolRegistry.getTool("null_tool").call(null);
|
||||
System.out.println(nullRes);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -23,4 +23,14 @@ public class ToolConfig {
|
||||
public String assemble(@ToolParam("a") String a, @ToolParam("b") String b) {
|
||||
return "Assembled result: " + a + " and " + b;
|
||||
}
|
||||
|
||||
@Tool(name = "test_tool", value = {"测试工具", "不使用 ToolParam 注解,参数名应当为 arg0 和 arg1"})
|
||||
public String tesTool(String input1, String input2) {
|
||||
return "Test tool executed with inputs: " + input1 + " and " + input2;
|
||||
}
|
||||
|
||||
@Tool(name = "null_tool", value = {"空工具", "不执行任何操作"})
|
||||
public void nullTool() {
|
||||
System.out.println("这是一个空工具,不执行任何操作");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,16 +2,19 @@ package com.yomahub.liteflow.test.ai.engine.tool;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
|
||||
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
|
||||
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
|
||||
import com.yomahub.liteflow.ai.engine.tool.function.FunctionToolCallback;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@@ -48,6 +51,14 @@ public class ToolTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScanningToolRegistry() {
|
||||
ScanningToolRegistry scanningToolRegistry = new ScanningToolRegistry("com.yomahub.liteflow.test.ai.engine.tool.domain");
|
||||
Collection<ToolCallBack> tools = scanningToolRegistry.getAllTools();
|
||||
Assertions.assertEquals(4, tools.size());
|
||||
System.out.println(tools);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToolDefinition() {
|
||||
ToolDefinition<Input<String>> toolDefinition = new ToolDefinition<>("testTool", "Test Tool", new TypeReference<Input<String>>() {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.yomahub.liteflow.test.ai.engine.tool.domain;
|
||||
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
|
||||
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
|
||||
import com.yomahub.liteflow.ai.engine.tool.function.FunctionToolCallback;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
|
||||
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
|
||||
@@ -30,4 +32,24 @@ public class TestTools {
|
||||
toolRegistry.register(weatherTool);
|
||||
return toolRegistry;
|
||||
}
|
||||
|
||||
@Tool(name = "weather_tool", value = {"查询天气", "获取指定位置的天气信息"})
|
||||
public String queryWeatherWithLocation(@ToolParam("location") ToolInput location) {
|
||||
return "The weather in " + location + " is sunny, 25°C.";
|
||||
}
|
||||
|
||||
@Tool(name = "assemble_tool", value = {"组装工具", "将 a 和 b 组装成答案"})
|
||||
public String assemble(@ToolParam("a") String a, @ToolParam("b") String b) {
|
||||
return "Assembled result: " + a + " and " + b;
|
||||
}
|
||||
|
||||
@Tool(name = "test_tool", value = {"测试工具", "不使用 ToolParam 注解,参数名应当为 arg0 和 arg1"})
|
||||
public String tesTool(String input1, String input2) {
|
||||
return "Test tool executed with inputs: " + input1 + " and " + input2;
|
||||
}
|
||||
|
||||
@Tool(name = "null_tool", value = {"空工具", "不执行任何操作"})
|
||||
public void nullTool() {
|
||||
System.out.println("这是一个空工具,不执行任何操作");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user