"use strict";

Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.postEvaluateRoute = void 0;
var _securitysolutionEsUtils = require("@kbn/securitysolution-es-utils");
var _std = require("@kbn/std");
var _langsmith = require("langsmith");
var _evaluation = require("langsmith/evaluation");
var _uuid = require("uuid");
var _server = require("@kbn/data-plugin/server");
var _elasticAssistantCommon = require("@kbn/elastic-assistant-common");
var _common = require("@kbn/elastic-assistant-common/impl/schemas/common");
var _server2 = require("@kbn/langchain/server");
var _agents = require("langchain/agents");
var _fp = require("lodash/fp");
var _build_response = require("../../lib/build_response");
var _helpers = require("../helpers");
var _utils = require("./utils");
var _helpers2 = require("../../ai_assistant_data_clients/anonymization_fields/helpers");
var _evaluation2 = require("../../lib/attack_discovery/evaluation");
var _graph = require("../../lib/langchain/graphs/default_assistant_graph/graph");
var _prompts = require("../../lib/langchain/graphs/default_assistant_graph/prompts");
var _utils2 = require("../utils");
var _get_graphs_from_names = require("./get_graphs_from_names");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

const DEFAULT_SIZE = 20;
const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes
const LANG_CHAIN_TIMEOUT = ROUTE_HANDLER_TIMEOUT - 10_000; // 9 minutes 50 seconds
const CONNECTOR_TIMEOUT = LANG_CHAIN_TIMEOUT - 10_000; // 9 minutes 40 seconds

