Feat(engine): 实现ScanningToolRegistry,重构部分 JsonSchemaGenerator实现

This commit is contained in:
LuanY77
2025-08-18 11:05:31 +08:00
parent d8880a423a
commit b0e4094b19
8 changed files with 171 additions and 44 deletions

View File

@@ -7,7 +7,6 @@ import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchem
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
import com.yomahub.liteflow.ai.engine.tool.method.MethodToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
import com.yomahub.liteflow.log.LFLog;
@@ -15,9 +14,10 @@ import com.yomahub.liteflow.log.LFLoggerManager;
import org.springframework.context.ApplicationContext;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Type;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
/**
@@ -111,51 +111,19 @@ public class SpringBeanToolRegistry implements ToolRegistry {
* @param toolAnnotation 工具注解实例
* @return ToolCallBack 实现
*/
// TODO 动态代理 Bean
private ToolCallBack createToolCallBack(String toolName, Object bean, Method method, Tool toolAnnotation) {
// 1. 获取工具描述
String description = String.join("\n", toolAnnotation.value());
// 2. 动态生成 JsonSchema
JsonNode schema;
if (method.getParameterCount() == 1 && !method.getParameterTypes()[0].isPrimitive()) {
// 如果方法参数只有一个非原始类型的参数,则直接使用该参数的类型作为 schema
Type paramType = method.getGenericParameterTypes()[0];
schema = JsonSchemaGenerator.generate(paramType);
} else {
// 多个参数,提取为 Map最后聚合为 Schema
Map<String, Type> typeMap = extractMethodParamsAsTypeMap(method);
schema = JsonSchemaGenerator.generateFromTypeMap(typeMap, true);
}
JsonNode inputSchema = JsonSchemaGenerator.generate(method, true);
// 3. 创建 ToolDefinition
ToolDefinition<?> toolDefinition = new ToolDefinition<>(toolName, description, schema);
ToolDefinition<?> toolDefinition = new ToolDefinition<>(toolName, description, inputSchema);
// 4. 创建 ToolCallBack
// TODO 动态代理 Bean 处理(
return new MethodToolCallBack(toolDefinition, bean, method);
}
/**
* 提取方法参数名称和类型
*
* @param method 方法实例
* @return 参数名和类型的映射
*/
private Map<String, Type> extractMethodParamsAsTypeMap(Method method) {
Map<String, Type> paramsMap = new LinkedHashMap<>();
for (Parameter parameter : method.getParameters()) {
ToolParam toolParam = parameter.getAnnotation(ToolParam.class);
String paramName;
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
// 但是需要开启 -parameters 参数,否则可能无法获取到
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
paramName = toolParam.value();
} else {
paramName = parameter.getName();
}
paramsMap.put(paramName, parameter.getParameterizedType());
}
return paramsMap;
}
}

View File

