"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.streamGraph = exports.invokeGraph = void 0;
var _elasticApmNode = _interopRequireDefault(require("elastic-apm-node"));
var _server = require("@kbn/ml-response-stream/server");
var _event_based_telemetry = require("../../../telemetry/event_based_telemetry");
var _with_assistant_span = require("../../tracers/apm/with_assistant_span");
var _run_agent = require("./nodes/run_agent");
var _graph = require("./graph");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

/**
 * Execute the graph in streaming mode
 *
 * @param apmTracer
 * @param assistantGraph
 * @param inputs
 * @param isEnabledKnowledgeBase
 * @param logger
 * @param onLlmResponse
 * @param request
 * @param telemetry
 * @param telemetryTracer
 * @param traceOptions
 */
const streamGraph = async ({
  apmTracer,
  assistantGraph,
  inputs,
  inferenceChatModelDisabled = false,
  isEnabledKnowledgeBase,
  logger,
  onLlmResponse,
  request,
  telemetry,
  telemetryTracer,
  traceOptions
}) => {
  var _traceOptions$tracers2, _traceOptions$tags2;
  let streamingSpan;
  if (_elasticApmNode.default.isStarted()) {
    var _agent$startSpan;
    streamingSpan = (_agent$startSpan = _elasticApmNode.default.startSpan(`${_graph.DEFAULT_ASSISTANT_GRAPH_ID} (Streaming)`)) !== null && _agent$startSpan !== void 0 ? _agent$startSpan : undefined;
  }
  const {
    end: streamEnd,
    push,
    responseWithHeaders
  } = (0, _server.streamFactory)(request.headers, logger, false, false);
  let didEnd = false;
  const handleStreamEnd = (finalResponse, isError = false) => {
    var _streamingSpan3, _streamingSpan4, _streamingSpan5;
    if (didEnd) {
      return;
    }
    if (isError) {
      telemetry.reportEvent(_event_based_telemetry.INVOKE_ASSISTANT_ERROR_EVENT.eventType, {
        actionTypeId: request.body.actionTypeId,
        model: request.body.model,
        errorMessage: finalResponse,
        assistantStreamingEnabled: true,
        isEnabledKnowledgeBase,
        errorLocation: 'handleStreamEnd'
      });
    }
    if (onLlmResponse) {
      var _streamingSpan, _streamingSpan$transa, _streamingSpan$transa2, _streamingSpan2, _streamingSpan2$ids;
      onLlmResponse(finalResponse, {
        transactionId: (_streamingSpan = streamingSpan) === null || _streamingSpan === void 0 ? void 0 : (_streamingSpan$transa = _streamingSpan.transaction) === null || _streamingSpan$transa === void 0 ? void 0 : (_streamingSpan$transa2 = _streamingSpan$transa.ids) === null || _streamingSpan$transa2 === void 0 ? void 0 : _streamingSpan$transa2['transaction.id'],
        traceId: (_streamingSpan2 = streamingSpan) === null || _streamingSpan2 === void 0 ? void 0 : (_streamingSpan2$ids = _streamingSpan2.ids) === null || _streamingSpan2$ids === void 0 ? void 0 : _streamingSpan2$ids['trace.id']
      }, isError).catch(() => {});
    }
    streamEnd();
    didEnd = true;
    if (streamingSpan && !((_streamingSpan3 = streamingSpan) !== null && _streamingSpan3 !== void 0 && _streamingSpan3.outcome) || ((_streamingSpan4 = streamingSpan) === null || _streamingSpan4 === void 0 ? void 0 : _streamingSpan4.outcome) === 'unknown') {
      streamingSpan.outcome = isError ? 'failure' : 'success';
    }
    (_streamingSpan5 = streamingSpan) === null || _streamingSpan5 === void 0 ? void 0 : _streamingSpan5.end();
  };

  // Stream is from tool calling agent or structured chat agent
  if (!inferenceChatModelDisabled || inputs.isOssModel || (inputs === null || inputs === void 0 ? void 0 : inputs.llmType) === 'bedrock' || (inputs === null || inputs === void 0 ? void 0 : inputs.llmType) === 'gemini') {
    var _traceOptions$tracers, _traceOptions$tags;
    const stream = await assistantGraph.streamEvents(inputs, {
      callbacks: [apmTracer, ...((_traceOptions$tracers = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tracers) !== null && _traceOptions$tracers !== void 0 ? _traceOptions$tracers : []), ...(telemetryTracer ? [telemetryTracer] : [])],
      runName: _graph.DEFAULT_ASSISTANT_GRAPH_ID,
      tags: (_traceOptions$tags = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tags) !== null && _traceOptions$tags !== void 0 ? _traceOptions$tags : [],
      version: 'v2',
      streamMode: 'values',
      recursionLimit: inputs !== null && inputs !== void 0 && inputs.isOssModel ? 50 : 25
    }, (inputs === null || inputs === void 0 ? void 0 : inputs.llmType) === 'bedrock' ? {
      includeNames: ['Summarizer']
    } : undefined);
    const pushStreamUpdate = async () => {
      for await (const {
        event,
        data,
        tags
      } of stream) {
        if ((tags || []).includes(_run_agent.AGENT_NODE_TAG)) {
          var _data$output$lc_kwarg, _data$output$lc_kwarg2;
          if (event === 'on_chat_model_stream' && !inputs.isOssModel) {
            var _msg$tool_call_chunks;
            const msg = data.chunk;
            if (!didEnd && !((_msg$tool_call_chunks = msg.tool_call_chunks) !== null && _msg$tool_call_chunks !== void 0 && _msg$tool_call_chunks.length) && msg.content.length) {
              push({
                payload: msg.content,
                type: 'content'
              });
            }
          }
          if (event === 'on_chat_model_end' && !((_data$output$lc_kwarg = data.output.lc_kwargs) !== null && _data$output$lc_kwarg !== void 0 && (_data$output$lc_kwarg2 = _data$output$lc_kwarg.tool_calls) !== null && _data$output$lc_kwarg2 !== void 0 && _data$output$lc_kwarg2.length) && !didEnd) {
            handleStreamEnd(data.output.content);
          }
        }
      }
    };
    pushStreamUpdate().catch(err => {
      logger.error(`Error streaming graph: ${err}`);
      handleStreamEnd(err.message, true);
    });
    return responseWithHeaders;
  }

  // Stream is from openai functions agent
  let finalMessage = '';
  const stream = assistantGraph.streamEvents(inputs, {
    callbacks: [apmTracer, ...((_traceOptions$tracers2 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tracers) !== null && _traceOptions$tracers2 !== void 0 ? _traceOptions$tracers2 : []), ...(telemetryTracer ? [telemetryTracer] : [])],
    runName: _graph.DEFAULT_ASSISTANT_GRAPH_ID,
    streamMode: 'values',
    tags: (_traceOptions$tags2 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tags) !== null && _traceOptions$tags2 !== void 0 ? _traceOptions$tags2 : [],
    version: 'v1'
  }, (inputs === null || inputs === void 0 ? void 0 : inputs.provider) === 'bedrock' ? {
    includeNames: ['Summarizer']
  } : undefined);
  const pushStreamUpdate = async () => {
    for await (const {
      event,
      data,
      tags
    } of stream) {
      if ((tags || []).includes(_run_agent.AGENT_NODE_TAG)) {
        if (event === 'on_llm_stream') {
          const chunk = data === null || data === void 0 ? void 0 : data.chunk;
          const msg = chunk.message;
          if (msg !== null && msg !== void 0 && msg.tool_call_chunks && (msg === null || msg === void 0 ? void 0 : msg.tool_call_chunks.length) > 0) {
            /* empty */
          } else if (!didEnd) {
            push({
              payload: msg.content,
              type: 'content'
            });
            finalMessage += msg.content;
          }
        }
        if (event === 'on_llm_end' && !didEnd) {
          var _data$output, _generation$generatio, _generation$generatio2;
          const generation = (_data$output = data.output) === null || _data$output === void 0 ? void 0 : _data$output.generations[0][0];
          if (
          // if generation is null, an error occurred - do nothing and let error handling complete the stream
          generation != null && (
          // no finish_reason means the stream was aborted
          !(generation !== null && generation !== void 0 && (_generation$generatio = generation.generationInfo) !== null && _generation$generatio !== void 0 && _generation$generatio.finish_reason) || (generation === null || generation === void 0 ? void 0 : (_generation$generatio2 = generation.generationInfo) === null || _generation$generatio2 === void 0 ? void 0 : _generation$generatio2.finish_reason) === 'stop')) {
            handleStreamEnd(generation !== null && generation !== void 0 && generation.text && generation !== null && generation !== void 0 && generation.text.length ? generation === null || generation === void 0 ? void 0 : generation.text : finalMessage);
          }
        }
      }
    }
  };
  pushStreamUpdate().catch(err => {
    logger.error(`Error streaming graph: ${err}`);
    handleStreamEnd(err.message, true);
  });
  return responseWithHeaders;
};
exports.streamGraph = streamGraph;
/**
 * Execute the graph in non-streaming mode
 *
 * @param apmTracer
 * @param assistantGraph
 * @param inputs
 * @param onLlmResponse
 * @param telemetryTracer
 * @param traceOptions
 */
