"use strict";

var _interopRequireDefault = require("@babel/runtime/helpers/interopRequireDefault");
Object.defineProperty(exports, "__esModule", {
  value: true
});
exports.FillMaskInference = void 0;
var _defineProperty2 = _interopRequireDefault(require("@babel/runtime/helpers/defineProperty"));
var _i18n = require("@kbn/i18n");
var _operators = require("rxjs/operators");
var _mlTrainedModelsUtils = require("@kbn/ml-trained-models-utils");
var _inference_base = require("../inference_base");
var _common = require("./common");
var _text_input = require("../text_input");
var _fill_mask_output = require("./fill_mask_output");
/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the Elastic License
 * 2.0; you may not use this file except in compliance with the Elastic License
 * 2.0.
 */

const MASK = '[MASK]';
class FillMaskInference extends _inference_base.InferenceBase {
  constructor(trainedModelsApi, model, inputType, deploymentId) {
    super(trainedModelsApi, model, inputType, deploymentId);
    (0, _defineProperty2.default)(this, "inferenceType", _mlTrainedModelsUtils.SUPPORTED_PYTORCH_TASKS.FILL_MASK);
    (0, _defineProperty2.default)(this, "inferenceTypeLabel", _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.fillMask.label', {
      defaultMessage: 'Fill mask'
    }));
    (0, _defineProperty2.default)(this, "info", [_i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.fillMask.info1', {
      defaultMessage: 'Test how well the model predicts a missing word in a phrase.'
    })]);
    this.initialize([this.inputText$.pipe((0, _operators.map)(inputText => inputText.every(t => t.includes(MASK))))]);
  }
  async inferText() {
    return this.runInfer(() => {
      return this.getInferenceConfig(this.getNumTopClassesConfig());
    }, (resp, inputText) => {
      return (0, _common.processResponse)(resp, this.model, inputText);
    });
  }
  async inferIndex() {
    return this.runPipelineSimulate(doc => {
      return {
        response: (0, _common.processInferenceResult)(doc._source[this.inferenceType], this.model),
        rawResponse: doc._source[this.inferenceType],
        inputText: doc._source[this.getInputField()]
      };
    });
  }
  getProcessors() {
    return this.getBasicProcessors(this.getNumTopClassesConfig());
  }
  predictedValue(resp) {
    var _response$;
    const {
      response,
      inputText
    } = resp;
    return (_response$ = response[0]) !== null && _response$ !== void 0 && _response$.value ? inputText.replace(MASK, response[0].value) : inputText;
  }
  getInputComponent() {
    if (this.inputType === _inference_base.INPUT_TYPE.TEXT) {
      const placeholder = _i18n.i18n.translate('xpack.ml.trainedModels.testModelsFlyout.fillMask.inputText', {
        defaultMessage: 'Enter a phrase to test. Use [MASK] as a placeholder for the missing words.'
      });
      return (0, _text_input.getGeneralInputComponent)(this, placeholder);
    } else {
      return null;
    }
  }
  getOutputComponent() {
    return (0, _fill_mask_output.getFillMaskOutputComponent)(this);
  }
}
exports.FillMaskInference = FillMaskInference;