Program Listing for File batching_pool.cpp¶
↰ Return to documentation for file (src/translator/batching_pool.cpp
)
#include "batching_pool.h"
#include <cassert>
#include "batch.h"
#include "common/logging.h"
namespace marian {
namespace bergamot {
BatchingPool::BatchingPool(Ptr<Options> options)
: miniBatchWords_(options->get<int>("mini-batch-words")), maxActiveBucketLength_(0) {
size_t maxLengthBreak = options->get<int>("max-length-break");
float maxLengthFactor = options->get<float>("max-length-factor", 3.0);
// For the time being, we add some slack, which only BatchingPool is aware of. Since the TextProcessor still wraps at
// first request in, most of the Batches generated will be under max-length break.
//
// In the unlikely event of a few sentences overflowing, this allows the exceeding words to be put in the slack area.
// Very few batches are expected to be generated at a higher length.
size_t pivotSlack = maxLengthBreak * maxLengthFactor - maxLengthBreak;
bucket_.resize(maxLengthBreak + pivotSlack + 1);
ABORT_IF(bucket_.size() - 1 > miniBatchWords_,
"Fatal: max-length-break > mini-batch-words will lead to sentences "
"longer than what can fit in a batch.");
}
size_t BatchingPool::generateBatch(Batch &batch) {
// For now simply iterates on buckets and converts batches greedily. This
// has to be enhanced with optimizing over priority. The baseline
// implementation should at least be as fast as marian's maxi-batch with full
// corpus size as maxi-batch size.
batch.clear();
size_t paddedBatchSize = 0;
for (size_t length = 0; length <= maxActiveBucketLength_; length++) {
auto p = bucket_[length].begin();
while (p != bucket_[length].end()) {
paddedBatchSize = (batch.size() + 1) * length;
if (paddedBatchSize <= miniBatchWords_) {
auto q = p++;
batch.add(*q);
bucket_[length].erase(q);
} else {
// Check if elements exist
assert(batch.size() > 0);
return batch.size();
}
}
}
return batch.size();
}
size_t BatchingPool::enqueueRequest(Ptr<Request> request) {
size_t toBeFreshlyTranslated = 0;
for (size_t i = 0; i < request->numSegments(); i++) {
if (!request->cacheHitPrefilled(i)) {
RequestSentence sentence(i, request);
size_t bucket_id = sentence.numTokens();
// Due to a workaround for pivoting, unless we can discipline the
// vocabulary to get stronger static requirements, it is difficult to
// rework the rest of the components. Instead, we allow dynamic growth
// here. We let std::vector take care of the dynamic growth.
// https://en.cppreference.com/w/cpp/container/vector/resize#Complexity
if (bucket_id >= bucket_.size()) {
bucket_.resize(bucket_id + 1);
}
bucket_[bucket_id].insert(sentence);
maxActiveBucketLength_ = std::max<size_t>(bucket_id, maxActiveBucketLength_);
toBeFreshlyTranslated += 1;
}
}
return toBeFreshlyTranslated;
}
void BatchingPool::clear() {
for (size_t length = 0; length < bucket_.size(); length++) {
bucket_[length].clear();
}
}
} // namespace bergamot
} // namespace marian