package dev.langchain4j.model.huggingface;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.client.Options;
import dev.langchain4j.model.huggingface.client.Parameters;
import dev.langchain4j.model.huggingface.client.TextGenerationRequest;
import dev.langchain4j.model.huggingface.spi.HuggingFaceChatModelBuilderFactory;
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.spi.ServiceHelper;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/model/huggingface/HuggingFaceChatModel.class */
public class HuggingFaceChatModel implements ChatLanguageModel {
    private final HuggingFaceClient client;
    private final Double temperature;
    private final Integer maxNewTokens;
    private final Boolean returnFullText;
    private final Boolean waitForModel;

    /* loaded from: input_file:dev/langchain4j/model/huggingface/HuggingFaceChatModel$Builder.class */
    public static final class Builder {
        private String accessToken;
        private Double temperature;
        private Integer maxNewTokens;
        private String modelId = HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT;
        private Duration timeout = Duration.ofSeconds(15);
        private Boolean returnFullText = false;
        private Boolean waitForModel = true;

        public Builder accessToken(String str) {
            this.accessToken = str;
            return this;
        }

        public Builder modelId(String str) {
            if (str != null) {
                this.modelId = str;
            }
            return this;
        }

        public Builder timeout(Duration duration) {
            if (duration != null) {
                this.timeout = duration;
            }
            return this;
        }

        public Builder temperature(Double d) {
            this.temperature = d;
            return this;
        }

        public Builder maxNewTokens(Integer num) {
            this.maxNewTokens = num;
            return this;
        }

        public Builder returnFullText(Boolean bool) {
            if (bool != null) {
                this.returnFullText = bool;
            }
            return this;
        }

        public Builder waitForModel(Boolean bool) {
            if (bool != null) {
                this.waitForModel = bool;
            }
            return this;
        }

        public HuggingFaceChatModel build() {
            if (Utils.isNullOrBlank(this.accessToken)) {
                throw new IllegalArgumentException("HuggingFace access token must be defined. It can be generated here: https://huggingface.co/settings/tokens");
            }
            return new HuggingFaceChatModel(this);
        }
    }

    public HuggingFaceChatModel(String str, String str2, Duration duration, Double d, Integer num, Boolean bool, Boolean bool2) {
        this(builder().accessToken(str).modelId(str2).timeout(duration).temperature(d).maxNewTokens(num).returnFullText(bool).waitForModel(bool2));
    }

    public HuggingFaceChatModel(final Builder builder) {
        this.client = FactoryCreator.FACTORY.create(new HuggingFaceClientFactory.Input() { // from class: dev.langchain4j.model.huggingface.HuggingFaceChatModel.1
            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public String apiKey() {
                return builder.accessToken;
            }

            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public String modelId() {
                return builder.modelId;
            }

            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public Duration timeout() {
                return builder.timeout;
            }
        });
        this.temperature = builder.temperature;
        this.maxNewTokens = builder.maxNewTokens;
        this.returnFullText = builder.returnFullText;
        this.waitForModel = builder.waitForModel;
    }

    public Response<AiMessage> generate(List<ChatMessage> list) {
        return Response.from(AiMessage.from(this.client.chat(TextGenerationRequest.builder().inputs((String) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.joining("\n"))).parameters(Parameters.builder().temperature(this.temperature).maxNewTokens(this.maxNewTokens).returnFullText(this.returnFullText).build()).options(Options.builder().waitForModel(this.waitForModel).build()).build()).generatedText()));
    }

    public static Builder builder() {
        Iterator it = ServiceHelper.loadFactories(HuggingFaceChatModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((HuggingFaceChatModelBuilderFactory) it.next()).get() : new Builder();
    }

    public static HuggingFaceChatModel withAccessToken(String str) {
        return builder().accessToken(str).build();
    }
}
