/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.extractor;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ElasticsearchClient;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorContext;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.extractor.ProcessedField;

public class DataFrameDataExtractor {
    private static final Logger LOGGER = LogManager.getLogger(DataFrameDataExtractor.class);
    public static final String NULL_VALUE = "\u0000";
    private final Client client;
    private final DataFrameDataExtractorContext context;
    private long lastSortKey = -1L;
    private boolean isCancelled;
    private boolean hasNext;
    private boolean hasPreviousSearchFailed;
    private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
    private final String[] organicFeatures;
    private final String[] processedFeatures;
    private final Map<String, ExtractedField> extractedFieldsByName;

    DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
        this.client = Objects.requireNonNull(client);
        this.context = Objects.requireNonNull(context);
        this.organicFeatures = context.extractedFields.extractOrganicFeatureNames();
        this.processedFeatures = context.extractedFields.extractProcessedFeatureNames();
        this.extractedFieldsByName = new LinkedHashMap<String, ExtractedField>();
        context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), (ExtractedField)f));
        this.hasNext = true;
        this.hasPreviousSearchFailed = false;
        this.trainTestSplitter = CachedSupplier.wrap(context.trainTestSplitterFactory::create);
    }

    public Map<String, String> getHeaders() {
        return Collections.unmodifiableMap(this.context.headers);
    }

    public boolean hasNext() {
        return this.hasNext;
    }

    public boolean isCancelled() {
        return this.isCancelled;
    }

    public void cancel() {
        LOGGER.debug(() -> "[" + this.context.jobId + "] Data extractor was cancelled");
        this.isCancelled = true;
    }

    public Optional<List<Row>> next() throws IOException {
        if (!this.hasNext()) {
            throw new NoSuchElementException();
        }
        Optional<List<Row>> hits = Optional.ofNullable(this.nextSearch());
        if (hits.isPresent() && !hits.get().isEmpty()) {
            this.lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey();
        } else {
            this.hasNext = false;
        }
        return hits;
    }

    public void preview(ActionListener<List<Row>> listener) {
        SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder((ElasticsearchClient)this.client).setAllowPartialSearchResults(false).setIndices(this.context.indices).setSize(this.context.scrollSize).setQuery((QueryBuilder)QueryBuilders.boolQuery().filter(this.context.query));
        this.setFetchSource(searchRequestBuilder);
        for (ExtractedField docValueField : this.context.extractedFields.getDocValueFields()) {
            searchRequestBuilder.addDocValueField(docValueField.getSearchField(), docValueField.getDocValueFormat());
        }
        searchRequestBuilder.setRuntimeMappings(this.context.runtimeMappings);
        ClientHelper.executeWithHeadersAsync(this.context.headers, (String)"ml", (Client)this.client, (ActionType)TransportSearchAction.TYPE, (ActionRequest)((SearchRequest)searchRequestBuilder.request()), (ActionListener)ActionListener.wrap(searchResponse -> {
            if (searchResponse.getHits().getHits().length == 0) {
                listener.onResponse(Collections.emptyList());
                return;
            }
            SearchHit[] hits = searchResponse.getHits().getHits();
            ArrayList<Row> rows = new ArrayList<Row>(hits.length);
            for (SearchHit hit : hits) {
                String[] extractedValues = this.extractValues(hit);
                rows.add(extractedValues == null ? new Row(null, hit, true) : new Row(extractedValues, hit, false));
            }
            listener.onResponse(rows);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    protected List<Row> nextSearch() throws IOException {
        return this.tryRequestWithSearchResponse(() -> this.executeSearchRequest(this.buildSearchRequest()));
    }

    private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
        try {
            SearchResponse searchResponse = request.get();
            LOGGER.trace(() -> "[" + this.context.jobId + "] Search response was obtained");
            List<Row> rows = this.processSearchResponse(searchResponse);
            this.hasPreviousSearchFailed = false;
            return rows;
        }
        catch (Exception e) {
            if (this.hasPreviousSearchFailed) {
                throw e;
            }
            LOGGER.warn(() -> "[" + this.context.jobId + "] Search resulted to failure; retrying once", (Throwable)e);
            this.markScrollAsErrored();
            return this.nextSearch();
        }
    }

    protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
        return (SearchResponse)ClientHelper.executeWithHeaders(this.context.headers, (String)"ml", (Client)this.client, () -> ((SearchRequestBuilder)searchRequestBuilder).get());
    }

    private SearchRequestBuilder buildSearchRequest() {
        long from = this.lastSortKey + 1L;
        long to = from + (long)this.context.scrollSize;
        LOGGER.trace(() -> Strings.format((String)"[%s] Searching docs with [%s] in [%s, %s)", (Object[])new Object[]{this.context.jobId, "ml__incremental_id", from, to}));
        SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder((ElasticsearchClient)this.client).setAllowPartialSearchResults(false).addSort("ml__incremental_id", SortOrder.ASC).setIndices(this.context.indices).setSize(this.context.scrollSize);
        searchRequestBuilder.setQuery((QueryBuilder)QueryBuilders.boolQuery().filter(this.context.query).filter((QueryBuilder)QueryBuilders.rangeQuery((String)"ml__incremental_id").gte((Object)from).lt((Object)to)));
        this.setFetchSource(searchRequestBuilder);
        for (ExtractedField docValueField : this.context.extractedFields.getDocValueFields()) {
            searchRequestBuilder.addDocValueField(docValueField.getSearchField(), docValueField.getDocValueFormat());
        }
        searchRequestBuilder.setRuntimeMappings(this.context.runtimeMappings);
        return searchRequestBuilder;
    }

    private void setFetchSource(SearchRequestBuilder searchRequestBuilder) {
        if (this.context.includeSource) {
            searchRequestBuilder.setFetchSource(true);
        } else {
            String[] sourceFields = this.context.extractedFields.getSourceFields();
            if (sourceFields.length == 0) {
                searchRequestBuilder.setFetchSource(false);
                searchRequestBuilder.storedFields(new String[]{"_none_"});
            } else {
                searchRequestBuilder.setFetchSource(sourceFields, null);
            }
        }
    }

    private List<Row> processSearchResponse(SearchResponse searchResponse) {
        if (searchResponse.getHits().getHits().length == 0) {
            this.hasNext = false;
            return null;
        }
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList<Row> rows = new ArrayList<Row>(hits.length);
        for (SearchHit hit : hits) {
            if (this.isCancelled) {
                this.hasNext = false;
                break;
            }
            rows.add(this.createRow(hit));
        }
        return rows;
    }

    private String extractNonProcessedValues(SearchHit hit, String organicFeature) {
        ExtractedField field = this.extractedFieldsByName.get(organicFeature);
        Object[] values = field.value(hit);
        if (values.length == 1 && DataFrameDataExtractor.isValidValue(values[0])) {
            return Objects.toString(values[0]);
        }
        if (values.length == 0 && this.context.supportsRowsWithMissingValues) {
            return NULL_VALUE;
        }
        return null;
    }

    private String[] extractProcessedValue(ProcessedField processedField, SearchHit hit) {
        int i;
        Object[] values = processedField.value(hit, this.extractedFieldsByName::get);
        if (values.length == 0 && !this.context.supportsRowsWithMissingValues) {
            return null;
        }
        String[] extractedValue = new String[processedField.getOutputFieldNames().size()];
        for (i = 0; i < processedField.getOutputFieldNames().size(); ++i) {
            extractedValue[i] = NULL_VALUE;
        }
        if (values.length == 0) {
            return extractedValue;
        }
        if (values.length != processedField.getOutputFieldNames().size()) {
            throw ExceptionsHelper.badRequestException((String)"field_processor [{}] output size expected to be [{}], instead it was [{}]", (Object[])new Object[]{processedField.getProcessorName(), processedField.getOutputFieldNames().size(), values.length});
        }
        for (i = 0; i < processedField.getOutputFieldNames().size(); ++i) {
            Object value = values[i];
            if (value == null && this.context.supportsRowsWithMissingValues) continue;
            if (!DataFrameDataExtractor.isValidValue(value)) {
                return null;
            }
            extractedValue[i] = Objects.toString(value);
        }
        return extractedValue;
    }

    private Row createRow(SearchHit hit) {
        String[] extractedValues = this.extractValues(hit);
        if (extractedValues == null) {
            return new Row(null, hit, true);
        }
        boolean isTraining = ((TrainTestSplitter)this.trainTestSplitter.get()).isTraining(extractedValues);
        Row row = new Row(extractedValues, hit, isTraining);
        LOGGER.trace(() -> Strings.format((String)"[%s] Extracted row: sort key = [%s], is_training = [%s], values = %s", (Object[])new Object[]{this.context.jobId, row.getSortKey(), isTraining, Arrays.toString(row.values)}));
        return row;
    }

    private String[] extractValues(SearchHit hit) {
        String[] extractedValues = new String[this.organicFeatures.length + this.processedFeatures.length];
        int i = 0;
        for (String organicFeature : this.organicFeatures) {
            String extractedValue = this.extractNonProcessedValues(hit, organicFeature);
            if (extractedValue == null) {
                return null;
            }
            extractedValues[i++] = extractedValue;
        }
        for (ProcessedField processedField : this.context.extractedFields.getProcessedFields()) {
            String[] processedValues = this.extractProcessedValue(processedField, hit);
            if (processedValues == null) {
                return null;
            }
            for (String processedValue : processedValues) {
                extractedValues[i++] = processedValue;
            }
        }
        return extractedValues;
    }

    private void markScrollAsErrored() {
        this.hasPreviousSearchFailed = true;
    }

    public List<String> getFieldNames() {
        return Stream.concat(Arrays.stream(this.organicFeatures), Arrays.stream(this.processedFeatures)).collect(Collectors.toList());
    }

    public ExtractedFields getExtractedFields() {
        return this.context.extractedFields;
    }

    public DataSummary collectDataSummary() {
        SearchRequestBuilder searchRequestBuilder = this.buildDataSummarySearchRequestBuilder();
        SearchResponse searchResponse = this.executeSearchRequest(searchRequestBuilder);
        long rows = searchResponse.getHits().getTotalHits().value;
        LOGGER.debug(() -> Strings.format((String)"[%s] Data summary rows [%s]", (Object[])new Object[]{this.context.jobId, rows}));
        return new DataSummary(rows, this.organicFeatures.length + this.processedFeatures.length);
    }

    public void collectDataSummaryAsync(ActionListener<DataSummary> dataSummaryActionListener) {
        SearchRequestBuilder searchRequestBuilder = this.buildDataSummarySearchRequestBuilder();
        int numberOfFields = this.organicFeatures.length + this.processedFeatures.length;
        ClientHelper.executeWithHeadersAsync(this.context.headers, (String)"ml", (Client)this.client, (ActionType)TransportSearchAction.TYPE, (ActionRequest)((SearchRequest)searchRequestBuilder.request()), (ActionListener)ActionListener.wrap(searchResponse -> dataSummaryActionListener.onResponse((Object)new DataSummary(searchResponse.getHits().getTotalHits().value, numberOfFields)), arg_0 -> dataSummaryActionListener.onFailure(arg_0)));
    }

    private SearchRequestBuilder buildDataSummarySearchRequestBuilder() {
        QueryBuilder summaryQuery = this.context.query;
        if (!this.context.supportsRowsWithMissingValues) {
            summaryQuery = QueryBuilders.boolQuery().filter(summaryQuery).filter(this.allExtractedFieldsExistQuery());
        }
        return new SearchRequestBuilder((ElasticsearchClient)this.client).setAllowPartialSearchResults(false).setIndices(this.context.indices).setSize(0).setQuery(summaryQuery).setTrackTotalHits(true).setRuntimeMappings(this.context.runtimeMappings);
    }

    private QueryBuilder allExtractedFieldsExistQuery() {
        BoolQueryBuilder query = QueryBuilders.boolQuery();
        for (ExtractedField field : this.context.extractedFields.getAllFields()) {
            query.filter((QueryBuilder)QueryBuilders.existsQuery((String)field.getName()));
        }
        return query;
    }

    public Set<String> getCategoricalFields(DataFrameAnalysis analysis) {
        return ExtractedFieldsDetector.getCategoricalOutputFields(this.context.extractedFields, analysis);
    }

    public static boolean isValidValue(Object value) {
        return value instanceof Number || value instanceof String || value instanceof Boolean;
    }

    public static class Row {
        private final SearchHit hit;
        @Nullable
        private final String[] values;
        private final boolean isTraining;

        private Row(String[] values, SearchHit hit, boolean isTraining) {
            this.values = values;
            this.hit = hit;
            this.isTraining = isTraining;
        }

        @Nullable
        public String[] getValues() {
            return this.values;
        }

        public SearchHit getHit() {
            return this.hit;
        }

        public boolean shouldSkip() {
            return this.values == null;
        }

        public boolean isTraining() {
            return this.isTraining;
        }

        public int getChecksum() {
            return (int)this.getSortKey();
        }

        public long getSortKey() {
            return (Long)this.hit.getSortValues()[0];
        }
    }

    public static class DataSummary {
        public final long rows;
        public final int cols;

        public DataSummary(long rows, int cols) {
            this.rows = rows;
            this.cols = cols;
        }
    }
}

