/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plan.logical;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiConsumer;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.index.IndexMode;
import org.elasticsearch.xpack.core.enrich.EnrichPolicy;
import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware;
import org.elasticsearch.xpack.esql.capabilities.TelemetryAware;
import org.elasticsearch.xpack.esql.common.Failure;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.capabilities.Resolvables;
import org.elasticsearch.xpack.esql.core.expression.Alias;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.AttributeSet;
import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute;
import org.elasticsearch.xpack.esql.core.expression.Expression;
import org.elasticsearch.xpack.esql.core.expression.Expressions;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.core.expression.NameId;
import org.elasticsearch.xpack.esql.core.expression.NamedExpression;
import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute;
import org.elasticsearch.xpack.esql.core.tree.Node;
import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
import org.elasticsearch.xpack.esql.core.tree.Source;
import org.elasticsearch.xpack.esql.expression.NamedExpressions;
import org.elasticsearch.xpack.esql.index.EsIndex;
import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
import org.elasticsearch.xpack.esql.plan.GeneratingPlan;
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.Fork;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.plan.logical.SortAgnostic;
import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan;
import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin;

public class Enrich
extends UnaryPlan
implements GeneratingPlan<Enrich>,
PostAnalysisPlanVerificationAware,
TelemetryAware,
SortAgnostic {
    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(LogicalPlan.class, "Enrich", Enrich::readFrom);
    private final Expression policyName;
    private final NamedExpression matchField;
    private final EnrichPolicy policy;
    private final Map<String, String> concreteIndices;
    private final List<NamedExpression> enrichFields;
    private List<Attribute> output;
    private final Mode mode;

    public Enrich(Source source, LogicalPlan child, Mode mode, Expression policyName, NamedExpression matchField, EnrichPolicy policy, Map<String, String> concreteIndices, List<NamedExpression> enrichFields) {
        super(source, child);
        this.mode = mode == null ? Mode.ANY : mode;
        this.policyName = policyName;
        this.matchField = matchField;
        this.policy = policy;
        this.concreteIndices = concreteIndices;
        this.enrichFields = enrichFields;
    }

    private static Enrich readFrom(StreamInput in) throws IOException {
        Map<String, String> concreteIndices;
        Mode mode = Mode.ANY;
        if (in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_13_0)) {
            mode = (Mode)in.readEnum(Mode.class);
        }
        Source source = Source.readFrom((StreamInput)((PlanStreamInput)in));
        LogicalPlan child = (LogicalPlan)in.readNamedWriteable(LogicalPlan.class);
        Expression policyName = (Expression)in.readNamedWriteable(Expression.class);
        NamedExpression matchField = (NamedExpression)in.readNamedWriteable(NamedExpression.class);
        if (in.getTransportVersion().before((VersionId)TransportVersions.V_8_13_0)) {
            in.readString();
        }
        EnrichPolicy policy = new EnrichPolicy(in);
        if (in.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_13_0)) {
            concreteIndices = in.readMap(StreamInput::readString, StreamInput::readString);
        } else {
            EsIndex esIndex = EsIndex.readFrom(in);
            if (esIndex.concreteIndices().size() > 1) {
                throw new IllegalStateException("expected a single enrich index; got " + String.valueOf(esIndex));
            }
            concreteIndices = Map.of("", (String)Iterables.get(esIndex.concreteIndices(), (int)0));
        }
        return new Enrich(source, child, mode, policyName, matchField, policy, concreteIndices, in.readNamedWriteableCollectionAsList(NamedExpression.class));
    }

    public void writeTo(StreamOutput out) throws IOException {
        if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_13_0)) {
            out.writeEnum((Enum)this.mode());
        }
        Source.EMPTY.writeTo(out);
        out.writeNamedWriteable((NamedWriteable)this.child());
        out.writeNamedWriteable((NamedWriteable)this.policyName());
        out.writeNamedWriteable((NamedWriteable)this.matchField());
        if (out.getTransportVersion().before((VersionId)TransportVersions.V_8_13_0)) {
            out.writeString(BytesRefs.toString((Object)this.policyName().fold(FoldContext.small())));
        }
        this.policy().writeTo(out);
        if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.V_8_13_0)) {
            out.writeMap(this.concreteIndices(), StreamOutput::writeString, StreamOutput::writeString);
        } else {
            Map<String, String> concreteIndices = this.concreteIndices();
            if (concreteIndices.keySet().equals(Set.of(""))) {
                String enrichIndex = concreteIndices.get("");
                EsIndex esIndex = new EsIndex(enrichIndex, Map.of(), Map.of(enrichIndex, IndexMode.STANDARD));
                esIndex.writeTo(out);
            } else {
                throw new IllegalStateException("expected a single enrich index; got " + String.valueOf(concreteIndices));
            }
        }
        out.writeNamedWriteableCollection(this.enrichFields());
    }

    public String getWriteableName() {
        return Enrich.ENTRY.name;
    }

    public NamedExpression matchField() {
        return this.matchField;
    }

    public List<NamedExpression> enrichFields() {
        return this.enrichFields;
    }

    public EnrichPolicy policy() {
        return this.policy;
    }

    public Map<String, String> concreteIndices() {
        return this.concreteIndices;
    }

    public Expression policyName() {
        return this.policyName;
    }

    public Mode mode() {
        return this.mode;
    }

    @Override
    protected AttributeSet computeReferences() {
        return this.matchField.references();
    }

    @Override
    public boolean expressionsResolved() {
        return this.policyName.resolved() && !(this.matchField instanceof EmptyAttribute) && this.matchField.resolved() && Resolvables.resolved(this.enrichFields());
    }

    @Override
    public UnaryPlan replaceChild(LogicalPlan newChild) {
        return new Enrich(this.source(), newChild, this.mode, this.policyName, this.matchField, this.policy, this.concreteIndices, this.enrichFields);
    }

    protected NodeInfo<? extends LogicalPlan> info() {
        return NodeInfo.create((Node)this, Enrich::new, (Object)((Object)this.child()), (Object)((Object)this.mode), (Object)this.policyName, (Object)this.matchField, (Object)this.policy, this.concreteIndices, this.enrichFields);
    }

    @Override
    public List<Attribute> output() {
        if (this.enrichFields == null) {
            return this.child().output();
        }
        if (this.output == null) {
            this.output = NamedExpressions.mergeOutputAttributes(this.enrichFields(), this.child().output());
        }
        return this.output;
    }

    @Override
    public List<Attribute> generatedAttributes() {
        return Expressions.asAttributes(this.enrichFields);
    }

    @Override
    public Enrich withGeneratedNames(List<String> newNames) {
        this.checkNumberOfNewNames(newNames);
        ArrayList<NamedExpression> newEnrichFields = new ArrayList<NamedExpression>(this.enrichFields.size());
        for (int i = 0; i < this.enrichFields.size(); ++i) {
            NamedExpression enrichField = this.enrichFields.get(i);
            String newName = newNames.get(i);
            if (enrichField.name().equals(newName)) {
                newEnrichFields.add(enrichField);
                continue;
            }
            if (enrichField instanceof ReferenceAttribute) {
                ReferenceAttribute ra = (ReferenceAttribute)enrichField;
                newEnrichFields.add((NamedExpression)new Alias(ra.source(), newName, (Expression)ra, new NameId(), ra.synthetic()));
                continue;
            }
            if (enrichField instanceof Alias) {
                Alias a = (Alias)enrichField;
                newEnrichFields.add((NamedExpression)new Alias(a.source(), newName, a.child(), new NameId(), a.synthetic()));
                continue;
            }
            throw new IllegalArgumentException("Enrich field must be Alias or ReferenceAttribute");
        }
        return new Enrich(this.source(), this.child(), this.mode(), this.policyName(), this.matchField(), this.policy(), this.concreteIndices(), newEnrichFields);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        Enrich enrich = (Enrich)o;
        return Objects.equals((Object)this.mode, (Object)enrich.mode) && Objects.equals(this.policyName, enrich.policyName) && Objects.equals(this.matchField, enrich.matchField) && Objects.equals(this.policy, enrich.policy) && Objects.equals(this.concreteIndices, enrich.concreteIndices) && Objects.equals(this.enrichFields, enrich.enrichFields);
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{super.hashCode(), this.mode, this.policyName, this.matchField, this.policy, this.concreteIndices, this.enrichFields});
    }

    @Override
    public BiConsumer<LogicalPlan, Failures> postAnalysisPlanVerification() {
        return Enrich::checkRemoteEnrich;
    }

    private static void checkRemoteEnrich(LogicalPlan plan, Failures failures) {
        plan.forEachUp(Enrich.class, enrich -> Enrich.checkForPlansForbiddenBeforeRemoteEnrich(enrich, failures));
    }

    private static void checkForPlansForbiddenBeforeRemoteEnrich(Enrich enrich, Failures failures) {
        if (enrich.mode != Mode.REMOTE) {
            return;
        }
        HashSet badCommands = new HashSet();
        enrich.forEachUp(LogicalPlan.class, u -> {
            Enrich upstreamEnrich;
            if (u instanceof Aggregate) {
                badCommands.add("STATS");
            } else if (u instanceof Enrich && (upstreamEnrich = (Enrich)u).mode() == Mode.COORDINATOR) {
                badCommands.add("another ENRICH with coordinator policy");
            } else if (u instanceof LookupJoin) {
                badCommands.add("LOOKUP JOIN");
            } else if (u instanceof Fork) {
                badCommands.add("FORK");
            }
        });
        badCommands.forEach(c -> failures.add(Failure.fail(enrich, "ENRICH with remote policy can't be executed after " + c, new Object[0])));
    }

    public static enum Mode {
        ANY,
        COORDINATOR,
        REMOTE;

        private static final Map<String, Mode> map;

        public static Mode from(String name) {
            return name == null ? null : map.get(name.toUpperCase(Locale.ROOT));
        }

        static {
            Mode[] values = Mode.values();
            map = Maps.newMapWithExpectedSize((int)values.length);
            for (Mode m : values) {
                map.put(m.name(), m);
            }
        }
    }
}

