/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.nlp.TaskType;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class NlpTask {
    private final NlpConfig config;
    private final NlpTokenizer tokenizer;

    public NlpTask(NlpConfig config, Vocabulary vocabulary) {
        this.config = config;
        this.tokenizer = NlpTokenizer.build(vocabulary, config.getTokenization());
    }

    public Processor createProcessor() throws ValidationException {
        return TaskType.fromString(this.config.getName()).createProcessor(this.tokenizer, this.config);
    }

    public static String extractInput(TrainedModelInput input, Map<String, Object> doc) {
        assert (input.getFieldNames().size() == 1);
        String inputField = (String)input.getFieldNames().get(0);
        Object inputValue = XContentMapValues.extractValue((String)inputField, doc);
        if (inputValue == null) {
            throw ExceptionsHelper.badRequestException((String)"Input field [{}] does not exist in the source document", (Object[])new Object[]{inputField});
        }
        if (inputValue instanceof String) {
            return (String)inputValue;
        }
        throw ExceptionsHelper.badRequestException((String)"Input value [{}] for field [{}] must be a string", (Object[])new Object[]{inputValue, inputField});
    }

    public static interface Processor {
        public void validateInputs(List<String> var1);

        public RequestBuilder getRequestBuilder(NlpConfig var1);

        public ResultProcessor getResultProcessor(NlpConfig var1);
    }

    public static class Request {
        public final TokenizationResult tokenization;
        public final BytesReference processInput;

        public Request(TokenizationResult tokenization, BytesReference processInput) {
            this.tokenization = Objects.requireNonNull(tokenization);
            this.processInput = Objects.requireNonNull(processInput);
        }
    }

    public static interface ResultProcessor {
        public InferenceResults processResult(TokenizationResult var1, PyTorchResult var2);
    }

    public static interface RequestBuilder {
        public Request buildRequest(List<String> var1, String var2) throws IOException;

        public Request buildRequest(TokenizationResult var1, String var2) throws IOException;

        public static void writePaddedTokens(String fieldName, TokenizationResult tokenization, int padToken, TokenLookupFunction generator, XContentBuilder builder) throws IOException {
            builder.startArray(fieldName);
            for (TokenizationResult.Tokenization inputTokens : tokenization.getTokenizations()) {
                int i;
                builder.startArray();
                for (i = 0; i < inputTokens.getTokenIds().length; ++i) {
                    builder.value(generator.apply(inputTokens, i));
                }
                while (i < tokenization.getLongestSequenceLength()) {
                    builder.value(padToken);
                    ++i;
                }
                builder.endArray();
            }
            builder.endArray();
        }

        public static void writeNonPaddedArguments(String fieldName, int numTokenizations, int longestSequenceLength, IntToIntFunction generator, XContentBuilder builder) throws IOException {
            builder.startArray(fieldName);
            for (int i = 0; i < numTokenizations; ++i) {
                builder.startArray();
                for (int j = 0; j < longestSequenceLength; ++j) {
                    builder.value(generator.applyAsInt(j));
                }
                builder.endArray();
            }
            builder.endArray();
        }

        @FunctionalInterface
        public static interface TokenLookupFunction {
            public int apply(TokenizationResult.Tokenization var1, int var2);
        }

        @FunctionalInterface
        public static interface IntToIntFunction {
            public int applyAsInt(int var1);
        }
    }
}

