/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.ltr;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorer;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankRescorerContext;
import org.elasticsearch.xpack.ml.inference.ltr.LearningToRankService;

public class LearningToRankRescorerBuilder
extends RescorerBuilder<LearningToRankRescorerBuilder> {
    public static final ParseField NAME = new ParseField("learning_to_rank", new String[0]);
    public static final ParseField MODEL_FIELD = new ParseField("model_id", new String[0]);
    public static final ParseField PARAMS_FIELD = new ParseField("params", new String[0]);
    private static final ObjectParser<Builder, Void> PARSER = new ObjectParser(NAME.getPreferredName(), false, Builder::new);
    private final String modelId;
    private final Map<String, Object> params;
    private final LearningToRankService learningToRankService;
    private final LocalModel localModel;
    private final LearningToRankConfig learningToRankConfig;
    private boolean rescoreOccurred = false;

    public static LearningToRankRescorerBuilder fromXContent(XContentParser parser, LearningToRankService learningToRankService) {
        return ((Builder)PARSER.apply(parser, null)).build(learningToRankService);
    }

    LearningToRankRescorerBuilder(String modelId, Map<String, Object> params, LearningToRankService learningToRankService) {
        this(modelId, null, params, learningToRankService);
    }

    LearningToRankRescorerBuilder(String modelId, LearningToRankConfig learningToRankConfig, Map<String, Object> params, LearningToRankService learningToRankService) {
        this.modelId = modelId;
        this.params = params;
        this.learningToRankConfig = learningToRankConfig;
        this.learningToRankService = learningToRankService;
        this.localModel = null;
    }

    LearningToRankRescorerBuilder(LocalModel localModel, LearningToRankConfig learningToRankConfig, Map<String, Object> params, LearningToRankService learningToRankService) {
        this.modelId = localModel.getModelId();
        this.params = params;
        this.learningToRankConfig = learningToRankConfig;
        this.localModel = localModel;
        this.learningToRankService = learningToRankService;
    }

    public LearningToRankRescorerBuilder(StreamInput input, LearningToRankService learningToRankService) throws IOException {
        super(input);
        this.modelId = input.readString();
        this.params = input.readGenericMap();
        this.learningToRankConfig = (LearningToRankConfig)input.readOptionalNamedWriteable(InferenceConfig.class);
        this.learningToRankService = learningToRankService;
        this.localModel = null;
    }

    public String modelId() {
        return this.modelId;
    }

    public Map<String, Object> params() {
        return this.params;
    }

    public LearningToRankConfig learningToRankConfig() {
        return this.learningToRankConfig;
    }

    public LearningToRankService learningToRankService() {
        return this.learningToRankService;
    }

    public LocalModel localModel() {
        return this.localModel;
    }

    public RescorerBuilder<LearningToRankRescorerBuilder> rewrite(QueryRewriteContext ctx) throws IOException {
        if (ctx.convertToDataRewriteContext() != null) {
            return this.doDataNodeRewrite(ctx);
        }
        if (ctx.convertToSearchExecutionContext() != null) {
            return this.doSearchRewrite(ctx);
        }
        return this.doCoordinatorNodeRewrite(ctx);
    }

    private RescorerBuilder<LearningToRankRescorerBuilder> doCoordinatorNodeRewrite(QueryRewriteContext ctx) throws IOException {
        if (this.learningToRankConfig != null) {
            LearningToRankConfig rewrittenConfig = (LearningToRankConfig)Rewriteable.rewrite((Rewriteable)this.learningToRankConfig, (QueryRewriteContext)ctx);
            if (rewrittenConfig == this.learningToRankConfig) {
                return this;
            }
            LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(this.modelId, rewrittenConfig, this.params, this.learningToRankService);
            if (this.windowSize != null) {
                builder.windowSize(this.windowSize);
            }
            return builder;
        }
        if (this.learningToRankService == null) {
            throw new IllegalStateException("Learning to rank service must be available");
        }
        SetOnce configSetOnce = new SetOnce();
        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(this.modelId);
        request.setAllowNoResources(false);
        ctx.registerAsyncAction((c, l) -> this.learningToRankService.loadLearningToRankConfig(this.modelId, this.params, (ActionListener<LearningToRankConfig>)ActionListener.wrap(learningToRankConfig -> {
            configSetOnce.set(learningToRankConfig);
            l.onResponse(null);
        }, arg_0 -> ((ActionListener)l).onFailure(arg_0))));
        RewritingLearningToRankRescorerBuilder builder = new RewritingLearningToRankRescorerBuilder(rewritingBuilder -> configSetOnce.get() == null ? rewritingBuilder : new LearningToRankRescorerBuilder(this.modelId, (LearningToRankConfig)configSetOnce.get(), this.params, this.learningToRankService));
        if (this.windowSize() != null) {
            builder.windowSize(this.windowSize);
        }
        return builder;
    }

    private RescorerBuilder<LearningToRankRescorerBuilder> doDataNodeRewrite(QueryRewriteContext ctx) throws IOException {
        assert (this.learningToRankConfig != null);
        if (this.localModel != null) {
            return this;
        }
        if (this.learningToRankService == null) {
            throw new IllegalStateException("Learning to rank service must be available");
        }
        LearningToRankConfig rewrittenConfig = (LearningToRankConfig)Rewriteable.rewrite((Rewriteable)this.learningToRankConfig, (QueryRewriteContext)ctx);
        SetOnce localModelSetOnce = new SetOnce();
        ctx.registerAsyncAction((c, l) -> this.learningToRankService.loadLocalModel(this.modelId, (ActionListener<LocalModel>)ActionListener.wrap(lm -> {
            localModelSetOnce.set(lm);
            l.onResponse(null);
        }, arg_0 -> ((ActionListener)l).onFailure(arg_0))));
        RewritingLearningToRankRescorerBuilder builder = new RewritingLearningToRankRescorerBuilder(rewritingBuilder -> localModelSetOnce.get() != null ? new LearningToRankRescorerBuilder((LocalModel)localModelSetOnce.get(), rewrittenConfig, this.params, this.learningToRankService) : rewritingBuilder);
        if (this.windowSize() != null) {
            builder.windowSize(this.windowSize());
        }
        return builder;
    }

    private RescorerBuilder<LearningToRankRescorerBuilder> doSearchRewrite(QueryRewriteContext ctx) throws IOException {
        if (this.learningToRankConfig == null) {
            return this;
        }
        LearningToRankConfig rewrittenConfig = (LearningToRankConfig)Rewriteable.rewrite((Rewriteable)this.learningToRankConfig, (QueryRewriteContext)ctx);
        if (rewrittenConfig == this.learningToRankConfig) {
            return this;
        }
        LearningToRankRescorerBuilder builder = new LearningToRankRescorerBuilder(this.localModel, rewrittenConfig, this.params, this.learningToRankService);
        if (this.windowSize != null) {
            builder.windowSize(this.windowSize);
        }
        return builder;
    }

    protected LearningToRankRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) {
        this.rescoreOccurred = true;
        return new LearningToRankRescorerContext(windowSize, LearningToRankRescorer.INSTANCE, this.learningToRankConfig, this.localModel, context);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersion.current();
    }

    protected boolean isWindowSizeRequired() {
        return true;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        assert (this.localModel == null || this.rescoreOccurred) : "Unnecessarily populated local model object";
        out.writeString(this.modelId);
        out.writeGenericMap(this.params);
        out.writeOptionalNamedWriteable((NamedWriteable)this.learningToRankConfig);
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME.getPreferredName());
        builder.field(MODEL_FIELD.getPreferredName(), this.modelId);
        if (this.params != null) {
            builder.field(PARAMS_FIELD.getPreferredName(), this.params);
        }
        builder.endObject();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        LearningToRankRescorerBuilder that = (LearningToRankRescorerBuilder)((Object)o);
        return Objects.equals(this.modelId, that.modelId) && Objects.equals(this.params, that.params) && Objects.equals(this.learningToRankConfig, that.learningToRankConfig) && Objects.equals(this.localModel, that.localModel) && Objects.equals(this.learningToRankService, that.learningToRankService) && this.rescoreOccurred == that.rescoreOccurred;
    }

    public int hashCode() {
        return Objects.hash(super.hashCode(), this.modelId, this.params, this.learningToRankConfig, this.localModel, this.learningToRankService, this.rescoreOccurred);
    }

    static {
        PARSER.declareString(Builder::setModelId, MODEL_FIELD);
        PARSER.declareObject(Builder::setParams, (p, c) -> p.map(), PARAMS_FIELD);
        PARSER.declareRequiredFieldSet(new String[]{MODEL_FIELD.getPreferredName()});
    }

    static class Builder {
        private String modelId;
        private Map<String, Object> params = null;

        Builder() {
        }

        public void setModelId(String modelId) {
            this.modelId = modelId;
        }

        public void setParams(Map<String, Object> params) {
            this.params = params;
        }

        LearningToRankRescorerBuilder build(LearningToRankService learningToRankService) {
            return new LearningToRankRescorerBuilder(this.modelId, this.params, learningToRankService);
        }
    }

    private static class RewritingLearningToRankRescorerBuilder
    extends LearningToRankRescorerBuilder {
        private final Function<RewritingLearningToRankRescorerBuilder, LearningToRankRescorerBuilder> rewriteFunction;

        RewritingLearningToRankRescorerBuilder(Function<RewritingLearningToRankRescorerBuilder, LearningToRankRescorerBuilder> rewriteFunction) {
            super(null, null, null);
            this.rewriteFunction = rewriteFunction;
        }

        @Override
        public RescorerBuilder<LearningToRankRescorerBuilder> rewrite(QueryRewriteContext ctx) throws IOException {
            LearningToRankRescorerBuilder builder = this.rewriteFunction.apply(this);
            if (this.windowSize() != null) {
                builder.windowSize(this.windowSize());
            }
            return builder;
        }
    }
}