const postEvaluateRoute = (router, getElser) => {
  router.versioned.post({
    access: _elasticAssistantCommon.INTERNAL_API_ACCESS,
    path: _elasticAssistantCommon.ELASTIC_AI_ASSISTANT_EVALUATE_URL,
    options: {
      tags: ['access:elasticAssistant'],
      timeout: {
        idleSocket: ROUTE_HANDLER_TIMEOUT
      }
    }
  }).addVersion({
    version: _elasticAssistantCommon.API_VERSIONS.internal.v1,
    validate: {
      request: {
        body: (0, _common.buildRouteValidationWithZod)(_elasticAssistantCommon.PostEvaluateBody)
      },
      response: {
        200: {
          body: {
            custom: (0, _common.buildRouteValidationWithZod)(_elasticAssistantCommon.PostEvaluateResponse)
          }
        }
      }
    }
  }, async (context, request, response) => {
    const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']);
    const assistantContext = ctx.elasticAssistant;
    const actions = ctx.elasticAssistant.actions;
    const logger = assistantContext.logger.get('evaluate');
    const abortSignal = (0, _server.getRequestAbortedSignal)(request.events.aborted$);
    const v2KnowledgeBaseEnabled = (0, _helpers.isV2KnowledgeBaseEnabled)({
      context: ctx,
      request
    });

    // Perform license, authenticated user and evaluation FF checks
    const checkResponse = (0, _helpers.performChecks)({
      capability: 'assistantModelEvaluation',
      context: ctx,
      request,
      response
    });
    if (!checkResponse.isSuccess) {
      return checkResponse.response;
    }
    try {
      var _await$assistantConte, _await$assistantConte2, _await$assistantConte3;
      const evaluationId = (0, _uuid.v4)();
      const {
        alertsIndexPattern,
        datasetName,
        evaluatorConnectorId,
        graphs: graphNames,
        langSmithApiKey,
        langSmithProject,
        connectorIds,
        size,
        replacements,
        runName = evaluationId
      } = request.body;
      const dataset = await (0, _utils.fetchLangSmithDataset)(datasetName, logger, langSmithApiKey);
      if (dataset.length === 0) {
        return response.badRequest({
          body: {
            message: `No LangSmith dataset found for name: ${datasetName}`
          }
        });
      }
      logger.info('postEvaluateRoute:');
      logger.info(`request.query:\n${JSON.stringify(request.query, null, 2)}`);
      logger.info(`request.body:\n${JSON.stringify((0, _fp.omit)(['langSmithApiKey'], request.body), null, 2)}`);
      logger.info(`Evaluation ID: ${evaluationId}`);
      const totalExecutions = connectorIds.length * graphNames.length * dataset.length;
      logger.info('Creating graphs:');
      logger.info(`\tconnectors/models: ${connectorIds.length}`);
      logger.info(`\tgraphs: ${graphNames.length}`);
      logger.info(`\tdataset: ${dataset.length}`);
      logger.warn(`\ttotal graph executions: ${totalExecutions} `);
      if (totalExecutions > 50) {
        logger.warn(`Total baseline graph executions >= 50! This may take a while, and cost some money...`);
      }

      // Setup graph params
      // Get a scoped esClient for esStore + writing results to the output index
      const esClient = ctx.core.elasticsearch.client.asCurrentUser;
      const inference = ctx.elasticAssistant.inference;

      // Data clients
      const anonymizationFieldsDataClient = (_await$assistantConte = await assistantContext.getAIAssistantAnonymizationFieldsDataClient()) !== null && _await$assistantConte !== void 0 ? _await$assistantConte : undefined;
      const conversationsDataClient = (_await$assistantConte2 = await assistantContext.getAIAssistantConversationsDataClient()) !== null && _await$assistantConte2 !== void 0 ? _await$assistantConte2 : undefined;
      const kbDataClient = (_await$assistantConte3 = await assistantContext.getAIAssistantKnowledgeBaseDataClient({
        v2KnowledgeBaseEnabled
      })) !== null && _await$assistantConte3 !== void 0 ? _await$assistantConte3 : undefined;
      const dataClients = {
        anonymizationFieldsDataClient,
        conversationsDataClient,
        kbDataClient
      };

      // Actions
      const actionsClient = await actions.getActionsClientWithRequest(request);
      const connectors = await actionsClient.getBulk({
        ids: connectorIds,
        throwIfSystemAction: false
      });

      // Fetch any tools registered to the security assistant
      const assistantTools = assistantContext.getRegisteredTools(_helpers.DEFAULT_PLUGIN_NAME);
      const {
        attackDiscoveryGraphs
      } = (0, _get_graphs_from_names.getGraphsFromNames)(graphNames);
      if (attackDiscoveryGraphs.length > 0) {
        try {
          // NOTE: we don't wait for the evaluation to finish here, because
          // the client will retry / timeout when evaluations take too long
          void (0, _evaluation2.evaluateAttackDiscovery)({
            actionsClient,
            alertsIndexPattern,
            attackDiscoveryGraphs,
            connectors,
            connectorTimeout: CONNECTOR_TIMEOUT,
            datasetName,
            esClient,
            evaluationId,
            evaluatorConnectorId,
            langSmithApiKey,
            langSmithProject,
            logger,
            runName,
            size
          });
        } catch (err) {
          logger.error(() => `Error evaluating attack discovery: ${err}`);
        }

        // Return early if we're only running attack discovery graphs
        return response.ok({
          body: {
            evaluationId,
            success: true
          }
        });
      }
      const graphs = await Promise.all(connectors.map(async connector => {
        var _dataClients$anonymiz, _await$dataClients$kb, _dataClients$kbDataCl;
        const llmType = (0, _utils2.getLlmType)(connector.actionTypeId);
        const isOssModel = (0, _utils2.isOpenSourceModel)(connector);
        const isOpenAI = llmType === 'openai' && !isOssModel;
        const llmClass = (0, _utils2.getLlmClass)(llmType);
        const createLlmInstance = () => new llmClass({
          actionsClient,
          connectorId: connector.id,
          llmType,
          logger,
          temperature: (0, _server2.getDefaultArguments)(llmType).temperature,
          signal: abortSignal,
          streaming: false,
          maxRetries: 0
        });
        const llm = createLlmInstance();
        const anonymizationFieldsRes = await (dataClients === null || dataClients === void 0 ? void 0 : (_dataClients$anonymiz = dataClients.anonymizationFieldsDataClient) === null || _dataClients$anonymiz === void 0 ? void 0 : _dataClients$anonymiz.findDocuments({
          perPage: 1000,
          page: 1
        }));
        const anonymizationFields = anonymizationFieldsRes ? (0, _helpers2.transformESSearchToAnonymizationFields)(anonymizationFieldsRes.data) : undefined;

        // Check if KB is available
        const isEnabledKnowledgeBase = (_await$dataClients$kb = await ((_dataClients$kbDataCl = dataClients.kbDataClient) === null || _dataClients$kbDataCl === void 0 ? void 0 : _dataClients$kbDataCl.isModelDeployed())) !== null && _await$dataClients$kb !== void 0 ? _await$dataClients$kb : false;

        // Skeleton request from route to pass to the agents
        // params will be passed to the actions executor
        const skeletonRequest = {
          ...request,
          body: {
            alertsIndexPattern: '',
            allow: [],
            allowReplacement: [],
            subAction: 'invokeAI',
            // The actionTypeId is irrelevant when used with the invokeAI subaction
            actionTypeId: '.gen-ai',
            replacements: {},
            size: DEFAULT_SIZE,
            conversationId: ''
          }
        };

        // Fetch any applicable tools that the source plugin may have registered
        const assistantToolParams = {
          anonymizationFields,
          esClient,
          isEnabledKnowledgeBase,
          kbDataClient: dataClients === null || dataClients === void 0 ? void 0 : dataClients.kbDataClient,
          llm,
          isOssModel,
          logger,
          request: skeletonRequest,
          alertsIndexPattern,
          // onNewReplacements,
          replacements,
          inference,
          connectorId: connector.id,
          size,
          telemetry: ctx.elasticAssistant.telemetry
        };
        const tools = assistantTools.flatMap(tool => {
          var _tool$getTool;
          return (_tool$getTool = tool.getTool(assistantToolParams)) !== null && _tool$getTool !== void 0 ? _tool$getTool : [];
        });
        const agentRunnable = isOpenAI ? await (0, _agents.createOpenAIFunctionsAgent)({
          llm,
          tools,
          prompt: _prompts.openAIFunctionAgentPrompt,
          streamRunnable: false
        }) : llmType && ['bedrock', 'gemini'].includes(llmType) ? (0, _agents.createToolCallingAgent)({
          llm,
          tools,
          prompt: llmType === 'bedrock' ? _prompts.bedrockToolCallingAgentPrompt : _prompts.geminiToolCallingAgentPrompt,
          streamRunnable: false
        }) : await (0, _agents.createStructuredChatAgent)({
          llm,
          tools,
          prompt: _prompts.structuredChatAgentPrompt,
          streamRunnable: false
        });
        return {
          name: `${runName} - ${connector.name}`,
          llmType,
          isOssModel,
          graph: (0, _graph.getDefaultAssistantGraph)({
            agentRunnable,
            dataClients,
            createLlmInstance,
            logger,
            tools,
            replacements: {}
          })
        };
      }));

      // Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
      await (0, _std.asyncForEach)(graphs, async ({
        name,
        graph,
        llmType,
        isOssModel
      }) => {
        // Wrapper function for invoking the graph (to parse different input/output formats)
        const predict = async input => {
          logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
          const r = await graph.invoke({
            input: input.input,
            conversationId: undefined,
            responseLanguage: 'English',
            llmType,
            isStreaming: false,
            isOssModel
          },
          // TODO: Update to use the correct input format per dataset type
          {
            runName,
            tags: ['evaluation']
          });
          const output = r.agentOutcome.returnValues.output;
          return output;
        };
        (0, _evaluation.evaluate)(predict, {
          data: datasetName !== null && datasetName !== void 0 ? datasetName : '',
          evaluators: [],
          // Evals to be managed in LangSmith for now
          experimentPrefix: name,
          client: new _langsmith.Client({
            apiKey: langSmithApiKey
          }),
          // prevent rate limiting and unexpected multiple experiment runs
          maxConcurrency: 5
        }).then(output => {
          logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`);
        }).catch(err => {
          logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`);
        });
      });
      return response.ok({
        body: {
          evaluationId,
          success: true
        }
      });
    } catch (err) {
      logger.error(err);
      const error = (0, _securitysolutionEsUtils.transformError)(err);
      const resp = (0, _build_response.buildResponse)(response);
      return resp.error({
        body: {
          success: false,
          error: error.message
        },
        statusCode: error.statusCode
      });
    }
  });
};
exports.postEvaluateRoute = postEvaluateRoute;