const invokeGraph = async ({
  apmTracer,
  assistantGraph,
  inputs,
  onLlmResponse,
  telemetryTracer,
  traceOptions
}) => {
  return (0, _with_assistant_span.withAssistantSpan)(_graph.DEFAULT_ASSISTANT_GRAPH_ID, async span => {
    var _span$transaction, _traceOptions$tracers3, _traceOptions$tags3, _result$conversation;
    let traceData = {};
    if ((span === null || span === void 0 ? void 0 : (_span$transaction = span.transaction) === null || _span$transaction === void 0 ? void 0 : _span$transaction.ids['transaction.id']) != null && (span === null || span === void 0 ? void 0 : span.ids['trace.id']) != null) {
      traceData = {
        // Transactions ID since this span is the parent
        transactionId: span.transaction.ids['transaction.id'],
        traceId: span.ids['trace.id']
      };
      span.addLabels({
        evaluationId: traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.evaluationId
      });
    }
    const result = await assistantGraph.invoke(inputs, {
      callbacks: [apmTracer, ...((_traceOptions$tracers3 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tracers) !== null && _traceOptions$tracers3 !== void 0 ? _traceOptions$tracers3 : []), ...(telemetryTracer ? [telemetryTracer] : [])],
      runName: _graph.DEFAULT_ASSISTANT_GRAPH_ID,
      tags: (_traceOptions$tags3 = traceOptions === null || traceOptions === void 0 ? void 0 : traceOptions.tags) !== null && _traceOptions$tags3 !== void 0 ? _traceOptions$tags3 : [],
      recursionLimit: inputs !== null && inputs !== void 0 && inputs.isOssModel ? 50 : 25
    });
    const output = result.agentOutcome.returnValues.output;
    const conversationId = (_result$conversation = result.conversation) === null || _result$conversation === void 0 ? void 0 : _result$conversation.id;
    if (onLlmResponse) {
      await onLlmResponse(output, traceData);
    }
    return {
      output,
      traceData,
      conversationId
    };
  });
};
exports.invokeGraph = invokeGraph;