/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.SingleClassReservoirTrainTestSplitter;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.StratifiedTrainTestSplitter;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;

public class TrainTestSplitterFactory {
    private static final Logger LOGGER = LogManager.getLogger(TrainTestSplitterFactory.class);
    private final Client client;
    private final DataFrameAnalyticsConfig config;
    private final List<String> fieldNames;

    public TrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
        this.client = Objects.requireNonNull(client);
        this.config = Objects.requireNonNull(config);
        this.fieldNames = Objects.requireNonNull(fieldNames);
    }

    public TrainTestSplitter create() {
        if (this.config.getAnalysis() instanceof Regression) {
            return this.createSingleClassSplitter((Regression)this.config.getAnalysis());
        }
        if (this.config.getAnalysis() instanceof Classification) {
            return this.createStratifiedSplitter((Classification)this.config.getAnalysis());
        }
        return row -> true;
    }

    private TrainTestSplitter createSingleClassSplitter(Regression regression) {
        SearchRequestBuilder searchRequestBuilder = this.client.prepareSearch(new String[]{this.config.getDest().getIndex()}).setSize(0).setAllowPartialSearchResults(false).setTrackTotalHits(true).setQuery((QueryBuilder)QueryBuilders.existsQuery((String)regression.getDependentVariable()));
        try {
            SearchResponse searchResponse = (SearchResponse)ClientHelper.executeWithHeaders((Map)this.config.getHeaders(), (String)"ml", (Client)this.client, () -> ((SearchRequestBuilder)searchRequestBuilder).get());
            return new SingleClassReservoirTrainTestSplitter(this.fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
        }
        catch (Exception e) {
            String msg = "[" + this.config.getId() + "] Error searching total number of training docs";
            LOGGER.error(msg, (Throwable)e);
            throw new ElasticsearchException(msg, (Throwable)e, new Object[0]);
        }
    }

    private TrainTestSplitter createStratifiedSplitter(Classification classification) {
        String aggName = "dependent_variable_terms";
        SearchRequestBuilder searchRequestBuilder = this.client.prepareSearch(new String[]{this.config.getDest().getIndex()}).setSize(0).setAllowPartialSearchResults(false).addAggregation((AggregationBuilder)((TermsAggregationBuilder)AggregationBuilders.terms((String)aggName).field(classification.getDependentVariable())).size(100));
        try {
            SearchResponse searchResponse = (SearchResponse)ClientHelper.executeWithHeaders((Map)this.config.getHeaders(), (String)"ml", (Client)this.client, () -> ((SearchRequestBuilder)searchRequestBuilder).get());
            Aggregations aggs = searchResponse.getAggregations();
            Terms terms = (Terms)aggs.get(aggName);
            HashMap<String, Long> classCounts = new HashMap<String, Long>();
            for (Terms.Bucket bucket : terms.getBuckets()) {
                classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
            }
            return new StratifiedTrainTestSplitter(this.fieldNames, classification.getDependentVariable(), classCounts, classification.getTrainingPercent(), classification.getRandomizeSeed());
        }
        catch (Exception e) {
            String msg = "[" + this.config.getId() + "] Dependent variable terms search failed";
            LOGGER.error(msg, (Throwable)e);
            throw new ElasticsearchException(msg, (Throwable)e, new Object[0]);
        }
    }
}

