Program Listing for File translation_model.h¶
↰ Return to documentation for file (src/translator/translation_model.h
)
#ifndef SRC_BERGAMOT_TRANSLATION_MODEL_H_
#define SRC_BERGAMOT_TRANSLATION_MODEL_H_
#include <string>
#include <vector>
#include "batch.h"
#include "batching_pool.h"
#include "byte_array_util.h"
#include "cache.h"
#include "common/utils.h"
#include "data/shortlist.h"
#include "definitions.h"
#include "parser.h"
#include "request.h"
#include "text_processor.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "vocabs.h"
namespace marian {
namespace bergamot {
class TranslationModel {
public:
using Config = Ptr<Options>;
using ShortlistGenerator = Ptr<data::ShortlistGenerator const>;
TranslationModel(const std::string& config, MemoryBundle&& memory, size_t replicas = 1)
: TranslationModel(parseOptionsFromString(config, /*validate=*/false), std::move(memory), replicas){};
TranslationModel(const Config& options, MemoryBundle&& memory, size_t replicas = 1);
TranslationModel(const Config& options, size_t replicas = 1)
: TranslationModel(options, getMemoryBundleFromConfig(options), replicas) {}
// @returns Request created from the query parameters wrapped within a shared-pointer.
Ptr<Request> makeRequest(size_t requestId, std::string&& source, CallbackType callback,
const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);
Ptr<Request> makePivotRequest(size_t requestId, AnnotatedText&& previousTarget, CallbackType callback,
const ResponseOptions& responseOptions, std::optional<TranslationCache>& cache);
size_t enqueueRequest(Ptr<Request> request) { return batchingPool_.enqueueRequest(request); };
size_t generateBatch(Batch& batch) { return batchingPool_.generateBatch(batch); }
void translateBatch(size_t deviceId, Batch& batch);
size_t modelId() const { return modelId_; }
private:
size_t modelId_;
Config options_;
MemoryBundle memory_;
Vocabs vocabs_;
TextProcessor textProcessor_;
BatchingPool batchingPool_;
struct MarianBackend {
using Graph = Ptr<ExpressionGraph>;
using ScorerEnsemble = std::vector<Ptr<Scorer>>;
Graph graph;
ScorerEnsemble scorerEnsemble;
bool initialized{false};
};
// ShortlistGenerator is purely const, we don't need one per thread.
ShortlistGenerator shortlistGenerator_;
std::vector<MarianBackend> backend_;
std::shared_ptr<QualityEstimator> qualityEstimator_;
void loadBackend(size_t idx);
Ptr<marian::data::CorpusBatch> convertToMarianBatch(Batch& batch);
static std::atomic<size_t> modelCounter_;
};
} // namespace bergamot
} // namespace marian
#endif // SRC_BERGAMOT_TRANSLATION_MODEL_H_