Program Listing for File translation_model.cpp

Return to documentation for file (src/translator/translation_model.cpp)

#include "translation_model.h"

#include "batch.h"
#include "byte_array_util.h"
#include "cache.h"
#include "common/logging.h"
#include "data/corpus.h"
#include "data/text_input.h"
#include "html.h"
#include "parser.h"
#include "translator/beam_search.h"

namespace marian {
namespace bergamot {

std::atomic<size_t> TranslationModel::modelCounter_ = 0;

TranslationModel::TranslationModel(const Config &options, MemoryBundle &&memory /*=MemoryBundle{}*/,
                                   size_t replicas /*=1*/)
    : modelId_(modelCounter_++),
      options_(options),
      memory_(std::move(memory)),
      vocabs_(options, std::move(memory_.vocabs)),
      textProcessor_(options, vocabs_, std::move(memory_.ssplitPrefixFile)),
      batchingPool_(options),
      qualityEstimator_(createQualityEstimator(getQualityEstimatorModel(memory, options))) {
  ABORT_IF(replicas == 0, "At least one replica needs to be created.");
  backend_.resize(replicas);

  // Try to load shortlist from memory-bundle. If not available, try to load from options_;

  int srcIdx = 0, trgIdx = 1;
  // vocabs_->sources().front() is invoked as we currently only support one source vocab
  bool shared_vcb = (vocabs_.sources().front() == vocabs_.target());

  if (memory_.shortlist.size() > 0 && memory_.shortlist.begin() != nullptr) {
    bool check = options_->get<bool>("check-bytearray", false);
    shortlistGenerator_ = New<data::BinaryShortlistGenerator>(memory_.shortlist.begin(), memory_.shortlist.size(),
                                                              vocabs_.sources().front(), vocabs_.target(), srcIdx,
                                                              trgIdx, shared_vcb, check);
  } else if (options_->hasAndNotEmpty("shortlist")) {
    // Changed to BinaryShortlistGenerator to enable loading binary shortlist file
    // This class also supports text shortlist file
    shortlistGenerator_ = New<data::BinaryShortlistGenerator>(options_, vocabs_.sources().front(), vocabs_.target(),
                                                              srcIdx, trgIdx, shared_vcb);
  } else {
    // In this case, the loadpath does not load shortlist.
    shortlistGenerator_ = nullptr;
  }
}

void TranslationModel::loadBackend(size_t idx) {
  auto &graph = backend_[idx].graph;
  auto &scorerEnsemble = backend_[idx].scorerEnsemble;

  marian::DeviceId device_(idx, DeviceType::cpu);
  graph = New<ExpressionGraph>(/*inference=*/true);  // set the graph to be inference only
  auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
  graph->setDefaultElementType(typeFromString(prec[0]));
  graph->setDevice(device_);
  graph->getBackend()->configureDevice(options_);
  graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));

  // if memory_.models is populated, then all models were of binary format
  if (memory_.models.size() >= 1) {
    const std::vector<const void *> container = std::invoke([&]() {
      std::vector<const void *> model_ptrs(memory_.models.size());
      for (size_t i = 0; i < memory_.models.size(); ++i) {
        const AlignedMemory &model = memory_.models[i];

        ABORT_IF(model.size() == 0 || model.begin() == nullptr, "The provided memory is empty. Cannot load the model.");
        ABORT_IF(
            (uintptr_t)model.begin() % 256 != 0,
            "The provided memory is not aligned to 256 bytes and will crash when vector instructions are used on it.");
        if (options_->get<bool>("check-bytearray", false)) {
          ABORT_IF(!validateBinaryModel(model, model.size()),
                   "The binary file is invalid. Incomplete or corrupted download?");
        }

        model_ptrs[i] = model.begin();
        LOG(debug, "Loaded model {} of {} from memory", (i + 1), model_ptrs.size());
      }
      return model_ptrs;
    });

    scorerEnsemble = createScorers(options_, container);
  } else {
    // load npz format models, or a mixture of binary/npz formats
    scorerEnsemble = createScorers(options_);
    LOG(debug, "Loaded {} model(s) from file", scorerEnsemble.size());
  }

  for (auto scorer : scorerEnsemble) {
    scorer->init(graph);
    if (shortlistGenerator_) {
      scorer->setShortlistGenerator(shortlistGenerator_);
    }
  }
  graph->forward();
}

