"use strict";

Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.continueConversation = continueConversation;
var _gptTokenizer = require("gpt-tokenizer");
var _lodash = require("lodash");
var _rxjs = require("rxjs");
var _common = require("../../../../common");
var _conversation_complete = require("../../../../common/conversation_complete");
var _types = require("../../../../common/functions/types");
var _create_function_response_message = require("../../../../common/utils/create_function_response_message");
var _emit_with_concatenated_message = require("../../../../common/utils/emit_with_concatenated_message");
var _without_token_count_events = require("../../../../common/utils/without_token_count_events");
var _create_server_side_function_response_error = require("../../util/create_server_side_function_response_error");
var _get_system_message_from_instructions = require("../../util/get_system_message_from_instructions");
var _replace_system_message = require("../../util/replace_system_message");
var _catch_function_not_found_error = require("./catch_function_not_found_error");
var _extract_messages = require("./extract_messages");
var _hide_token_count_events = require("./hide_token_count_events");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
function executeFunctionAndCatchError({
  name,
  args,
  functionClient,
  messages,
  chat,
  signal,
  logger
}) {
  // hide token count events from functions to prevent them from
  // having to deal with it as well
  return (0, _hide_token_count_events.hideTokenCountEvents)(hide => {
    const executeFunctionResponse$ = (0, _rxjs.from)(functionClient.executeFunction({
      name,
      chat: (operationName, params) => {
        return chat(operationName, params).pipe(hide());
      },
      args,
      signal,
      messages
    }));
    return executeFunctionResponse$.pipe((0, _rxjs.catchError)(error => {
      logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`);
      // We want to catch the error only when a promise occurs
      // if it occurs in the Observable, we cannot easily recover
      // from it because the function may have already emitted
      // values which could lead to an invalid conversation state,
      // so in that case we let the stream fail.
      return (0, _rxjs.of)((0, _create_server_side_function_response_error.createServerSideFunctionResponseError)({
        name,
        error
      }));
    }), (0, _rxjs.switchMap)(response => {
      if ((0, _rxjs.isObservable)(response)) {
        return response;
      }

      // is messageAdd event
      if ('type' in response) {
        return (0, _rxjs.of)(response);
      }
      const encoded = (0, _gptTokenizer.encode)(JSON.stringify(response.content || {}));
      const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
      return (0, _rxjs.of)((0, _create_function_response_message.createFunctionResponseMessage)({
        name,
        content: exceededTokenLimit ? {
          message: 'Function response exceeded the maximum length allowed and was truncated',
          truncated: (0, _gptTokenizer.decode)((0, _lodash.take)(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT))
        } : response.content,
        data: response.data
      }));
    }));
  });
}
function getFunctionDefinitions({
  functionClient,
  functionLimitExceeded,
  disableFunctions
}) {
  if (functionLimitExceeded || disableFunctions) {
    return [];
  }
  const systemFunctions = functionClient.getFunctions().map(fn => fn.definition).filter(def => !def.visibility || [_types.FunctionVisibility.AssistantOnly, _types.FunctionVisibility.All].includes(def.visibility));
  const actions = functionClient.getActions();
  const allDefinitions = systemFunctions.concat(actions).map(definition => (0, _lodash.pick)(definition, 'name', 'description', 'parameters'));
  return allDefinitions;
}
function continueConversation({
  messages: initialMessages,
  functionClient,
  chat,
  signal,
  functionCallsLeft,
  requestInstructions,
  knowledgeBaseInstructions,
  logger,
  disableFunctions
}) {
  let nextFunctionCallsLeft = functionCallsLeft;
  const functionLimitExceeded = functionCallsLeft <= 0;
  const definitions = getFunctionDefinitions({
    functionLimitExceeded,
    functionClient,
    disableFunctions
  });
  const messagesWithUpdatedSystemMessage = (0, _replace_system_message.replaceSystemMessage)((0, _get_system_message_from_instructions.getSystemMessageFromInstructions)({
    registeredInstructions: functionClient.getInstructions(),
    knowledgeBaseInstructions,
    requestInstructions,
    availableFunctionNames: definitions.map(def => def.name)
  }), initialMessages);
  const lastMessage = messagesWithUpdatedSystemMessage[messagesWithUpdatedSystemMessage.length - 1].message;
  const isUserMessage = lastMessage.role === _common.MessageRole.User;
  return executeNextStep().pipe(handleEvents());
  function executeNextStep() {
    var _lastMessage$function;
    if (isUserMessage) {
      const operationName = lastMessage.name && lastMessage.name !== 'context' ? `function_response ${lastMessage.name}` : 'user_message';
      return chat(operationName, {
        messages: messagesWithUpdatedSystemMessage,
        functions: definitions
      }).pipe((0, _emit_with_concatenated_message.emitWithConcatenatedMessage)(), functionLimitExceeded ? (0, _catch_function_not_found_error.catchFunctionNotFoundError)() : _rxjs.identity);
    }
    const functionCallName = (_lastMessage$function = lastMessage.function_call) === null || _lastMessage$function === void 0 ? void 0 : _lastMessage$function.name;
    if (!functionCallName) {
      // reply from the LLM without a function request,
      // so we can close the stream and wait for input from the user
      return _rxjs.EMPTY;
    }

    // we know we are executing a function here, so we can already
    // subtract one, and reference the old count for if clauses
    const currentFunctionCallsLeft = nextFunctionCallsLeft;
    nextFunctionCallsLeft--;
    const isAction = functionCallName && functionClient.hasAction(functionCallName);
    if (currentFunctionCallsLeft === 0) {
      // create a function call response error so the LLM knows it needs to stop calling functions
      return (0, _rxjs.of)((0, _create_server_side_function_response_error.createServerSideFunctionResponseError)({
        name: functionCallName,
        error: (0, _conversation_complete.createFunctionLimitExceededError)()
      }));
    }
    if (currentFunctionCallsLeft < 0) {
      // LLM tried calling it anyway, throw an error
      return (0, _rxjs.throwError)(() => (0, _conversation_complete.createFunctionLimitExceededError)());
    }

    // if it's an action, we close the stream and wait for the action response
    // from the client/browser
    if (isAction) {
      try {
        functionClient.validate(functionCallName, JSON.parse(lastMessage.function_call.arguments || '{}'));
      } catch (error) {
        // return a function response error for the LLM to handle
        return (0, _rxjs.of)((0, _create_server_side_function_response_error.createServerSideFunctionResponseError)({
          name: functionCallName,
          error
        }));
      }
      return _rxjs.EMPTY;
    }
    if (!functionClient.hasFunction(functionCallName)) {
      // tell the LLM the function was not found
      return (0, _rxjs.of)((0, _create_server_side_function_response_error.createServerSideFunctionResponseError)({
        name: functionCallName,
        error: (0, _common.createFunctionNotFoundError)(functionCallName)
      }));
    }
    return executeFunctionAndCatchError({
      name: functionCallName,
      args: lastMessage.function_call.arguments,
      chat,
      functionClient,
      messages: messagesWithUpdatedSystemMessage,
      signal,
      logger
    });
  }
  function handleEvents() {
    return events$ => {
      const shared$ = events$.pipe((0, _rxjs.shareReplay)());
      return (0, _rxjs.concat)(shared$, shared$.pipe((0, _without_token_count_events.withoutTokenCountEvents)(), (0, _extract_messages.extractMessages)(), (0, _rxjs.switchMap)(extractedMessages => {
        if (!extractedMessages.length) {
          return _rxjs.EMPTY;
        }
        return continueConversation({
          messages: messagesWithUpdatedSystemMessage.concat(extractedMessages),
          chat,
          functionCallsLeft: nextFunctionCallsLeft,
          functionClient,
          signal,
          knowledgeBaseInstructions,
          requestInstructions,
          logger,
          disableFunctions
        });
      })));
    };
  }
}