@@ -1,5 +1,6 @@
package com.yomahub.liteflow.ai.engine.model.output.structure.generator;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
@@ -10,7 +11,10 @@ import com.github.victools.jsonschema.module.jackson.JacksonOption;
import com.yomahub.liteflow.ai.engine.model.output.structure.Description;
import com.yomahub.liteflow.ai.engine.model.output.structure.ParameterizedTypeImpl;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.*;
@@ -232,6 +236,52 @@ public class JsonSchemaGenerator {
return arguments;
}
// ===== 根据 Method 的参数 动态生成 =====
/**
* 根据 Method 的参数动态生成 JSON Schema
*
* @param method 方法实例
* @param strict 是否为严格模式
* @return 生成的 JSON Schema
*/
public static JsonNode generate(Method method, boolean strict) {
if (method.getParameterCount() == 1 && !method.getParameterTypes()[0].isPrimitive()) {
// 如果方法参数只有一个非原始类型的参数,则直接使用该参数的类型作为 schema
Type paramType = method.getGenericParameterTypes()[0];
return generate(paramType, strict);
} else {
// 多个参数,提取为 Map最后聚合为 Schema
Map<String, Type> typeMap = extractMethodParamsAsTypeMap(method);
return generateFromTypeMap(typeMap, strict);
}
}
/**
* 提取方法参数名称和类型
*
* @param method 方法实例
* @return 参数名和类型的映射
*/
private static Map<String, Type> extractMethodParamsAsTypeMap(Method method) {
Map<String, Type> paramsMap = new LinkedHashMap<>();
for (Parameter parameter : method.getParameters()) {
ToolParam toolParam = parameter.getAnnotation(ToolParam.class);
String paramName;
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
// 但是需要开启 -parameters 参数,否则可能无法获取到
// 如果未开启获取到的是 arg0, arg1, ... 的形式
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
paramName = toolParam.value();
} else {
paramName = parameter.getName();
}
paramsMap.put(paramName, parameter.getParameterizedType());
}
return paramsMap;
}
// ===== 根据 Map 动态生成 =====
/**

View File

@@ -66,6 +66,7 @@ public class MethodToolCallBack implements ToolCallBack {
String paramName;
// 这里将注解的 value 作为参数名,如果注解不存在,那么将使用方法参数的名字
// 但是需要开启 -parameters 参数,否则可能无法获取到
// 如果未开启获取到的是 arg0, arg1, ... 的形式
if (Objects.nonNull(toolParam) && StrUtil.isNotBlank(toolParam.value())) {
paramName = toolParam.value();
} else {
@@ -73,7 +74,7 @@ public class MethodToolCallBack implements ToolCallBack {
}
JsonNode argNode = inputNode.get(paramName);
if (argNode.isNull()) {
if (toolParam.required()) {
if (Objects.nonNull(toolParam) && toolParam.required()) {
throw new IllegalArgumentException("Missing required parameter: " + paramName);
}
args[i] = null;

View File

@@ -1,10 +1,21 @@
package com.yomahub.liteflow.ai.engine.tool.registry;
import cn.hutool.core.util.ClassUtil;
import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.databind.JsonNode;
import com.yomahub.liteflow.ai.engine.exception.LiteFlowAIEngineException;
import com.yomahub.liteflow.ai.engine.log.EngineLog;
import com.yomahub.liteflow.ai.engine.log.EngineLogManager;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
import com.yomahub.liteflow.ai.engine.tool.method.MethodToolCallBack;
import java.util.Collection;
import java.util.Map;
import java.lang.reflect.Method;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
* 自动扫描注册工具
@@ -15,6 +26,7 @@ import java.util.concurrent.ConcurrentHashMap;
public class ScanningToolRegistry implements ToolRegistry {
private final EngineLog LOG = EngineLogManager.getLogger(ScanningToolRegistry.class);
private final Map<String, ToolCallBack> registry = new ConcurrentHashMap<>();
/**
@@ -32,7 +44,54 @@ public class ScanningToolRegistry implements ToolRegistry {
* @param packages 要扫描的包路径列表
*/
private void scanAndRegisterTools(String... packages) {
// TODO
if (Objects.isNull(packages) || packages.length == 0) {
LOG.warn("No packages specified for tool scanning. Skipping tool registration.");
return;
}
// 扫描指定包路径下的所有类
Set<Class<?>> classes = Arrays.stream(packages)
.map(ClassUtil::scanPackage)
.flatMap(Collection::stream)
.collect(Collectors.toSet());
// 实例缓存
Map<Class<?>, Object> instanceCache = new HashMap<>();
for (Class<?> clazz : classes) {
for (Method method : clazz.getMethods()) {
if (method.isAnnotationPresent(Tool.class)) {
try {
// 获取实例
Object instance = instanceCache.computeIfAbsent(clazz, c -> {
try {
// 需要提供无参构造
return c.getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new LiteFlowAIEngineException("Failed to create instance of class: " + c.getName(), e);
}
});
// 获取工具注解
Tool toolAnnotation = method.getAnnotation(Tool.class);
String toolName = StrUtil.isNotBlank(toolAnnotation.name()) ?
toolAnnotation.name() : method.getName();
// 尝试注册工具
registry.computeIfAbsent(toolName, name -> {
String description = String.join("\n", toolAnnotation.value());
JsonNode inputSchema = JsonSchemaGenerator.generate(method, true);
ToolDefinition<?> toolDefinition = new ToolDefinition<>(name, description, inputSchema);
LOG.info("Registered tool: {} from method: [{}]", toolName, method.getName());
return new MethodToolCallBack(toolDefinition, instance, method);
});
} catch (Exception e) {
LOG.error("Failed to register tool from method: [{}] in class: [{}]. Error: {}",
method.getName(), clazz.getName(), e.getMessage(), e);
}
}
}
}
LOG.info("Tool scanning completed. Total tools registered: {}", registry.size());
}
@Override

View File

