diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 9ba3b81572..891ac2bec9 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -17,7 +17,6 @@ package org.springframework.ai.mistralai; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -30,11 +29,13 @@ import com.fasterxml.jackson.annotation.JsonProperty; import org.jspecify.annotations.Nullable; +import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; import org.springframework.ai.mistralai.api.MistralAiApi.FunctionTool; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.ai.model.tool.StructuredOutputChatOptions; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; @@ -48,6 +49,7 @@ * @author Thomas Vitale * @author Alexandros Pappas * @author Jason Smith + * @author Sebastien Deleuze * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -163,29 +165,57 @@ public class MistralAiChatOptions implements ToolCallingChatOptions, StructuredO @JsonIgnore private Map toolContext = new HashMap<>(); - public static Builder builder() { - return new Builder(); + // Temporary constructor to maintain compat with ModelOptionUtils + public MistralAiChatOptions() { + } + + protected MistralAiChatOptions(String model, @Nullable Double temperature, @Nullable Double topP, + @Nullable Integer maxTokens, @Nullable Boolean safePrompt, @Nullable Integer randomSeed, + @Nullable ResponseFormat responseFormat, @Nullable List stop, @Nullable Double frequencyPenalty, + @Nullable Double presencePenalty, @Nullable Integer n, @Nullable List tools, + @Nullable ToolChoice toolChoice, @Nullable List toolCallbacks, + @Nullable Set toolNames, @Nullable Boolean internalToolExecutionEnabled, + @Nullable Map toolContext) { + + this.model = model; + this.temperature = temperature; + if (topP != null) { + this.topP = topP; + } + this.maxTokens = maxTokens; + if (safePrompt != null) { + this.safePrompt = safePrompt; + } + this.randomSeed = randomSeed; + this.responseFormat = responseFormat; + this.stop = stop; + if (frequencyPenalty != null) { + this.frequencyPenalty = frequencyPenalty; + } + if (presencePenalty != null) { + this.presencePenalty = presencePenalty; + } + this.n = n; + this.tools = tools; + this.toolChoice = toolChoice; + if (toolCallbacks != null) { + this.toolCallbacks = new ArrayList<>(toolCallbacks); + } + if (toolNames != null) { + this.toolNames = new HashSet<>(toolNames); + } + this.internalToolExecutionEnabled = internalToolExecutionEnabled; + if (toolContext != null) { + this.toolContext = new HashMap<>(toolContext); + } + } + + public static Builder builder() { + return new Builder<>(); } public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) { - return builder().model(fromOptions.getModel()) - .maxTokens(fromOptions.getMaxTokens()) - .safePrompt(fromOptions.getSafePrompt()) - .randomSeed(fromOptions.getRandomSeed()) - .temperature(fromOptions.getTemperature()) - .topP(fromOptions.getTopP()) - .responseFormat(fromOptions.getResponseFormat()) - .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) - .frequencyPenalty(fromOptions.getFrequencyPenalty()) - .presencePenalty(fromOptions.getPresencePenalty()) - .n(fromOptions.getN()) - .tools(fromOptions.getTools() != null ? new ArrayList<>(fromOptions.getTools()) : null) - .toolChoice(fromOptions.getToolChoice()) - .toolCallbacks(new ArrayList<>(fromOptions.getToolCallbacks())) - .toolNames(new HashSet<>(fromOptions.getToolNames())) - .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) - .toolContext(new HashMap<>(fromOptions.getToolContext())) - .build(); + return fromOptions.mutate().build(); } @Override @@ -386,9 +416,33 @@ public void setOutputSchema(String outputSchema) { } @Override - @SuppressWarnings("unchecked") public MistralAiChatOptions copy() { - return fromOptions(this); + return mutate().build(); + } + + public MistralAiChatOptions.Builder mutate() { + return builder() + // ChatOptions + .model(this.model) + .frequencyPenalty(this.frequencyPenalty) + .maxTokens(this.maxTokens) + .presencePenalty(this.presencePenalty) + .stop(this.stop == null ? null : new ArrayList<>(this.stop)) + .temperature(this.temperature) + .topP(this.topP) + .topK(this.getTopK()) // always null but here for consistency + // ToolCallingChatOptions + .toolCallbacks(new ArrayList<>(this.getToolCallbacks())) + .toolNames(new HashSet<>(this.getToolNames())) + .toolContext(new HashMap<>(this.getToolContext())) + .internalToolExecutionEnabled(this.getInternalToolExecutionEnabled()) + // Mistral AI specific + .safePrompt(this.safePrompt) + .randomSeed(this.randomSeed) + .responseFormat(this.responseFormat) + .n(this.n) + .tools(this.tools != null ? new ArrayList<>(this.tools) : null) + .toolChoice(this.toolChoice); } @Override @@ -424,125 +478,115 @@ public boolean equals(Object obj) { && Objects.equals(this.toolContext, other.toolContext); } - public static final class Builder { + public static class Builder> extends DefaultToolCallingChatOptions.Builder + implements StructuredOutputChatOptions.Builder { - private final MistralAiChatOptions options = new MistralAiChatOptions(); + private @Nullable Boolean safePrompt; - public Builder model(String model) { - this.options.setModel(model); - return this; - } - - public Builder model(MistralAiApi.ChatModel chatModel) { - this.options.setModel(chatModel.getName()); - return this; - } + private @Nullable Integer randomSeed; - public Builder maxTokens(@Nullable Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } + private @Nullable ResponseFormat responseFormat; - public Builder safePrompt(Boolean safePrompt) { - this.options.setSafePrompt(safePrompt); - return this; - } + private @Nullable Integer n; - public Builder randomSeed(@Nullable Integer randomSeed) { - this.options.setRandomSeed(randomSeed); - return this; - } + private @Nullable List tools; - public Builder stop(@Nullable List stop) { - this.options.setStop(stop); - return this; - } - - public Builder frequencyPenalty(Double frequencyPenalty) { - this.options.setFrequencyPenalty(frequencyPenalty); - return this; - } - - public Builder presencePenalty(Double presencePenalty) { - this.options.presencePenalty = presencePenalty; - return this; - } + private @Nullable ToolChoice toolChoice; - public Builder n(@Nullable Integer n) { - this.options.setN(n); - return this; - } - - public Builder temperature(@Nullable Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public Builder topP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public Builder responseFormat(@Nullable ResponseFormat responseFormat) { - this.options.setResponseFormat(responseFormat); - return this; + public B model(MistralAiApi.@Nullable ChatModel chatModel) { + if (chatModel != null) { + this.model(chatModel.getName()); + } + else { + this.model((String) null); + } + return self(); } - public Builder tools(@Nullable List tools) { - this.options.setTools(tools); - return this; + public B safePrompt(@Nullable Boolean safePrompt) { + this.safePrompt = safePrompt; + return self(); } - public Builder toolChoice(@Nullable ToolChoice toolChoice) { - this.options.setToolChoice(toolChoice); - return this; + public B randomSeed(@Nullable Integer randomSeed) { + this.randomSeed = randomSeed; + return self(); } - public Builder toolCallbacks(List toolCallbacks) { - this.options.setToolCallbacks(toolCallbacks); - return this; + public B stop(@Nullable List stop) { + super.stopSequences(stop); + return self(); } - public Builder toolCallbacks(ToolCallback... toolCallbacks) { - Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); - this.options.getToolCallbacks().addAll(Arrays.asList(toolCallbacks)); - return this; + public B responseFormat(@Nullable ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return self(); } - public Builder toolNames(Set toolNames) { - Assert.notNull(toolNames, "toolNames cannot be null"); - this.options.setToolNames(toolNames); - return this; + public B n(@Nullable Integer n) { + this.n = n; + return self(); } - public Builder toolNames(String... toolNames) { - Assert.notNull(toolNames, "toolNames cannot be null"); - this.options.getToolNames().addAll(Set.of(toolNames)); - return this; + public B tools(@Nullable List tools) { + this.tools = tools; + return self(); } - public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) { - this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); - return this; + public B toolChoice(@Nullable ToolChoice toolChoice) { + this.toolChoice = toolChoice; + return self(); } - public Builder toolContext(Map toolContext) { - if (this.options.toolContext == null) { - this.options.toolContext = toolContext; + @Override + public B outputSchema(@Nullable String outputSchema) { + if (outputSchema != null) { + this.responseFormat = ResponseFormat.builder() + .type(ResponseFormat.Type.JSON_SCHEMA) + .jsonSchema(outputSchema) + .build(); } else { - this.options.toolContext.putAll(toolContext); + this.responseFormat = null; } - return this; + return self(); } - public Builder outputSchema(String outputSchema) { - this.options.setOutputSchema(outputSchema); - return this; + @Override + public B combineWith(ChatOptions.Builder other) { + super.combineWith(other); + if (other instanceof Builder that) { + if (that.safePrompt != null) { + this.safePrompt = that.safePrompt; + } + if (that.randomSeed != null) { + this.randomSeed = that.randomSeed; + } + if (that.responseFormat != null) { + this.responseFormat = that.responseFormat; + } + if (that.n != null) { + this.n = that.n; + } + if (that.tools != null) { + this.tools = that.tools; + } + if (that.toolChoice != null) { + this.toolChoice = that.toolChoice; + } + } + return self(); } + @Override + @SuppressWarnings("NullAway") public MistralAiChatOptions build() { - return this.options; + // TODO: add assertions, remove SuppressWarnings + // Assert.state(this.model != null, "model must be set"); + return new MistralAiChatOptions(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, + this.randomSeed, this.responseFormat, this.stopSequences, this.frequencyPenalty, + this.presencePenalty, this.n, this.tools, this.toolChoice, new ArrayList<>(this.toolCallbacks), + new HashSet<>(this.toolNames), this.internalToolExecutionEnabled, new HashMap<>(this.toolContext)); } } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java index 49bfbef4c3..f811e9b187 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java @@ -27,6 +27,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.model.tool.StructuredOutputChatOptions; +import org.springframework.ai.test.options.AbstractChatOptionsTests; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -36,7 +37,8 @@ * * @author Alexandros Pappas */ -class MistralAiChatOptionsTests { +class MistralAiChatOptionsTests> + extends AbstractChatOptionsTests { @Test void testBuilderWithAllFields() { @@ -118,7 +120,7 @@ void testSetters() { @Test void testDefaultValues() { - MistralAiChatOptions options = new MistralAiChatOptions(); + MistralAiChatOptions options = MistralAiChatOptions.builder().build(); assertThat(options.getModel()).isNull(); assertThat(options.getTemperature()).isNull(); assertThat(options.getTopP()).isEqualTo(1.0); @@ -547,6 +549,17 @@ void testResponseFormatWithOptionsIntegration() { assertThat(options.getResponseFormat().getType()).isEqualTo(ResponseFormat.Type.JSON_SCHEMA); } + @Override + protected Class getConcreteOptionsClass() { + return MistralAiChatOptions.class; + } + + @Override + @SuppressWarnings("unchecked") + protected B readyToBuildBuilder() { + return (B) MistralAiChatOptions.builder().model(MistralAiApi.ChatModel.MISTRAL_SMALL).maxTokens(500); + } + // Test record for schema generation tests record TestRecord(String name, int age, List tags) {