// Make request process is shared between Async and Blocking workflow of translating.
Ptr<Request> TranslationModel::makeRequest(size_t requestId, std::string &&source, CallbackType callback,
                                           const ResponseOptions &responseOptions,
                                           std::optional<TranslationCache> &cache) {
  Segments segments;
  AnnotatedText annotatedSource;

  textProcessor_.process(std::move(source), annotatedSource, segments);
  ResponseBuilder responseBuilder(responseOptions, std::move(annotatedSource), vocabs_, callback, *qualityEstimator_);

  Ptr<Request> request =
      New<Request>(requestId, /*model=*/*this, std::move(segments), std::move(responseBuilder), cache);
  return request;
}

Ptr<Request> TranslationModel::makePivotRequest(size_t requestId, AnnotatedText &&previousTarget, CallbackType callback,
                                                const ResponseOptions &responseOptions,
                                                std::optional<TranslationCache> &cache) {
  Segments segments;

  textProcessor_.processFromAnnotation(previousTarget, segments);
  ResponseBuilder responseBuilder(responseOptions, std::move(previousTarget), vocabs_, callback, *qualityEstimator_);

  Ptr<Request> request = New<Request>(requestId, *this, std::move(segments), std::move(responseBuilder), cache);
  return request;
}

Ptr<marian::data::CorpusBatch> TranslationModel::convertToMarianBatch(Batch &batch) {
  std::vector<data::SentenceTuple> batchVector;
  auto &sentences = batch.sentences();

  size_t batchSequenceNumber{0};
  for (auto &sentence : sentences) {
    data::SentenceTuple sentence_tuple(batchSequenceNumber);
    Segment segment = sentence.getUnderlyingSegment();
    sentence_tuple.push_back(segment);
    batchVector.push_back(sentence_tuple);

    ++batchSequenceNumber;
  }

  // Usually one would expect inputs to be [B x T], where B = batch-size and T = max seq-len among the sentences in the
  // batch. However, marian's library supports multi-source and ensembling through different source-vocabulary but same
  // target vocabulary. This means the inputs are 3 dimensional when converted into marian's library formatted batches.
  //
  // Consequently B x T projects to N x B x T, where N = ensemble size. This adaptation does not fully force the idea of
  // N = 1 (the code remains general, but N iterates only from 0-1 in the nested loop).

  size_t batchSize = batchVector.size();

  std::vector<size_t> sentenceIds;
  std::vector<int> maxDims;

  for (auto &example : batchVector) {
    if (maxDims.size() < example.size()) {
      maxDims.resize(example.size(), 0);
    }
    for (size_t i = 0; i < example.size(); ++i) {
      if (example[i].size() > static_cast<size_t>(maxDims[i])) {
        maxDims[i] = static_cast<int>(example[i].size());
      }
    }
    sentenceIds.push_back(example.getId());
  }

  using SubBatch = marian::data::SubBatch;
  std::vector<Ptr<SubBatch>> subBatches;
  for (size_t j = 0; j < maxDims.size(); ++j) {
    subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_.sources().at(j)));
  }

  std::vector<size_t> words(maxDims.size(), 0);
  for (size_t i = 0; i < batchSize; ++i) {
    for (size_t j = 0; j < maxDims.size(); ++j) {
      for (size_t k = 0; k < batchVector[i][j].size(); ++k) {
        subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k];
        subBatches[j]->mask()[k * batchSize + i] = 1.f;
        words[j]++;
      }
    }
  }

  for (size_t j = 0; j < maxDims.size(); ++j) {
    subBatches[j]->setWords(words[j]);
  }

  using CorpusBatch = marian::data::CorpusBatch;
  Ptr<CorpusBatch> corpusBatch = New<CorpusBatch>(subBatches);
  corpusBatch->setSentenceIds(sentenceIds);
  return corpusBatch;
}

void TranslationModel::translateBatch(size_t deviceId, Batch &batch) {
  auto &backend = backend_[deviceId];

  if (!backend.initialized) {
    loadBackend(deviceId);
    backend.initialized = true;
  }

  BeamSearch search(options_, backend.scorerEnsemble, vocabs_.target());
  Histories histories = search.search(backend.graph, convertToMarianBatch(batch));
  batch.completeBatch(histories);
}

}  // namespace bergamot
}  // namespace marian