/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services;

import java.io.Closeable;
import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;

public abstract class SenderService
implements InferenceService {
    protected static final Set<TaskType> COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION);
    private final Sender sender;
    private final ServiceComponents serviceComponents;
    private final ClusterService clusterService;

    public SenderService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ClusterService clusterService) {
        Objects.requireNonNull(factory);
        this.sender = factory.createSender();
        this.serviceComponents = Objects.requireNonNull(serviceComponents);
        this.clusterService = Objects.requireNonNull(clusterService);
    }

    public Sender getSender() {
        return this.sender;
    }

    protected ServiceComponents getServiceComponents() {
        return this.serviceComponents;
    }

    public void infer(Model model, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, List<String> input, boolean stream, Map<String, Object> taskSettings, InputType inputType, @Nullable TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        SubscribableListener.newForked(this::init).andThen(inferListener -> {
            TimeValue resolvedInferenceTimeout = ServiceUtils.resolveInferenceTimeout(timeout, inputType, this.clusterService);
            InferenceInputs inferenceInput = SenderService.createInput(this, model, input, inputType, query, returnDocuments, topN, stream);
            this.doInfer(model, inferenceInput, taskSettings, resolvedInferenceTimeout, (ActionListener<InferenceServiceResults>)inferListener);
        }).addListener(listener);
    }

    private static InferenceInputs createInput(SenderService service, Model model, List<String> input, InputType inputType, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, boolean stream) {
        return switch (model.getTaskType()) {
            case TaskType.COMPLETION, TaskType.CHAT_COMPLETION -> new ChatCompletionInput(input, stream);
            case TaskType.RERANK -> {
                ValidationException validationException = new ValidationException();
                service.validateRerankParameters(returnDocuments, topN, validationException);
                if (!validationException.validationErrors().isEmpty()) {
                    throw validationException;
                }
                yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream);
            }
            case TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING -> {
                ValidationException validationException = new ValidationException();
                service.validateInputType(inputType, model, validationException);
                if (!validationException.validationErrors().isEmpty()) {
                    throw validationException;
                }
                yield new EmbeddingsInput(input, inputType, stream);
            }
            default -> throw new ElasticsearchStatusException(Strings.format((String)"Invalid task type received when determining input type: [%s]", (Object[])new Object[]{model.getTaskType().toString()}), RestStatus.BAD_REQUEST, new Object[0]);
        };
    }

    public void unifiedCompletionInfer(Model model, UnifiedCompletionRequest request, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        SubscribableListener.newForked(this::init).andThen(completionInferListener -> this.doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, (ActionListener<InferenceServiceResults>)completionInferListener)).addListener(listener);
    }

    public void chunkedInfer(Model model, @Nullable String query, List<ChunkInferenceInput> input, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<List<ChunkedInference>> listener) {
        SubscribableListener.newForked(this::init).andThen(chunkedInferListener -> {
            ValidationException validationException = new ValidationException();
            this.validateInputType(inputType, model, validationException);
            if (!validationException.validationErrors().isEmpty()) {
                throw validationException;
            }
            this.doChunkedInfer(model, input, taskSettings, inputType, timeout, (ActionListener<List<ChunkedInference>>)chunkedInferListener);
        }).addListener(listener);
    }

    protected abstract void doInfer(Model var1, InferenceInputs var2, Map<String, Object> var3, TimeValue var4, ActionListener<InferenceServiceResults> var5);

    protected abstract void validateInputType(InputType var1, Model var2, ValidationException var3);

    protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
    }

    protected abstract void doUnifiedCompletionInfer(Model var1, UnifiedChatInput var2, TimeValue var3, ActionListener<InferenceServiceResults> var4);

    protected abstract void doChunkedInfer(Model var1, List<ChunkInferenceInput> var2, Map<String, Object> var3, InputType var4, TimeValue var5, ActionListener<List<ChunkedInference>> var6);

    public void start(Model model, ActionListener<Boolean> listener) {
        SubscribableListener.newForked(this::init).andThen(doStartListener -> this.doStart(model, (ActionListener<Boolean>)doStartListener)).addListener(listener);
    }

    public void start(Model model, @Nullable TimeValue unused, ActionListener<Boolean> listener) {
        this.start(model, listener);
    }

    protected void doStart(Model model, ActionListener<Boolean> listener) {
        listener.onResponse((Object)true);
    }

    private void init(ActionListener<Void> listener) {
        this.sender.startAsynchronously(listener);
    }

    public void close() throws IOException {
        IOUtils.closeWhileHandlingException((Closeable)this.sender);
    }
}

