/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.inference.common.InferenceExceptions;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

public class TransportGetInferenceModelAction
extends HandledTransportAction<GetInferenceModelAction.Request, GetInferenceModelAction.Response> {
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final Executor executor;

    @Inject
    public TransportGetInferenceModelAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry) {
        super("cluster:monitor/xpack/inference/get", transportService, actionFilters, GetInferenceModelAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.executor = threadPool.executor("inference_utility");
    }

    protected void doExecute(Task task, GetInferenceModelAction.Request request, ActionListener<GetInferenceModelAction.Response> listener) {
        boolean inferenceEntityIdIsWildCard = Strings.isAllOrWildcard((String)request.getInferenceEntityId());
        if (request.getTaskType() == TaskType.ANY && inferenceEntityIdIsWildCard) {
            this.getAllModels(request.isPersistDefaultConfig(), listener);
        } else if (inferenceEntityIdIsWildCard) {
            this.getModelsByTaskType(request.getTaskType(), listener);
        } else {
            this.getSingleModel(request.getInferenceEntityId(), request.getTaskType(), listener);
        }
    }

    private void getSingleModel(String inferenceEntityId, TaskType requestedTaskType, ActionListener<GetInferenceModelAction.Response> listener) {
        this.modelRegistry.getModel(inferenceEntityId, (ActionListener<UnparsedModel>)listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
            Optional service = this.serviceRegistry.getService(unparsedModel.service());
            if (service.isEmpty()) {
                delegate.onFailure((Exception)((Object)this.serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId())));
                return;
            }
            if (!requestedTaskType.isAnyOrSame(unparsedModel.taskType())) {
                delegate.onFailure((Exception)((Object)InferenceExceptions.mismatchedTaskTypeException(requestedTaskType, unparsedModel.taskType())));
                return;
            }
            Model model = ((InferenceService)service.get()).parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
            ((InferenceService)service.get()).updateModelsWithDynamicFields(List.of(model), delegate.delegateFailureAndWrap((l2, updatedModels) -> l2.onResponse((Object)new GetInferenceModelAction.Response(updatedModels.stream().map(Model::getConfigurations).collect(Collectors.toList())))));
        }));
    }

    private void getAllModels(boolean persistDefaultEndpoints, ActionListener<GetInferenceModelAction.Response> listener) {
        this.modelRegistry.getAllModels(persistDefaultEndpoints, (ActionListener<List<UnparsedModel>>)listener.delegateFailureAndWrap((l, models) -> this.executor.execute(() -> this.parseModels((List<UnparsedModel>)models, (ActionListener<GetInferenceModelAction.Response>)l))));
    }

    private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceModelAction.Response> listener) {
        this.modelRegistry.getModelsByTaskType(taskType, (ActionListener<List<UnparsedModel>>)listener.delegateFailureAndWrap((l, models) -> this.executor.execute(() -> this.parseModels((List<UnparsedModel>)models, listener))));
    }

    private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
        if (unparsedModels.isEmpty()) {
            listener.onResponse((Object)new GetInferenceModelAction.Response(List.of()));
            return;
        }
        HashMap<String, List> parsedModelsByService = new HashMap<String, List>();
        try {
            for (UnparsedModel unparsedModel : unparsedModels) {
                Optional service = this.serviceRegistry.getService(unparsedModel.service());
                if (service.isEmpty()) {
                    throw this.serviceNotFoundException(unparsedModel.service(), unparsedModel.inferenceEntityId());
                }
                List list = parsedModelsByService.computeIfAbsent(((InferenceService)service.get()).name(), s -> new ArrayList());
                list.add(((InferenceService)service.get()).parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings()));
            }
            GroupedActionListener groupedListener = new GroupedActionListener(parsedModelsByService.entrySet().size(), listener.delegateFailureAndWrap((delegate, listOfListOfModels) -> {
                ArrayList<Model> modifiable = new ArrayList<Model>();
                for (List l : listOfListOfModels) {
                    modifiable.addAll(l);
                }
                modifiable.sort(Comparator.comparing(Model::getInferenceEntityId));
                delegate.onResponse((Object)new GetInferenceModelAction.Response(modifiable.stream().map(Model::getConfigurations).collect(Collectors.toList())));
            }));
            for (Map.Entry entry : parsedModelsByService.entrySet()) {
                ((InferenceService)this.serviceRegistry.getService((String)entry.getKey()).get()).updateModelsWithDynamicFields((List)entry.getValue(), (ActionListener)groupedListener);
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private ElasticsearchStatusException serviceNotFoundException(String service, String inferenceId) {
        throw new ElasticsearchStatusException("Unknown service [{}] for inference endpoint [{}]. ", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{service, inferenceId});
    }
}

