Program Listing for File translation_model.h

Return to documentation for file (src/translator/translation_model.h)

#ifndef SRC_BERGAMOT_TRANSLATION_MODEL_H_
#define SRC_BERGAMOT_TRANSLATION_MODEL_H_

#include <string>
#include <vector>

#include "batch.h"
#include "batching_pool.h"
#include "byte_array_util.h"
#include "cache.h"
#include "common/utils.h"
#include "data/shortlist.h"
#include "definitions.h"
#include "parser.h"
#include "request.h"
#include "text_processor.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "vocabs.h"

namespace marian {
namespace bergamot {


class TranslationModel {
 public:
  using Config = Ptr<Options>;
  using ShortlistGenerator = Ptr<data::ShortlistGenerator const>;

  TranslationModel(const std::string& config, MemoryBundle&& memory, size_t replicas = 1)
      : TranslationModel(parseOptionsFromString(config, /*validate=*/false), std::move(memory), replicas){};

  TranslationModel(const Config& options, MemoryBundle&& memory, size_t replicas = 1);

  TranslationModel(const Config& options, size_t replicas = 1)
      : TranslationModel(options, getMemoryBundleFromConfig(options), replicas) {}

  //  @returns Request created from the query parameters wrapped within a shared-pointer.
  Ptr<Request> makeRequest(size_t requestId, std::string&& source, CallbackType callback,
                           const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);

  Ptr<Request> makePivotRequest(size_t requestId, AnnotatedText&& previousTarget, CallbackType callback,
                                const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);

  size_t enqueueRequest(Ptr<Request> request) { return batchingPool_.enqueueRequest(request); };

  size_t generateBatch(Batch& batch) { return batchingPool_.generateBatch(batch); }

  void translateBatch(size_t deviceId, Batch& batch);

  size_t modelId() const { return modelId_; }

 private:
  size_t modelId_;
  Config options_;
  MemoryBundle memory_;
  Vocabs vocabs_;
  TextProcessor textProcessor_;

  BatchingPool batchingPool_;

  struct MarianBackend {
    using Graph = Ptr<ExpressionGraph>;
    using ScorerEnsemble = std::vector<Ptr<Scorer>>;

    Graph graph;
    ScorerEnsemble scorerEnsemble;
    bool initialized{false};
  };

  // ShortlistGenerator is purely const, we don't need one per thread.
  ShortlistGenerator shortlistGenerator_;

  std::vector<MarianBackend> backend_;
  std::shared_ptr<QualityEstimator> qualityEstimator_;

  void loadBackend(size_t idx);
  Ptr<marian::data::CorpusBatch> convertToMarianBatch(Batch& batch);

  static std::atomic<size_t> modelCounter_;
};

}  // namespace bergamot
}  // namespace marian

#endif  //  SRC_BERGAMOT_TRANSLATION_MODEL_H_