diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java index a03c7b1b7a..cf33dc5645 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java @@ -34,12 +34,13 @@ * @param the input type * @param the output type * @author Christian Tzolov + * @author Thomas Vitale */ public final class FunctionInvokingFunctionCallback extends AbstractFunctionCallback { private final BiFunction biFunction; - FunctionInvokingFunctionCallback(String name, String description, String inputTypeSchema, Type inputType, + public FunctionInvokingFunctionCallback(String name, String description, String inputTypeSchema, Type inputType, Function responseConverter, ObjectMapper objectMapper, BiFunction function) { super(name, description, inputTypeSchema, inputType, responseConverter, objectMapper); Assert.notNull(function, "Function must not be null"); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java index 9e23370b4f..2bdfd40f86 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java @@ -18,6 +18,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Parameter; import java.util.List; import java.util.Map; import java.util.function.Function; @@ -50,6 +51,7 @@ * Automatically infers the input parameters JSON schema from method's argument types. * * @author Christian Tzolov + * @author Thomas Vitale * @since 1.0.0 */ public class MethodInvokingFunctionCallback implements FunctionCallback { @@ -94,11 +96,11 @@ public class MethodInvokingFunctionCallback implements FunctionCallback { private final String name; /** - * + * Custom response converter function. */ private final Function responseConverter; - MethodInvokingFunctionCallback(Object functionObject, Method method, String description, ObjectMapper mapper, + public MethodInvokingFunctionCallback(Object functionObject, Method method, String description, ObjectMapper mapper, String name, Function responseConverter) { Assert.notNull(method, "Method must not be null"); @@ -116,11 +118,7 @@ public class MethodInvokingFunctionCallback implements FunctionCallback { Assert.isTrue(this.functionObject != null || Modifier.isStatic(this.method.getModifiers()), "Function object must be provided for non-static methods!"); - // Generate the JSON schema from the method input parameters - Map> methodParameters = Stream.of(method.getParameters()) - .collect(Collectors.toMap(param -> param.getName(), param -> param.getType())); - - this.inputSchema = this.generateJsonSchema(methodParameters); + this.inputSchema = this.generateJsonSchema(method); logger.debug("Generated JSON Schema: {}", this.inputSchema); } @@ -192,10 +190,14 @@ else if (returnType == Class.class || returnType.isRecord() || returnType == Lis /** * Generates a JSON schema from the given named classes. - * @param namedClasses The named classes to generate the schema from. + * @param method The method whose parameters will be used to generate the JSON schema. * @return The generated JSON schema. */ - protected String generateJsonSchema(Map> namedClasses) { + protected String generateJsonSchema(Method method) { + // Extract the method parameters to generate the JSON schema for. + Map> namedClasses = Stream.of(method.getParameters()) + .collect(Collectors.toMap(Parameter::getName, Parameter::getType)); + try { JsonSchemaGenerator schemaGen = new JsonSchemaGenerator(this.mapper);