package dev.langchain4j.model.googleai;

import dev.langchain4j.Experimental;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Experimental
/* loaded from: input_file:dev/langchain4j/model/googleai/GoogleAiGeminiChatModel.class */
public class GoogleAiGeminiChatModel extends BaseGeminiChatModel implements ChatLanguageModel, TokenCountEstimator {
    private static final Logger log = LoggerFactory.getLogger(GoogleAiGeminiChatModel.class);
    private final GoogleAiGeminiTokenizer geminiTokenizer;

    /* loaded from: input_file:dev/langchain4j/model/googleai/GoogleAiGeminiChatModel$GoogleAiGeminiChatModelBuilder.class */
    public static class GoogleAiGeminiChatModelBuilder {
        private String apiKey;
        private String modelName;
        private Integer maxRetries;
        private Double temperature;
        private Integer topK;
        private Double topP;
        private Integer maxOutputTokens;
        private Duration timeout;
        private ResponseFormat responseFormat;
        private List<String> stopSequences;
        private GeminiFunctionCallingConfig toolConfig;
        private Boolean allowCodeExecution;
        private Boolean includeCodeExecutionOutput;
        private Boolean logRequestsAndResponses;
        private List<GeminiSafetySetting> safetySettings;
        private List<ChatModelListener> listeners;

        public GoogleAiGeminiChatModelBuilder toolConfig(GeminiMode geminiMode, String... strArr) {
            this.toolConfig = new GeminiFunctionCallingConfig(geminiMode, Arrays.asList(strArr));
            return this;
        }

        public GoogleAiGeminiChatModelBuilder safetySettings(Map<GeminiHarmCategory, GeminiHarmBlockThreshold> map) {
            this.safetySettings = (List) map.entrySet().stream().map(entry -> {
                return new GeminiSafetySetting((GeminiHarmCategory) entry.getKey(), (GeminiHarmBlockThreshold) entry.getValue());
            }).collect(Collectors.toList());
            return this;
        }

        GoogleAiGeminiChatModelBuilder() {
        }

