"use strict";

Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.getGenerateNode = void 0;
var _discard_previous_generations = require("./helpers/discard_previous_generations");
var _extract_json = require("../helpers/extract_json");
var _get_anonymized_docs_from_state = require("./helpers/get_anonymized_docs_from_state");
var _get_chain_with_format_instructions = require("../helpers/get_chain_with_format_instructions");
var _get_combined = require("../helpers/get_combined");
var _generations_are_repeating = require("../helpers/generations_are_repeating");
var _get_use_unrefined_results = require("./helpers/get_use_unrefined_results");
var _parse_combined_or_throw = require("../helpers/parse_combined_or_throw");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

const getGenerateNode = ({
  llm,
  logger,
  getCombinedPromptFn,
  responseIsHallucinated,
  generationSchema
}) => {
  const generate = async state => {
    logger === null || logger === void 0 ? void 0 : logger.debug(() => `---GENERATE---`);
    const anonymizedDocs = (0, _get_anonymized_docs_from_state.getAnonymizedDocsFromState)(state);
    const {
      prompt,
      continuePrompt,
      combinedGenerations,
      generationAttempts,
      generations,
      hallucinationFailures,
      maxGenerationAttempts,
      maxRepeatedGenerations
    } = state;
    let combinedResponse = ''; // mutable, because it must be accessed in the catch block
    let partialResponse = ''; // mutable, because it must be accessed in the catch block

    try {
      const query = getCombinedPromptFn({
        anonymizedDocs,
        prompt,
        combinedMaybePartialResults: combinedGenerations,
        continuePrompt
      });
      const {
        chain,
        formatInstructions,
        llmType
      } = (0, _get_chain_with_format_instructions.getChainWithFormatInstructions)({
        llm,
        generationSchema
      });
      logger === null || logger === void 0 ? void 0 : logger.debug(() => `generate node is invoking the chain (${llmType}), attempt ${generationAttempts}`);
      const rawResponse = await chain.invoke({
        format_instructions: formatInstructions,
        query
      });

      // LOCAL MUTATION:
      partialResponse = (0, _extract_json.extractJson)(rawResponse); // remove the surrounding ```json```

      // if the response is hallucinated, discard previous generations and start over:
      if (responseIsHallucinated(partialResponse)) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `generate node detected a hallucination (${llmType}), on attempt ${generationAttempts}; discarding the accumulated generations and starting over`);
        return (0, _discard_previous_generations.discardPreviousGenerations)({
          generationAttempts,
          hallucinationFailures,
          isHallucinationDetected: true,
          state
        });
      }

      // if the generations are repeating, discard previous generations and start over:
      if ((0, _generations_are_repeating.generationsAreRepeating)({
        currentGeneration: partialResponse,
        previousGenerations: generations,
        sampleLastNGenerations: maxRepeatedGenerations - 1
      })) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `generate node detected (${llmType}), detected ${maxRepeatedGenerations} repeated generations on attempt ${generationAttempts}; discarding the accumulated results and starting over`);

        // discard the accumulated results and start over:
        return (0, _discard_previous_generations.discardPreviousGenerations)({
          generationAttempts,
          hallucinationFailures,
          isHallucinationDetected: false,
          state
        });
      }

      // LOCAL MUTATION:
      combinedResponse = (0, _get_combined.getCombined)({
        combinedGenerations,
        partialResponse
      }); // combine the new response with the previous ones

      const unrefinedResults = (0, _parse_combined_or_throw.parseCombinedOrThrow)({
        combinedResponse,
        generationAttempts,
        llmType,
        logger,
        nodeName: 'generate',
        generationSchema
      });

      // use the unrefined results if we already reached the max number of retries:
      const useUnrefinedResults = (0, _get_use_unrefined_results.getUseUnrefinedResults)({
        generationAttempts,
        maxGenerationAttempts,
        unrefinedResults
      });
      if (useUnrefinedResults) {
        logger === null || logger === void 0 ? void 0 : logger.debug(() => `generate node is using unrefined results response (${llm._llmType()}) from attempt ${generationAttempts}, because all attempts have been used`);
      }
      return {
        ...state,
        insights: useUnrefinedResults ? unrefinedResults : null,
        // optionally skip the refinement step by returning the final answer
        combinedGenerations: combinedResponse,
        generationAttempts: generationAttempts + 1,
        generations: [...generations, partialResponse],
        unrefinedResults
      };
    } catch (error) {
      const parsingError = `generate node is unable to parse (${llm._llmType()}) response from attempt ${generationAttempts}; (this may be an incomplete response from the model): ${error}`;
      logger === null || logger === void 0 ? void 0 : logger.debug(() => parsingError); // logged at debug level because the error is expected when the model returns an incomplete response

      return {
        ...state,
        combinedGenerations: combinedResponse,
        errors: [...state.errors, parsingError],
        generationAttempts: generationAttempts + 1,
        generations: [...generations, partialResponse]
      };
    }
  };
  return generate;
};
exports.getGenerateNode = getGenerateNode;