/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.rescorer;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorer;
import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorerContext;

public class InferenceRescorerBuilder
extends RescorerBuilder<InferenceRescorerBuilder> {
    public static final String NAME = "inference";
    private static final ParseField MODEL = new ParseField("model_id", new String[0]);
    private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config", new String[0]);
    private static final ParseField INTERNAL_INFERENCE_CONFIG = new ParseField("_internal_inference_config", new String[0]);
    private static final ObjectParser<Builder, Void> PARSER = new ObjectParser("inference", false, Builder::new);
    private final String modelId;
    private final LearnToRankConfigUpdate inferenceConfigUpdate;
    private final LearnToRankConfig inferenceConfig;
    private final LocalModel inferenceDefinition;
    private final Supplier<LocalModel> inferenceDefinitionSupplier;
    private final Supplier<ModelLoadingService> modelLoadingServiceSupplier;
    private final Supplier<LearnToRankConfig> inferenceConfigSupplier;
    private boolean rescoreOccurred;

    public static InferenceRescorerBuilder fromXContent(XContentParser parser, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
        return ((Builder)PARSER.apply(parser, null)).build(modelLoadingServiceSupplier);
    }

    public InferenceRescorerBuilder(String modelId, LearnToRankConfigUpdate inferenceConfigUpdate, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
        this.modelId = Objects.requireNonNull(modelId);
        this.inferenceConfigUpdate = inferenceConfigUpdate;
        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
        this.inferenceDefinition = null;
        this.inferenceDefinitionSupplier = null;
        this.inferenceConfigSupplier = null;
        this.inferenceConfig = null;
    }

    InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
        this.modelId = Objects.requireNonNull(modelId);
        this.inferenceConfigUpdate = null;
        this.inferenceDefinition = null;
        this.inferenceDefinitionSupplier = null;
        this.inferenceConfigSupplier = null;
        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
        this.inferenceConfig = Objects.requireNonNull(inferenceConfig);
    }

    private InferenceRescorerBuilder(String modelId, LearnToRankConfigUpdate update, Supplier<ModelLoadingService> modelLoadingServiceSupplier, Supplier<LearnToRankConfig> inferenceConfigSupplier) {
        this.modelId = Objects.requireNonNull(modelId);
        this.inferenceConfigUpdate = update;
        this.inferenceDefinition = null;
        this.inferenceDefinitionSupplier = null;
        this.inferenceConfigSupplier = inferenceConfigSupplier;
        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
        this.inferenceConfig = null;
    }

    private InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, Supplier<ModelLoadingService> modelLoadingServiceSupplier, Supplier<LocalModel> inferenceDefinitionSupplier) {
        this.modelId = modelId;
        this.inferenceConfigUpdate = null;
        this.inferenceDefinition = null;
        this.inferenceDefinitionSupplier = inferenceDefinitionSupplier;
        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
        this.inferenceConfigSupplier = null;
        this.inferenceConfig = inferenceConfig;
    }

    InferenceRescorerBuilder(String modelId, LearnToRankConfig inferenceConfig, LocalModel inferenceDefinition) {
        this.modelId = modelId;
        this.inferenceConfigUpdate = null;
        this.inferenceDefinition = inferenceDefinition;
        this.inferenceDefinitionSupplier = null;
        this.modelLoadingServiceSupplier = null;
        this.inferenceConfigSupplier = null;
        this.inferenceConfig = inferenceConfig;
    }

    public InferenceRescorerBuilder(StreamInput input, Supplier<ModelLoadingService> modelLoadingServiceSupplier) throws IOException {
        super(input);
        this.modelId = input.readString();
        this.inferenceConfigUpdate = (LearnToRankConfigUpdate)input.readOptionalNamedWriteable(InferenceConfigUpdate.class);
        this.inferenceDefinitionSupplier = null;
        this.inferenceConfigSupplier = null;
        this.inferenceDefinition = null;
        this.inferenceConfig = (LearnToRankConfig)input.readOptionalNamedWriteable(InferenceConfig.class);
        this.modelLoadingServiceSupplier = modelLoadingServiceSupplier;
    }

    public String getWriteableName() {
        return NAME;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersion.current();
    }

    private RescorerBuilder<InferenceRescorerBuilder> doRewrite(QueryRewriteContext ctx) throws IOException {
        if (this.inferenceConfigSupplier != null && this.inferenceConfigSupplier.get() == null) {
            return this;
        }
        if (this.inferenceConfig != null) {
            LearnToRankConfig rewrittenConfig = (LearnToRankConfig)Rewriteable.rewrite((Rewriteable)this.inferenceConfig, (QueryRewriteContext)ctx);
            if (rewrittenConfig == this.inferenceConfig) {
                return this;
            }
            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(this.modelId, rewrittenConfig, this.modelLoadingServiceSupplier);
            if (this.windowSize != null) {
                builder.windowSize(this.windowSize);
            }
            return builder;
        }
        if (this.inferenceConfigSupplier != null) {
            LearnToRankConfig rewrittenConfig = (LearnToRankConfig)Rewriteable.rewrite((Rewriteable)this.inferenceConfigSupplier.get(), (QueryRewriteContext)ctx);
            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(this.modelId, rewrittenConfig, this.modelLoadingServiceSupplier);
            if (this.windowSize != null) {
                builder.windowSize(this.windowSize);
            }
            return builder;
        }
        SetOnce configSetOnce = new SetOnce();
        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(this.modelId);
        request.setAllowNoResources(false);
        ctx.registerAsyncAction((c, l) -> ClientHelper.executeAsyncWithOrigin((Client)c, (String)"ml", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)request, (ActionListener)ActionListener.wrap(trainedModels -> {
            TrainedModelConfig config = (TrainedModelConfig)trainedModels.getResources().results().get(0);
            InferenceConfig patt10043$temp = config.getInferenceConfig();
            if (patt10043$temp instanceof LearnToRankConfig) {
                LearnToRankConfig retrievedInferenceConfig = (LearnToRankConfig)patt10043$temp;
                retrievedInferenceConfig = this.inferenceConfigUpdate == null ? retrievedInferenceConfig : this.inferenceConfigUpdate.apply((InferenceConfig)retrievedInferenceConfig);
                for (LearnToRankFeatureExtractorBuilder builder : retrievedInferenceConfig.getFeatureExtractorBuilders()) {
                    builder.validate();
                }
                configSetOnce.set((Object)retrievedInferenceConfig);
                l.onResponse(null);
                return;
            }
            l.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"Inference config of type [{0}] is invalid, must be of type [{1}]", (Object[])new Object[]{Optional.ofNullable(config.getInferenceConfig()).map(NamedXContentObject::getName).orElse("null"), LearnToRankConfig.NAME.getPreferredName()}), (Object[])new Object[0])));
        }, arg_0 -> ((ActionListener)l).onFailure(arg_0))));
        InferenceRescorerBuilder builder = new InferenceRescorerBuilder(this.modelId, this.inferenceConfigUpdate, this.modelLoadingServiceSupplier, () -> ((SetOnce)configSetOnce).get());
        if (this.windowSize() != null) {
            builder.windowSize(this.windowSize);
        }
        return builder;
    }

    private RescorerBuilder<InferenceRescorerBuilder> doDataNodeRewrite(QueryRewriteContext ctx) {
        assert (this.inferenceConfig != null);
        if (this.inferenceDefinition != null) {
            return this;
        }
        if (this.inferenceDefinitionSupplier != null && this.inferenceDefinitionSupplier.get() == null) {
            return this;
        }
        if (this.inferenceDefinitionSupplier != null) {
            LocalModel inferenceDefinition = this.inferenceDefinitionSupplier.get();
            InferenceRescorerBuilder builder = new InferenceRescorerBuilder(this.modelId, this.inferenceConfig, inferenceDefinition);
            if (this.windowSize() != null) {
                builder.windowSize(this.windowSize());
            }
            return builder;
        }
        if (this.modelLoadingServiceSupplier == null || this.modelLoadingServiceSupplier.get() == null) {
            throw new IllegalStateException("Model loading service must be available");
        }
        SetOnce inferenceDefinitionSetOnce = new SetOnce();
        ctx.registerAsyncAction((c, l) -> this.modelLoadingServiceSupplier.get().getModelForLearnToRank(this.modelId, (ActionListener<LocalModel>)ActionListener.wrap(lm -> {
            inferenceDefinitionSetOnce.set(lm);
            l.onResponse(null);
        }, arg_0 -> ((ActionListener)l).onFailure(arg_0))));
        InferenceRescorerBuilder builder = new InferenceRescorerBuilder(this.modelId, this.inferenceConfig, this.modelLoadingServiceSupplier, () -> ((SetOnce)inferenceDefinitionSetOnce).get());
        if (this.windowSize() != null) {
            builder.windowSize(this.windowSize());
        }
        return builder;
    }

    private RescorerBuilder<InferenceRescorerBuilder> doSearchRewrite(QueryRewriteContext ctx) throws IOException {
        InferenceRescorerBuilder builder;
        if (this.inferenceConfig == null) {
            return this;
        }
        LearnToRankConfig rewrittenConfig = (LearnToRankConfig)Rewriteable.rewrite((Rewriteable)this.inferenceConfig, (QueryRewriteContext)ctx);
        if (rewrittenConfig == this.inferenceConfig) {
            return this;
        }
        InferenceRescorerBuilder inferenceRescorerBuilder = builder = this.inferenceDefinition == null ? new InferenceRescorerBuilder(this.modelId, rewrittenConfig, this.modelLoadingServiceSupplier) : new InferenceRescorerBuilder(this.modelId, rewrittenConfig, this.inferenceDefinition);
        if (this.windowSize != null) {
            builder.windowSize(this.windowSize);
        }
        return builder;
    }

    public RescorerBuilder<InferenceRescorerBuilder> rewrite(QueryRewriteContext ctx) throws IOException {
        if (ctx.convertToDataRewriteContext() != null) {
            return this.doDataNodeRewrite(ctx);
        }
        if (ctx.convertToSearchExecutionContext() != null) {
            return this.doSearchRewrite(ctx);
        }
        return this.doRewrite(ctx);
    }

    public String getModelId() {
        return this.modelId;
    }

    LearnToRankConfig getInferenceConfig() {
        return this.inferenceConfig;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        if (this.inferenceDefinitionSupplier != null || this.inferenceConfigSupplier != null) {
            throw new IllegalStateException("suppliers must be null, missing a rewriteAndFetch?");
        }
        assert (this.inferenceDefinition == null || this.rescoreOccurred) : "Unnecessarily populated local model object";
        out.writeString(this.modelId);
        out.writeOptionalNamedWriteable((NamedWriteable)this.inferenceConfigUpdate);
        out.writeOptionalNamedWriteable((NamedWriteable)this.inferenceConfig);
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.field(MODEL.getPreferredName(), this.modelId);
        if (this.inferenceConfigUpdate != null) {
            NamedXContentObjectHelper.writeNamedObject((XContentBuilder)builder, (ToXContent.Params)params, (String)INFERENCE_CONFIG.getPreferredName(), (NamedXContentObject)this.inferenceConfigUpdate);
        }
        if (this.inferenceConfig != null) {
            NamedXContentObjectHelper.writeNamedObject((XContentBuilder)builder, (ToXContent.Params)params, (String)INTERNAL_INFERENCE_CONFIG.getPreferredName(), (NamedXContentObject)this.inferenceConfig);
        }
        builder.endObject();
    }

    protected InferenceRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) {
        this.rescoreOccurred = true;
        return new InferenceRescorerContext(windowSize, InferenceRescorer.INSTANCE, this.inferenceConfig, this.inferenceDefinition, context);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        InferenceRescorerBuilder that = (InferenceRescorerBuilder)((Object)o);
        return Objects.equals(this.modelId, that.modelId) && Objects.equals(this.inferenceDefinition, that.inferenceDefinition) && Objects.equals(this.inferenceConfigUpdate, that.inferenceConfigUpdate) && Objects.equals(this.inferenceConfig, that.inferenceConfig) && Objects.equals(this.inferenceDefinitionSupplier, that.inferenceDefinitionSupplier) && Objects.equals(this.modelLoadingServiceSupplier, that.modelLoadingServiceSupplier);
    }

    public int hashCode() {
        return Objects.hash(super.hashCode(), this.modelId, this.inferenceConfigUpdate, this.inferenceConfig, this.inferenceDefinition, this.inferenceDefinitionSupplier, this.modelLoadingServiceSupplier);
    }

    LearnToRankConfigUpdate getInferenceConfigUpdate() {
        return this.inferenceConfigUpdate;
    }

    Supplier<ModelLoadingService> modelLoadingServiceSupplier() {
        return this.modelLoadingServiceSupplier;
    }

    LocalModel getInferenceDefinition() {
        return this.inferenceDefinition;
    }

    static {
        PARSER.declareString(Builder::setModelId, MODEL);
        PARSER.declareNamedObject(Builder::setInferenceConfigUpdate, (p, c, name) -> (InferenceConfigUpdate)p.namedObject(InferenceConfigUpdate.class, name, (Object)false), INFERENCE_CONFIG);
        PARSER.declareNamedObject(Builder::setInferenceConfig, (p, c, name) -> (StrictlyParsedInferenceConfig)p.namedObject(StrictlyParsedInferenceConfig.class, name, (Object)false), INTERNAL_INFERENCE_CONFIG);
    }

    static class Builder {
        private String modelId;
        private LearnToRankConfigUpdate inferenceConfigUpdate;
        private LearnToRankConfig inferenceConfig;

        Builder() {
        }

        public void setModelId(String modelId) {
            this.modelId = modelId;
        }

        public void setInferenceConfigUpdate(InferenceConfigUpdate inferenceConfigUpdate) {
            if (inferenceConfigUpdate instanceof LearnToRankConfigUpdate) {
                LearnToRankConfigUpdate learnToRankConfigUpdate;
                this.inferenceConfigUpdate = learnToRankConfigUpdate = (LearnToRankConfigUpdate)inferenceConfigUpdate;
                return;
            }
            throw new IllegalArgumentException(Strings.format((String)"[%s] only allows a [%s] object to be configured", (Object[])new Object[]{INFERENCE_CONFIG.getPreferredName(), LearnToRankConfigUpdate.NAME.getPreferredName()}));
        }

        void setInferenceConfig(InferenceConfig inferenceConfig) {
            if (inferenceConfig instanceof LearnToRankConfig) {
                LearnToRankConfig learnToRankConfig;
                this.inferenceConfig = learnToRankConfig = (LearnToRankConfig)inferenceConfig;
                return;
            }
            throw new IllegalArgumentException(Strings.format((String)"[%s] only allows a [%s] object to be configured", (Object[])new Object[]{INFERENCE_CONFIG.getPreferredName(), LearnToRankConfigUpdate.NAME.getPreferredName()}));
        }

        InferenceRescorerBuilder build(Supplier<ModelLoadingService> modelLoadingServiceSupplier) {
            assert (this.inferenceConfig == null || this.inferenceConfigUpdate == null);
            if (this.inferenceConfig != null) {
                return new InferenceRescorerBuilder(this.modelId, this.inferenceConfig, modelLoadingServiceSupplier);
            }
            return new InferenceRescorerBuilder(this.modelId, this.inferenceConfigUpdate, modelLoadingServiceSupplier);
        }
    }
}

