mirror of
https://gitee.com/dromara/liteFlow.git
synced 2026-06-10 03:07:32 +08:00
feat(agent): track loaded skills
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
package com.yomahub.liteflow.agent.skill;
|
||||
|
||||
import io.agentscope.core.hook.Hook;
|
||||
import io.agentscope.core.hook.HookEvent;
|
||||
import io.agentscope.core.hook.PostActingEvent;
|
||||
import io.agentscope.core.message.ContentBlock;
|
||||
import io.agentscope.core.message.TextBlock;
|
||||
import io.agentscope.core.message.ToolResultBlock;
|
||||
import io.agentscope.core.message.ToolUseBlock;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Tracks skills loaded by agentscope's skill-loading tool during a ReAct session.
|
||||
*/
|
||||
public class SkillTrackingHook implements Hook {
|
||||
|
||||
public static final String LOAD_SKILL_TOOL_NAME = "load_skill_through_path";
|
||||
private static final String SKILL_ID_INPUT_KEY = "skillId";
|
||||
|
||||
private final Map<String, String> skillIdToName;
|
||||
private final Set<String> usedSkills = Collections.synchronizedSet(new LinkedHashSet<>());
|
||||
|
||||
public SkillTrackingHook(Map<String, String> skillIdToName) {
|
||||
this.skillIdToName = skillIdToName == null
|
||||
? Map.of()
|
||||
: Collections.unmodifiableMap(new LinkedHashMap<>(skillIdToName));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends HookEvent> Mono<T> onEvent(T event) {
|
||||
if (event instanceof PostActingEvent postActingEvent) {
|
||||
recordSkillLoad(postActingEvent.getToolUse(), postActingEvent.getToolResult());
|
||||
}
|
||||
return Mono.just(event);
|
||||
}
|
||||
|
||||
public List<String> getUsedSkills() {
|
||||
synchronized (usedSkills) {
|
||||
return List.copyOf(usedSkills);
|
||||
}
|
||||
}
|
||||
|
||||
public void clear() {
|
||||
usedSkills.clear();
|
||||
}
|
||||
|
||||
private void recordSkillLoad(ToolUseBlock toolUse, ToolResultBlock toolResult) {
|
||||
if (toolUse == null || !LOAD_SKILL_TOOL_NAME.equals(toolUse.getName()) || isErrorResult(toolResult)) {
|
||||
return;
|
||||
}
|
||||
Map<String, Object> input = toolUse.getInput();
|
||||
if (input == null || !input.containsKey(SKILL_ID_INPUT_KEY)) {
|
||||
return;
|
||||
}
|
||||
Object skillId = input.get(SKILL_ID_INPUT_KEY);
|
||||
if (skillId == null) {
|
||||
return;
|
||||
}
|
||||
String skillName = skillIdToName.get(String.valueOf(skillId));
|
||||
if (skillName != null) {
|
||||
usedSkills.add(skillName);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean isErrorResult(ToolResultBlock toolResult) {
|
||||
if (toolResult == null || toolResult.getOutput() == null) {
|
||||
return false;
|
||||
}
|
||||
for (ContentBlock block : toolResult.getOutput()) {
|
||||
if (block instanceof TextBlock textBlock) {
|
||||
String text = textBlock.getText();
|
||||
if (text != null && text.startsWith("Error:")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,151 @@
|
||||
package com.yomahub.liteflow.test.agent;
|
||||
|
||||
import com.yomahub.liteflow.agent.skill.SkillTrackingHook;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import io.agentscope.core.agent.Agent;
|
||||
import io.agentscope.core.agent.Event;
|
||||
import io.agentscope.core.agent.StreamOptions;
|
||||
import io.agentscope.core.hook.PostActingEvent;
|
||||
import io.agentscope.core.message.Msg;
|
||||
import io.agentscope.core.message.ToolResultBlock;
|
||||
import io.agentscope.core.message.ToolUseBlock;
|
||||
import io.agentscope.core.tool.Toolkit;
|
||||
import reactor.core.publisher.Flux;
|
||||
import reactor.core.publisher.Mono;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class ReActAgentSkillTrackingHookTest {
|
||||
|
||||
private static final Agent TEST_AGENT = new Agent() {
|
||||
@Override
|
||||
public String getAgentId() {
|
||||
return "test-agent-id";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "test-agent";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void interrupt() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void interrupt(Msg msg) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Msg> call(List<Msg> messages) {
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Msg> call(List<Msg> messages, Class<?> responseType) {
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Msg> call(List<Msg> messages, JsonNode schema) {
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<Event> stream(List<Msg> messages, StreamOptions options) {
|
||||
return Flux.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<Event> stream(List<Msg> messages, StreamOptions options, Class<?> responseType) {
|
||||
return Flux.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<Event> stream(List<Msg> messages, StreamOptions options, JsonNode schema) {
|
||||
return Flux.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> observe(Msg msg) {
|
||||
return Mono.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Mono<Void> observe(List<Msg> messages) {
|
||||
return Mono.empty();
|
||||
}
|
||||
};
|
||||
|
||||
@Test
|
||||
public void testTracksLoadSkillToolUseByMappedSkillName() {
|
||||
SkillTrackingHook hook = new SkillTrackingHook(new LinkedHashMap<>(Map.of("skill-1", "demo")));
|
||||
|
||||
hook.onEvent(postActingEvent(SkillTrackingHook.LOAD_SKILL_TOOL_NAME, Map.of("skillId", "skill-1"))).block();
|
||||
|
||||
Assertions.assertEquals(List.of("demo"), hook.getUsedSkills());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeduplicatesAndClearsUsedSkills() {
|
||||
Map<String, String> skillIdToName = new LinkedHashMap<>();
|
||||
skillIdToName.put("skill-1", "demo");
|
||||
skillIdToName.put("skill-2", "research");
|
||||
SkillTrackingHook hook = new SkillTrackingHook(skillIdToName);
|
||||
|
||||
hook.onEvent(postActingEvent(SkillTrackingHook.LOAD_SKILL_TOOL_NAME, Map.of("skillId", "skill-1"))).block();
|
||||
hook.onEvent(postActingEvent(SkillTrackingHook.LOAD_SKILL_TOOL_NAME, Map.of("skillId", "skill-1"))).block();
|
||||
hook.onEvent(postActingEvent(SkillTrackingHook.LOAD_SKILL_TOOL_NAME, Map.of("skillId", "skill-2"))).block();
|
||||
|
||||
Assertions.assertEquals(List.of("demo", "research"), hook.getUsedSkills());
|
||||
|
||||
hook.clear();
|
||||
|
||||
Assertions.assertEquals(List.of(), hook.getUsedSkills());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIgnoresNonSkillTools() {
|
||||
SkillTrackingHook hook = new SkillTrackingHook(Map.of("skill-1", "demo"));
|
||||
|
||||
hook.onEvent(postActingEvent("search", Map.of("skillId", "skill-1"))).block();
|
||||
|
||||
Assertions.assertEquals(List.of(), hook.getUsedSkills());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIgnoresUnknownSkillId() {
|
||||
SkillTrackingHook hook = new SkillTrackingHook(Map.of("skill-1", "demo"));
|
||||
|
||||
Assertions.assertDoesNotThrow(() ->
|
||||
hook.onEvent(postActingEvent(SkillTrackingHook.LOAD_SKILL_TOOL_NAME, Map.of("skillId", "unknown-skill"))).block());
|
||||
|
||||
Assertions.assertEquals(List.of(), hook.getUsedSkills());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIgnoresErrorResultForKnownSkillId() {
|
||||
SkillTrackingHook hook = new SkillTrackingHook(Map.of("skill-1", "demo"));
|
||||
|
||||
hook.onEvent(postActingEvent(
|
||||
SkillTrackingHook.LOAD_SKILL_TOOL_NAME,
|
||||
Map.of("skillId", "skill-1"),
|
||||
ToolResultBlock.error("failed to load skill"))).block();
|
||||
|
||||
Assertions.assertEquals(List.of(), hook.getUsedSkills());
|
||||
}
|
||||
|
||||
private static PostActingEvent postActingEvent(String toolName, Map<String, Object> input) {
|
||||
return postActingEvent(toolName, input, null);
|
||||
}
|
||||
|
||||
private static PostActingEvent postActingEvent(
|
||||
String toolName, Map<String, Object> input, ToolResultBlock toolResult) {
|
||||
ToolUseBlock toolUse = new ToolUseBlock("tool-call-1", toolName, input);
|
||||
return new PostActingEvent(TEST_AGENT, new Toolkit(), toolUse, toolResult);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user