@@ -2,6 +2,7 @@ package com.yomahub.liteflow.test.ai.core.tool;
import com.yomahub.liteflow.ai.context.ChatContext;
import com.yomahub.liteflow.ai.context.StreamHandler;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
import com.yomahub.liteflow.ai.util.SpringUtil;
import com.yomahub.liteflow.core.FlowExecutor;
@@ -39,9 +40,14 @@ public class ToolTest {
public void testToolRegistry() {
Assertions.assertNotNull(toolRegistry);
Assertions.assertFalse(toolRegistry.getAllTools().isEmpty(), "Tool registry should contain tools");
Assertions.assertEquals(2, toolRegistry.getAllTools().size());
Assertions.assertEquals(4, toolRegistry.getAllTools().size());
Assertions.assertEquals("weather_tool", toolRegistry.getTool("weather_tool").getName());
Assertions.assertEquals("assemble_tool", toolRegistry.getTool("assemble_tool").getName());
ToolCallBack testTool = toolRegistry.getTool("test_tool");
String res = testTool.call("{\"arg0\": \"Hello\", \"arg1\": \"World\"}");
System.out.println(res);
String nullRes = toolRegistry.getTool("null_tool").call(null);
System.out.println(nullRes);
}
@Test

View File

@@ -23,4 +23,14 @@ public class ToolConfig {
public String assemble(@ToolParam("a") String a, @ToolParam("b") String b) {
return "Assembled result: " + a + " and " + b;
}
@Tool(name = "test_tool", value = {"测试工具", "不使用 ToolParam 注解,参数名应当为 arg0 和 arg1"})
public String tesTool(String input1, String input2) {
return "Test tool executed with inputs: " + input1 + " and " + input2;
}
@Tool(name = "null_tool", value = {"空工具", "不执行任何操作"})
public void nullTool() {
System.out.println("这是一个空工具,不执行任何操作");
}
}

View File

@@ -2,16 +2,19 @@ package com.yomahub.liteflow.test.ai.engine.tool;
import com.yomahub.liteflow.ai.engine.model.output.structure.TypeReference;
import com.yomahub.liteflow.ai.engine.model.output.structure.generator.JsonSchemaGenerator;
import com.yomahub.liteflow.ai.engine.tool.ToolCallBack;
import com.yomahub.liteflow.ai.engine.tool.ToolDefinition;
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
import com.yomahub.liteflow.ai.engine.tool.function.FunctionToolCallback;
import com.yomahub.liteflow.ai.engine.tool.registry.ScanningToolRegistry;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -48,6 +51,14 @@ public class ToolTest {
}
}
@Test
public void testScanningToolRegistry() {
ScanningToolRegistry scanningToolRegistry = new ScanningToolRegistry("com.yomahub.liteflow.test.ai.engine.tool.domain");
Collection<ToolCallBack> tools = scanningToolRegistry.getAllTools();
Assertions.assertEquals(4, tools.size());
System.out.println(tools);
}
@Test
public void testToolDefinition() {
ToolDefinition<Input<String>> toolDefinition = new ToolDefinition<>("testTool", "Test Tool", new TypeReference<Input<String>>() {

View File

@@ -1,5 +1,7 @@
package com.yomahub.liteflow.test.ai.engine.tool.domain;
import com.yomahub.liteflow.ai.engine.tool.annotation.Tool;
import com.yomahub.liteflow.ai.engine.tool.annotation.ToolParam;
import com.yomahub.liteflow.ai.engine.tool.function.FunctionToolCallback;
import com.yomahub.liteflow.ai.engine.tool.registry.StaticToolRegistry;
import com.yomahub.liteflow.ai.engine.tool.registry.ToolRegistry;
@@ -30,4 +32,24 @@ public class TestTools {
toolRegistry.register(weatherTool);
return toolRegistry;
}
@Tool(name = "weather_tool", value = {"查询天气", "获取指定位置的天气信息"})
public String queryWeatherWithLocation(@ToolParam("location") ToolInput location) {
return "The weather in " + location + " is sunny, 25°C.";
}
@Tool(name = "assemble_tool", value = {"组装工具", "将 a 和 b 组装成答案"})
public String assemble(@ToolParam("a") String a, @ToolParam("b") String b) {
return "Assembled result: " + a + " and " + b;
}
@Tool(name = "test_tool", value = {"测试工具", "不使用 ToolParam 注解,参数名应当为 arg0 和 arg1"})
public String tesTool(String input1, String input2) {
return "Test tool executed with inputs: " + input1 + " and " + input2;
}
@Tool(name = "null_tool", value = {"空工具", "不执行任何操作"})
public void nullTool() {
System.out.println("这是一个空工具,不执行任何操作");
}
}