/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.rest;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy;

abstract class BaseInferenceAction
extends BaseRestHandler {
    BaseInferenceAction() {
    }

    static Params parseParams(RestRequest restRequest) {
        if (restRequest.hasParam("inference_id")) {
            String inferenceEntityId = restRequest.param("inference_id");
            TaskType taskType = TaskType.fromStringOrStatusException((String)restRequest.param("task_type_or_id"));
            return new Params(inferenceEntityId, taskType);
        }
        return new Params(restRequest.param("task_type_or_id"), TaskType.ANY);
    }

    static TimeValue parseTimeout(RestRequest restRequest) {
        return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT);
    }

    protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
        Params params = BaseInferenceAction.parseParams(restRequest);
        ReleasableBytesReference content = restRequest.requiredContent();
        TimeValue inferTimeout = BaseInferenceAction.parseTimeout(restRequest);
        String productUseCase = this.extractProductUseCase(restRequest);
        InferenceContext context = new InferenceContext(productUseCase);
        InferenceActionProxy.Request request = new InferenceActionProxy.Request(params.taskType(), params.inferenceEntityId(), (BytesReference)content, restRequest.getXContentType(), inferTimeout, this.shouldStream(), context);
        return channel -> client.execute((ActionType)InferenceActionProxy.INSTANCE, (ActionRequest)request, ActionListener.withRef(this.listener((RestChannel)channel), (RefCounted)content));
    }

    protected abstract boolean shouldStream();

    protected abstract ActionListener<InferenceAction.Response> listener(RestChannel var1);

    private String extractProductUseCase(RestRequest restRequest) {
        Map headers = restRequest.getHeaders();
        if (Objects.isNull(headers) || headers.isEmpty()) {
            return "";
        }
        List productUseCaseHeaders = (List)headers.get("X-elastic-product-use-case");
        if (Objects.isNull(productUseCaseHeaders) || productUseCaseHeaders.isEmpty()) {
            return "";
        }
        return (String)productUseCaseHeaders.get(0);
    }

    record Params(String inferenceEntityId, TaskType taskType) {
    }
}

