/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.inference;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction;
import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest;
import org.elasticsearch.xpack.core.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.core.security.support.Exceptions;
import org.elasticsearch.xpack.ml.aggs.inference.InferencePipelineAggregator;
import org.elasticsearch.xpack.ml.aggs.inference.InternalInferenceAggregation;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils;

public class InferencePipelineAggregationBuilder
extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
    public static final String NAME = "inference";
    public static final ParseField MODEL_ID = new ParseField("model_id", new String[0]);
    private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config", new String[0]);
    static final String AGGREGATIONS_RESULTS_FIELD = "value";
    private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, ParserSupplement> PARSER = new ConstructingObjectParser("inference", false, (args, context) -> new InferencePipelineAggregationBuilder(context.name, context.modelLoadingService, context.licenseState, context.settings, (Map)args[0]));
    private final Map<String, String> bucketPathMap;
    private String modelId;
    private InferenceConfigUpdate inferenceConfig;
    private final XPackLicenseState licenseState;
    private final Settings settings;
    private final SetOnce<ModelLoadingService> modelLoadingService;
    private final Supplier<LocalModel> model;

    public static SearchPlugin.PipelineAggregationSpec buildSpec(SetOnce<ModelLoadingService> modelLoadingService, XPackLicenseState xPackLicenseState, Settings settings) {
        SearchPlugin.PipelineAggregationSpec spec = new SearchPlugin.PipelineAggregationSpec(NAME, in -> new InferencePipelineAggregationBuilder(in, xPackLicenseState, settings, modelLoadingService), (parser, name) -> InferencePipelineAggregationBuilder.parse(modelLoadingService, xPackLicenseState, settings, name, parser));
        spec.addResultReader(InternalInferenceAggregation::new);
        return spec;
    }

    public static InferencePipelineAggregationBuilder parse(SetOnce<ModelLoadingService> modelLoadingService, XPackLicenseState licenseState, Settings settings, String pipelineAggregatorName, XContentParser parser) {
        return (InferencePipelineAggregationBuilder)((Object)PARSER.apply(parser, (Object)new ParserSupplement(pipelineAggregatorName, licenseState, settings, modelLoadingService)));
    }

    public InferencePipelineAggregationBuilder(String name, SetOnce<ModelLoadingService> modelLoadingService, XPackLicenseState licenseState, Settings settings, Map<String, String> bucketsPath) {
        super(name, NAME, new TreeMap<String, String>(bucketsPath).values().toArray(new String[0]));
        this.modelLoadingService = modelLoadingService;
        this.bucketPathMap = bucketsPath;
        this.model = null;
        this.licenseState = licenseState;
        this.settings = settings;
    }

    public InferencePipelineAggregationBuilder(StreamInput in, XPackLicenseState licenseState, Settings settings, SetOnce<ModelLoadingService> modelLoadingService) throws IOException {
        super(in, NAME);
        this.modelId = in.readString();
        this.bucketPathMap = in.readMap(StreamInput::readString, StreamInput::readString);
        this.inferenceConfig = (InferenceConfigUpdate)in.readOptionalNamedWriteable(InferenceConfigUpdate.class);
        this.modelLoadingService = modelLoadingService;
        this.model = null;
        this.licenseState = licenseState;
        this.settings = settings;
    }

    private InferencePipelineAggregationBuilder(String name, Map<String, String> bucketsPath, Supplier<LocalModel> model, String modelId, InferenceConfigUpdate inferenceConfig, XPackLicenseState licenseState, Settings settings) {
        super(name, NAME, new TreeMap<String, String>(bucketsPath).values().toArray(new String[0]));
        this.modelLoadingService = null;
        this.bucketPathMap = bucketsPath;
        this.model = model;
        this.modelId = modelId;
        this.inferenceConfig = inferenceConfig;
        this.licenseState = licenseState;
        this.settings = settings;
    }

    public void setModelId(String modelId) {
        this.modelId = modelId;
    }

    public void setInferenceConfig(InferenceConfigUpdate inferenceConfig) {
        this.inferenceConfig = inferenceConfig;
    }

    protected void validate(PipelineAggregationBuilder.ValidationContext context) {
        context.validateHasParent(NAME, this.name);
        if (this.modelId == null) {
            context.addValidationError("[model_id] must be set");
        }
        if (this.inferenceConfig != null) {
            ClassificationConfigUpdate classUpdate;
            String topClassesField;
            InferenceConfigUpdate inferenceConfigUpdate;
            String resultsField = this.inferenceConfig.getResultsField();
            if (!Strings.isNullOrEmpty((String)resultsField) && !AGGREGATIONS_RESULTS_FIELD.equals(resultsField)) {
                context.addValidationError("setting option [" + ClassificationConfig.RESULTS_FIELD.getPreferredName() + "] to [" + resultsField + "] is not valid for inference aggregations");
            }
            if ((inferenceConfigUpdate = this.inferenceConfig) instanceof ClassificationConfigUpdate && !Strings.isNullOrEmpty((String)(topClassesField = (classUpdate = (ClassificationConfigUpdate)inferenceConfigUpdate).getTopClassesResultsField())) && !"top_classes".equals(topClassesField)) {
                context.addValidationError("setting option [top_classes] to [" + topClassesField + "] is not valid for inference aggregations");
            }
        }
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.modelId);
        out.writeMap(this.bucketPathMap, StreamOutput::writeString, StreamOutput::writeString);
        out.writeOptionalNamedWriteable((NamedWriteable)this.inferenceConfig);
    }

    public InferencePipelineAggregationBuilder rewrite(QueryRewriteContext context) {
        if (this.model != null) {
            return this;
        }
        SetOnce loadedModel = new SetOnce();
        BiConsumer<Client, ActionListener> modelLoadAction = (client, listener) -> ((ModelLoadingService)this.modelLoadingService.get()).getModelForSearch(this.modelId, (ActionListener<LocalModel>)listener.delegateFailure((delegate, localModel) -> {
            boolean isLicensed;
            loadedModel.set(localModel);
            boolean bl = isLicensed = localModel.getLicenseLevel() == License.OperationMode.BASIC || MachineLearningField.ML_API_FEATURE.check(this.licenseState);
            if (isLicensed) {
                delegate.onResponse(null);
            } else {
                delegate.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
            }
        }));
        context.registerAsyncAction((client, listener) -> {
            if (((Boolean)XPackSettings.SECURITY_ENABLED.get(this.settings)).booleanValue()) {
                SecurityContext securityContext = new SecurityContext(this.settings, client.threadPool().getThreadContext());
                SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable(securityContext, () -> {
                    String username = securityContext.getUser().principal();
                    HasPrivilegesRequest privRequest = new HasPrivilegesRequest();
                    privRequest.username(username);
                    privRequest.clusterPrivileges(new String[]{"cluster:monitor/xpack/ml/inference/get"});
                    privRequest.indexPrivileges(new RoleDescriptor.IndicesPrivileges[0]);
                    privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[0]);
                    ActionListener privResponseListener = ActionListener.wrap(r -> {
                        if (r.isCompleteMatch()) {
                            modelLoadAction.accept((Client)client, (ActionListener)listener);
                        } else {
                            listener.onFailure((Exception)Exceptions.authorizationError((String)("user [" + username + "] does not have the privilege to get trained models so cannot use ml inference"), (Object[])new Object[0]));
                        }
                    }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
                    client.execute((ActionType)HasPrivilegesAction.INSTANCE, (ActionRequest)privRequest, privResponseListener);
                });
            } else {
                modelLoadAction.accept((Client)client, (ActionListener)listener);
            }
        });
        return new InferencePipelineAggregationBuilder(this.name, this.bucketPathMap, () -> ((SetOnce)loadedModel).get(), this.modelId, this.inferenceConfig, this.licenseState, this.settings);
    }

    protected PipelineAggregator createInternal(Map<String, Object> metaData) {
        if (this.model == null) {
            throw new IllegalStateException("model must be null, missing rewrite?");
        }
        InferenceConfigUpdate update = InferencePipelineAggregationBuilder.adaptForAggregation(this.inferenceConfig);
        return new InferencePipelineAggregator(this.name, this.bucketPathMap, metaData, update, this.model.get());
    }

    static InferenceConfigUpdate adaptForAggregation(InferenceConfigUpdate originalUpdate) {
        Object updated = originalUpdate == null ? new ResultsFieldUpdate(AGGREGATIONS_RESULTS_FIELD) : originalUpdate.newBuilder().setResultsField(AGGREGATIONS_RESULTS_FIELD).build();
        return updated;
    }

    protected boolean overrideBucketsPath() {
        return true;
    }

    protected XContentBuilder internalXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.field(MODEL_ID.getPreferredName(), this.modelId);
        builder.field(BUCKETS_PATH_FIELD.getPreferredName(), this.bucketPathMap);
        if (this.inferenceConfig != null) {
            builder.startObject(INFERENCE_CONFIG.getPreferredName());
            builder.field(this.inferenceConfig.getName(), (Object)this.inferenceConfig);
            builder.endObject();
        }
        return builder;
    }

    public String getWriteableName() {
        return NAME;
    }

    public int hashCode() {
        return Objects.hash(super.hashCode(), this.bucketPathMap, this.modelId, this.inferenceConfig);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || ((Object)((Object)this)).getClass() != obj.getClass()) {
            return false;
        }
        if (!super.equals(obj)) {
            return false;
        }
        InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder)((Object)obj);
        return Objects.equals(this.bucketPathMap, other.bucketPathMap) && Objects.equals(this.modelId, other.modelId) && Objects.equals(this.inferenceConfig, other.inferenceConfig);
    }

    static {
        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
        PARSER.declareString(InferencePipelineAggregationBuilder::setModelId, MODEL_ID);
        PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig, (p, c, n) -> (InferenceConfigUpdate)p.namedObject(InferenceConfigUpdate.class, n, c), INFERENCE_CONFIG);
    }

    private static class ParserSupplement {
        final XPackLicenseState licenseState;
        final Settings settings;
        final SetOnce<ModelLoadingService> modelLoadingService;
        final String name;

        ParserSupplement(String name, XPackLicenseState licenseState, Settings settings, SetOnce<ModelLoadingService> modelLoadingService) {
            this.name = name;
            this.licenseState = licenseState;
            this.settings = settings;
            this.modelLoadingService = modelLoadingService;
        }
    }
}