        public GoogleAiGeminiChatModelBuilder apiKey(String str) {
            this.apiKey = str;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder maxRetries(Integer num) {
            this.maxRetries = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder temperature(Double d) {
            this.temperature = d;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder topK(Integer num) {
            this.topK = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder topP(Double d) {
            this.topP = d;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder maxOutputTokens(Integer num) {
            this.maxOutputTokens = num;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder responseFormat(ResponseFormat responseFormat) {
            this.responseFormat = responseFormat;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder stopSequences(List<String> list) {
            this.stopSequences = list;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder allowCodeExecution(Boolean bool) {
            this.allowCodeExecution = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder includeCodeExecutionOutput(Boolean bool) {
            this.includeCodeExecutionOutput = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder logRequestsAndResponses(Boolean bool) {
            this.logRequestsAndResponses = bool;
            return this;
        }

        public GoogleAiGeminiChatModelBuilder listeners(List<ChatModelListener> list) {
            this.listeners = list;
            return this;
        }

        public GoogleAiGeminiChatModel build() {
            return new GoogleAiGeminiChatModel(this.apiKey, this.modelName, this.maxRetries, this.temperature, this.topK, this.topP, this.maxOutputTokens, this.timeout, this.responseFormat, this.stopSequences, this.toolConfig, this.allowCodeExecution, this.includeCodeExecutionOutput, this.logRequestsAndResponses, this.safetySettings, this.listeners);
        }

        public String toString() {
            return "GoogleAiGeminiChatModel.GoogleAiGeminiChatModelBuilder(apiKey=" + this.apiKey + ", modelName=" + this.modelName + ", maxRetries=" + this.maxRetries + ", temperature=" + this.temperature + ", topK=" + this.topK + ", topP=" + this.topP + ", maxOutputTokens=" + this.maxOutputTokens + ", timeout=" + String.valueOf(this.timeout) + ", responseFormat=" + String.valueOf(this.responseFormat) + ", stopSequences=" + String.valueOf(this.stopSequences) + ", toolConfig=" + String.valueOf(this.toolConfig) + ", allowCodeExecution=" + this.allowCodeExecution + ", includeCodeExecutionOutput=" + this.includeCodeExecutionOutput + ", logRequestsAndResponses=" + this.logRequestsAndResponses + ", safetySettings=" + String.valueOf(this.safetySettings) + ", listeners=" + String.valueOf(this.listeners) + ")";
        }
    }

    public GoogleAiGeminiChatModel(String str, String str2, Integer num, Double d, Integer num2, Double d2, Integer num3, Duration duration, ResponseFormat responseFormat, List<String> list, GeminiFunctionCallingConfig geminiFunctionCallingConfig, Boolean bool, Boolean bool2, Boolean bool3, List<GeminiSafetySetting> list2, List<ChatModelListener> list3) {
        super(str, str2, d, num2, d2, num3, duration, responseFormat, list, geminiFunctionCallingConfig, bool, bool2, bool3, list2, list3, num);
        this.geminiTokenizer = GoogleAiGeminiTokenizer.builder().modelName(this.modelName).apiKey(this.apiKey).timeout((Duration) Utils.getOrDefault(duration, Duration.ofSeconds(60L))).maxRetries(this.maxRetries).logRequestsAndResponses((Boolean) Utils.getOrDefault(bool3, false)).build();
    }

    public Response<AiMessage> generate(List<ChatMessage> list) {
        ChatResponse chat = chat(ChatRequest.builder().messages(list).build());
        return Response.from(chat.aiMessage(), chat.tokenUsage(), chat.finishReason());
    }

    public Response<AiMessage> generate(List<ChatMessage> list, ToolSpecification toolSpecification) {
        return generate(list, Collections.singletonList(toolSpecification));
    }

    public Response<AiMessage> generate(List<ChatMessage> list, List<ToolSpecification> list2) {
        ChatResponse chat = chat(ChatRequest.builder().messages(list).toolSpecifications(list2).build());
        return Response.from(chat.aiMessage(), chat.tokenUsage(), chat.finishReason());
    }

    public ChatResponse chat(ChatRequest chatRequest) {
        GeminiGenerateContentRequest createGenerateContentRequest = createGenerateContentRequest(chatRequest.messages(), chatRequest.toolSpecifications(), (ResponseFormat) Utils.getOrDefault(chatRequest.responseFormat(), this.responseFormat));
        ChatModelRequest createChatModelRequest = createChatModelRequest(chatRequest.messages(), chatRequest.toolSpecifications());
        ConcurrentHashMap<Object, Object> concurrentHashMap = new ConcurrentHashMap<>();
        notifyListenersOnRequest(new ChatModelRequestContext(createChatModelRequest, concurrentHashMap));
        try {
            return processResponse((GeminiGenerateContentResponse) RetryUtils.withRetry(() -> {
                return this.geminiService.generateContent(this.modelName, this.apiKey, createGenerateContentRequest);
            }, this.maxRetries.intValue()), createChatModelRequest, concurrentHashMap);
        } catch (RuntimeException e) {
            notifyListenersOnError(e, createChatModelRequest, concurrentHashMap);
            throw e;
        }
    }

    private ChatResponse processResponse(GeminiGenerateContentResponse geminiGenerateContentResponse, ChatModelRequest chatModelRequest, ConcurrentHashMap<Object, Object> concurrentHashMap) {
        if (geminiGenerateContentResponse == null) {
            throw new RuntimeException("Gemini response was null");
        }
        GeminiCandidate geminiCandidate = geminiGenerateContentResponse.getCandidates().get(0);
        GeminiUsageMetadata usageMetadata = geminiGenerateContentResponse.getUsageMetadata();
        FinishReason fromGFinishReasonToFinishReason = FinishReasonMapper.fromGFinishReasonToFinishReason(geminiCandidate.getFinishReason());
        AiMessage createAiMessage = createAiMessage(geminiCandidate, fromGFinishReasonToFinishReason);
        TokenUsage createTokenUsage = createTokenUsage(usageMetadata);
        notifyListenersOnResponse(Response.from(createAiMessage, createTokenUsage, fromGFinishReasonToFinishReason), chatModelRequest, concurrentHashMap);
        return ChatResponse.builder().aiMessage(createAiMessage).finishReason(fromGFinishReasonToFinishReason).tokenUsage(createTokenUsage).build();
    }

    private AiMessage createAiMessage(GeminiCandidate geminiCandidate, FinishReason finishReason) {
        return geminiCandidate.getContent() == null ? AiMessage.from("No text was returned by the model. The model finished generating because of the following reason: " + String.valueOf(finishReason)) : PartsAndContentsMapper.fromGPartsToAiMessage(geminiCandidate.getContent().getParts(), this.includeCodeExecutionOutput);
    }

    private TokenUsage createTokenUsage(GeminiUsageMetadata geminiUsageMetadata) {
        return new TokenUsage(geminiUsageMetadata.getPromptTokenCount(), geminiUsageMetadata.getCandidatesTokenCount(), geminiUsageMetadata.getTotalTokenCount());
    }

    public int estimateTokenCount(List<ChatMessage> list) {
        return this.geminiTokenizer.estimateTokenCountInMessages(list);
    }

    public Set<Capability> supportedCapabilities() {
        HashSet hashSet = new HashSet();
        if (this.responseFormat != null && ResponseFormatType.JSON.equals(this.responseFormat.type())) {
            hashSet.add(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
        }
        return hashSet;
    }

    public static GoogleAiGeminiChatModelBuilder builder() {
        return new GoogleAiGeminiChatModelBuilder();
    }
}
