/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.ingest;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.ingest.AbstractProcessor;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor;

public class InferenceProcessor
extends AbstractProcessor {
    public static final Setting<Integer> MAX_INFERENCE_PROCESSORS = Setting.intSetting((String)"xpack.ml.max_inference_processors", (int)50, (int)1, (Setting.Property[])new Setting.Property[]{Setting.Property.Dynamic, Setting.Property.NodeScope});
    public static final String TYPE = "inference";
    public static final String INFERENCE_CONFIG = "inference_config";
    public static final String TARGET_FIELD = "target_field";
    public static final String FIELD_MAPPINGS = "field_mappings";
    public static final String FIELD_MAP = "field_map";
    private static final String DEFAULT_TARGET_FIELD = "ml.inference";
    private final Client client;
    private final String modelId;
    private final String targetField;
    private final InferenceConfigUpdate inferenceConfig;
    private final Map<String, String> fieldMap;
    private final InferenceAuditor auditor;
    private volatile boolean previouslyLicensed;
    private final AtomicBoolean shouldAudit = new AtomicBoolean(true);

    public InferenceProcessor(Client client, InferenceAuditor auditor, String tag, String description, String targetField, String modelId, InferenceConfigUpdate inferenceConfig, Map<String, String> fieldMap) {
        super(tag, description);
        this.client = (Client)ExceptionsHelper.requireNonNull((Object)client, (String)"client");
        this.targetField = (String)ExceptionsHelper.requireNonNull((Object)targetField, (String)TARGET_FIELD);
        this.auditor = (InferenceAuditor)((Object)ExceptionsHelper.requireNonNull((Object)((Object)auditor), (String)"auditor"));
        this.modelId = (String)ExceptionsHelper.requireNonNull((Object)modelId, (String)"model_id");
        this.inferenceConfig = (InferenceConfigUpdate)ExceptionsHelper.requireNonNull((Object)inferenceConfig, (String)INFERENCE_CONFIG);
        this.fieldMap = (Map)ExceptionsHelper.requireNonNull(fieldMap, (String)FIELD_MAP);
    }

    public String getModelId() {
        return this.modelId;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)InferModelAction.INSTANCE, (ActionRequest)this.buildRequest(ingestDocument), (ActionListener)ActionListener.wrap(r -> this.handleResponse((InferModelAction.Response)r, ingestDocument, handler), e -> handler.accept(ingestDocument, (Exception)e)));
    }

    void handleResponse(InferModelAction.Response response, IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        if (!this.previouslyLicensed) {
            this.previouslyLicensed = true;
        }
        if (!response.isLicensed()) {
            this.auditWarningAboutLicenseIfNecessary();
        }
        try {
            this.mutateDocument(response, ingestDocument);
            handler.accept(ingestDocument, null);
        }
        catch (ElasticsearchException ex) {
            handler.accept(ingestDocument, (Exception)((Object)ex));
        }
    }

    InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
        HashMap<String, Object> fields = new HashMap<String, Object>(ingestDocument.getSourceAndMetadata());
        if (!ingestDocument.getIngestMetadata().isEmpty()) {
            fields.put("_ingest", ingestDocument.getIngestMetadata());
        }
        LocalModel.mapFieldsIfNecessary(fields, this.fieldMap);
        return InferModelAction.Request.forIngestDocs((String)this.modelId, List.of(fields), (InferenceConfigUpdate)this.inferenceConfig, (boolean)this.previouslyLicensed);
    }

    void auditWarningAboutLicenseIfNecessary() {
        if (this.shouldAudit.compareAndSet(true, false)) {
            this.auditor.warning(this.modelId, "This cluster is no longer licensed to use this model in the inference ingest processor. Please update your license information.");
        }
    }

    void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) {
        if (response.getInferenceResults().isEmpty()) {
            throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        assert (response.getInferenceResults().size() == 1);
        InferenceResults.writeResult((InferenceResults)((InferenceResults)response.getInferenceResults().get(0)), (IngestDocument)ingestDocument, (String)this.targetField, (String)(response.getId() != null ? response.getId() : this.modelId));
    }

    public boolean isAsync() {
        return true;
    }

    public String getType() {
        return TYPE;
    }

    public static final class Factory
    implements Processor.Factory,
    Consumer<ClusterState> {
        private static final Logger logger = LogManager.getLogger(Factory.class);
        private final Client client;
        private final InferenceAuditor auditor;
        private volatile int currentInferenceProcessors;
        private volatile int maxIngestProcessors;
        private volatile MlConfigVersion minNodeVersion = MlConfigVersion.CURRENT;

        public Factory(Client client, ClusterService clusterService, Settings settings, boolean includeNodeInfo) {
            this.client = client;
            this.maxIngestProcessors = (Integer)MAX_INFERENCE_PROCESSORS.get(settings);
            this.auditor = new InferenceAuditor(client, clusterService, includeNodeInfo);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
        }

        @Override
        public void accept(ClusterState state) {
            this.minNodeVersion = MlConfigVersion.getMinMlConfigVersion((DiscoveryNodes)state.nodes());
            try {
                this.currentInferenceProcessors = InferenceProcessorInfoExtractor.countInferenceProcessors(state);
            }
            catch (Exception ex) {
                logger.debug("failed gathering processors for pipelines", (Throwable)ex);
            }
        }

        public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, String description, Map<String, Object> config) {
            InferenceConfigUpdate inferenceConfigUpdate;
            Map inferenceConfigMap;
            if (this.maxIngestProcessors <= this.currentInferenceProcessors) {
                throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. Adjust the setting [{}]: [{}] if a greater number is desired.", RestStatus.CONFLICT, new Object[]{this.currentInferenceProcessors, MAX_INFERENCE_PROCESSORS.getKey(), this.maxIngestProcessors});
            }
            String modelId = ConfigurationUtils.readStringProperty((String)InferenceProcessor.TYPE, (String)tag, config, (String)"model_id");
            Object defaultTargetField = tag == null ? InferenceProcessor.DEFAULT_TARGET_FIELD : "ml.inference." + tag;
            String targetField = ConfigurationUtils.readStringProperty((String)InferenceProcessor.TYPE, (String)tag, config, (String)InferenceProcessor.TARGET_FIELD, (String)defaultTargetField);
            Map fieldMap = ConfigurationUtils.readOptionalMap((String)InferenceProcessor.TYPE, (String)tag, config, (String)InferenceProcessor.FIELD_MAP);
            if (fieldMap == null && (fieldMap = ConfigurationUtils.readOptionalMap((String)InferenceProcessor.TYPE, (String)tag, config, (String)InferenceProcessor.FIELD_MAPPINGS)) != null) {
                LoggingDeprecationHandler.INSTANCE.logRenamedField(null, () -> null, InferenceProcessor.FIELD_MAPPINGS, InferenceProcessor.FIELD_MAP);
            }
            if (fieldMap == null) {
                fieldMap = Collections.emptyMap();
            }
            if ((inferenceConfigMap = ConfigurationUtils.readOptionalMap((String)InferenceProcessor.TYPE, (String)tag, config, (String)InferenceProcessor.INFERENCE_CONFIG)) == null) {
                if (this.minNodeVersion.before((VersionId)EmptyConfigUpdate.minimumSupportedVersion())) {
                    throw ConfigurationUtils.newConfigurationException((String)InferenceProcessor.TYPE, (String)tag, (String)InferenceProcessor.INFERENCE_CONFIG, (String)"required property is missing");
                }
                inferenceConfigUpdate = new EmptyConfigUpdate();
            } else {
                inferenceConfigUpdate = this.inferenceConfigUpdateFromMap(inferenceConfigMap);
            }
            return new InferenceProcessor(this.client, this.auditor, tag, description, targetField, modelId, inferenceConfigUpdate, fieldMap);
        }

        void setMaxIngestProcessors(int maxIngestProcessors) {
            logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", (Object)this.maxIngestProcessors, (Object)maxIngestProcessors);
            this.maxIngestProcessors = maxIngestProcessors;
        }

        InferenceConfigUpdate inferenceConfigUpdateFromMap(Map<String, Object> configMap) {
            ExceptionsHelper.requireNonNull(configMap, (String)InferenceProcessor.INFERENCE_CONFIG);
            if (configMap.size() != 1) {
                throw ExceptionsHelper.badRequestException((String)"{} must be an object with one inference type mapped to an object.", (Object[])new Object[]{InferenceProcessor.INFERENCE_CONFIG});
            }
            Object value = configMap.values().iterator().next();
            if (!(value instanceof Map)) {
                throw ExceptionsHelper.badRequestException((String)"{} must be an object with one inference type mapped to an object.", (Object[])new Object[]{InferenceProcessor.INFERENCE_CONFIG});
            }
            Map valueMap = (Map)value;
            if (configMap.containsKey(ClassificationConfig.NAME.getPreferredName())) {
                this.checkSupportedVersion((InferenceConfig)ClassificationConfig.EMPTY_PARAMS);
                return ClassificationConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("fill_mask")) {
                this.checkNlpSupported("fill_mask");
                return FillMaskConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("ner")) {
                this.checkNlpSupported("ner");
                return NerConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("pass_through")) {
                this.checkNlpSupported("pass_through");
                return PassThroughConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey(RegressionConfig.NAME.getPreferredName())) {
                this.checkSupportedVersion((InferenceConfig)RegressionConfig.EMPTY_PARAMS);
                return RegressionConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("text_classification")) {
                this.checkNlpSupported("text_classification");
                return TextClassificationConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("text_embedding")) {
                this.checkNlpSupported("text_embedding");
                return TextEmbeddingConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("text_expansion")) {
                this.checkNlpSupported("text_expansion");
                return TextExpansionConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("text_similarity")) {
                this.checkNlpSupported("text_similarity");
                return TextSimilarityConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("zero_shot_classification")) {
                this.checkNlpSupported("zero_shot_classification");
                return ZeroShotClassificationConfigUpdate.fromMap((Map)valueMap);
            }
            if (configMap.containsKey("question_answering")) {
                this.checkNlpSupported("question_answering");
                return QuestionAnsweringConfigUpdate.fromMap((Map)valueMap);
            }
            throw ExceptionsHelper.badRequestException((String)"unrecognized inference configuration type {}. Supported types {}", (Object[])new Object[]{configMap.keySet(), List.of(ClassificationConfig.NAME.getPreferredName(), RegressionConfig.NAME.getPreferredName(), "fill_mask", "ner", "pass_through", "question_answering", "text_classification", "text_embedding", "text_expansion", "text_similarity", "zero_shot_classification")});
        }

        void checkNlpSupported(String taskType) {
            if (NlpConfig.MINIMUM_NLP_SUPPORTED_VERSION.after((VersionId)this.minNodeVersion)) {
                throw ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]", (Object[])new Object[]{taskType, NlpConfig.MINIMUM_NLP_SUPPORTED_VERSION, this.minNodeVersion}), (Object[])new Object[0]);
            }
        }

        void checkSupportedVersion(InferenceConfig config) {
            if (config.getMinimalSupportedMlConfigVersion().after((VersionId)this.minNodeVersion)) {
                throw ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]", (Object[])new Object[]{config.getName(), config.getMinimalSupportedMlConfigVersion(), this.minNodeVersion}), (Object[])new Object[0]);
            }
        }
    }
}

