/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.TaskAndApi;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticRerankPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticSparseEmbeddingPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic.ElasticTextEmbeddingPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiCompletionPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.openai.OpenAiTextEmbeddingPayload;

public class SageMakerSchemas {
    private static final Map<TaskAndApi, SageMakerSchema> schemas = SageMakerSchemas.register(new OpenAiTextEmbeddingPayload(), new OpenAiCompletionPayload(), new ElasticTextEmbeddingPayload(), new ElasticSparseEmbeddingPayload(), new ElasticCompletionPayload(), new ElasticRerankPayload());
    private static final Map<TaskAndApi, SageMakerStreamSchema> streamSchemas = schemas.entrySet().stream().filter(e -> e.getValue() instanceof SageMakerStreamSchema).collect(Collectors.toMap(Map.Entry::getKey, e -> (SageMakerStreamSchema)e.getValue()));
    private static final Map<String, Set<TaskType>> tasksByApi = schemas.keySet().stream().collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
    private static final Map<String, Set<TaskType>> streamingTasksByApi = streamSchemas.keySet().stream().collect(Collectors.groupingBy(TaskAndApi::api, Collectors.mapping(TaskAndApi::taskType, Collectors.toSet())));
    private static final Set<TaskType> supportedStreamingTasks = streamSchemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet());
    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.copyOf(schemas.keySet().stream().map(TaskAndApi::taskType).collect(Collectors.toSet()));

    private static Map<TaskAndApi, SageMakerSchema> register(SageMakerSchemaPayload ... payloads) {
        return Arrays.stream(payloads).flatMap(payload -> payload.supportedTasks().stream().map(taskType -> {
            SageMakerSchema value;
            TaskAndApi key = new TaskAndApi((TaskType)taskType, payload.api());
            if (payload instanceof SageMakerStreamSchemaPayload) {
                SageMakerStreamSchemaPayload streamPayload = (SageMakerStreamSchemaPayload)payload;
                value = new SageMakerStreamSchema(streamPayload);
            } else {
                value = new SageMakerSchema((SageMakerSchemaPayload)payload);
            }
            return Map.entry(key, value);
        })).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    public static List<NamedWriteableRegistry.Entry> namedWriteables() {
        return Stream.concat(Stream.of(new NamedWriteableRegistry.Entry(SageMakerStoredServiceSchema.class, SageMakerStoredServiceSchema.NO_OP.getWriteableName(), in -> SageMakerStoredServiceSchema.NO_OP), new NamedWriteableRegistry.Entry(SageMakerStoredTaskSchema.class, SageMakerStoredTaskSchema.NO_OP.getWriteableName(), in -> SageMakerStoredTaskSchema.NO_OP)), schemas.values().stream().flatMap(SageMakerSchema::namedWriteables)).collect(() -> new HashMap(), (map, entry) -> map.putIfAbsent(entry.name, entry), Map::putAll).values().stream().toList();
    }

    public SageMakerSchema schemaFor(SageMakerModel model) throws ElasticsearchStatusException {
        return this.schemaFor(model.getTaskType(), model.api());
    }

    public SageMakerSchema schemaFor(TaskType taskType, String api) throws ElasticsearchStatusException {
        SageMakerSchema schema = schemas.get(new TaskAndApi(taskType, api));
        if (schema == null) {
            throw new ElasticsearchStatusException(Strings.format((String)"Task [%s] is not compatible for service [sagemaker] and api [%s]. Supported tasks: [%s]", (Object[])new Object[]{api, taskType.toString(), tasksByApi.getOrDefault(api, Set.of())}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
        }
        return schema;
    }

    public SageMakerStreamSchema streamSchemaFor(SageMakerModel model) throws ElasticsearchStatusException {
        SageMakerStreamSchema schema = streamSchemas.get(new TaskAndApi(model.getTaskType(), model.api()));
        if (schema == null) {
            throw new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [sagemaker], api [%s], and task [%s]. Supported streaming tasks: [%s]", (Object[])new Object[]{model.api(), model.getTaskType().toString(), streamingTasksByApi.getOrDefault(model.api(), Set.of())}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
        }
        return schema;
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return supportedTaskTypes;
    }

    public Set<TaskType> supportedStreamingTasks() {
        return supportedStreamingTasks;
    }
}

