Program Listing for File byte_array_util.cpp¶
↰ Return to documentation for file (src/translator/byte_array_util.cpp
)
#include "byte_array_util.h"
#include <cstdlib>
#include <memory>
#include "common/io.h"
#include "data/shortlist.h"
namespace marian {
namespace bergamot {
namespace {
// This is a basic validator that checks if the file has not been truncated
// it basically loads up the header and checks
// This struct and the getter are copied from the marian source, because it's located
// inside src/common/binary.cpp:15 and we can't include it.
struct Header {
uint64_t nameLength;
uint64_t type;
uint64_t shapeLength;
uint64_t dataLength;
};
// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, uint64_t num = 1) {
const T* ptr = (const T*)current;
current = (const T*)current + num;
return ptr;
}
} // Anonymous namespace
bool validateBinaryModel(const AlignedMemory& model, uint64_t fileSize) {
const void* current = model.begin();
uint64_t memoryNeeded =
sizeof(uint64_t) * 2; // We keep track of how much memory we would need if we have a complete file
uint64_t numHeaders;
if (fileSize >= memoryNeeded) { // We have enough filesize to fetch the headers.
uint64_t binaryFileVersion = *get<uint64_t>(current);
numHeaders = *get<uint64_t>(current); // number of item headers that follow
} else {
return false;
}
memoryNeeded += numHeaders * sizeof(Header);
const Header* headers;
if (fileSize >= memoryNeeded) {
headers = get<Header>(current, numHeaders); // read that many headers
} else {
return false;
}
// Calculate how many bytes we are going to for reading just the names and the shape
for (uint64_t i = 0; i < numHeaders; i++) {
memoryNeeded += headers[i].nameLength + headers[i].shapeLength * sizeof(int);
// Advance the pointers.
get<char>(current, headers[i].nameLength);
get<int>(current, headers[i].shapeLength);
}
// Before we start reading the data, there is a small padding to ensure alignment
// Read that in, before calculating the actual tensor memory requirements.
uint64_t aligned_offset;
if (fileSize >= memoryNeeded) {
aligned_offset = *get<uint64_t>(current); // Offset to align memory to 256 size
memoryNeeded += aligned_offset + sizeof(uint64_t);
} else {
return false;
}
// Finally the tensor size:
for (uint64_t i = 0; i < numHeaders; i++) {
memoryNeeded += headers[i].dataLength;
}
// If this final check passes, the file is at least big enough to contain the model
if (fileSize >= memoryNeeded) {
return true;
} else {
return false;
}
}
AlignedMemory loadFileToMemory(const std::string& path, size_t alignment) {
uint64_t fileSize = filesystem::fileSize(path);
io::InputFileStream in(path);
ABORT_IF(in.bad(), "Failed opening file stream: {}", path);
AlignedMemory alignedMemory(fileSize, alignment);
in.read(reinterpret_cast<char*>(alignedMemory.begin()), fileSize);
ABORT_IF(alignedMemory.size() != fileSize, "Error reading file {}", path);
return alignedMemory;
}
std::vector<AlignedMemory> getModelMemoryFromConfig(marian::Ptr<marian::Options> options) {
auto models = options->get<std::vector<std::string>>("models");
std::vector<AlignedMemory> modelMemories(models.size());
for (size_t i = 0; i < models.size(); ++i) {
const auto model = models[i];
if (marian::io::isBin(model)) {
modelMemories[i] = loadFileToMemory(model, 256);
} else if (marian::io::isNpz(model)) {
// if any of the models are npz format, we revert to loading from file for all models.
LOG(debug, "Encountered an npz file {}; will use file loading for {} models", model, models.size());
return {};
} else {
ABORT("Unknown extension for model: {}, should be one of `.bin` or `.npz`", model);
}
}
return modelMemories;
}
AlignedMemory getShortlistMemoryFromConfig(marian::Ptr<marian::Options> options) {
auto shortlist = options->get<std::vector<std::string>>("shortlist");
if (!shortlist.empty()) {
ABORT_IF(!marian::data::isBinaryShortlist(shortlist[0]),
"Loading non-binary shortlist file into memory is not supported");
return loadFileToMemory(shortlist[0], 64);
}
return AlignedMemory();
}
void getVocabsMemoryFromConfig(marian::Ptr<marian::Options> options,
std::vector<std::shared_ptr<AlignedMemory>>& vocabMemories) {
auto vfiles = options->get<std::vector<std::string>>("vocabs");
ABORT_IF(vfiles.size() < 2, "Insufficient number of vocabularies.");
vocabMemories.resize(vfiles.size());
std::unordered_map<std::string, std::shared_ptr<AlignedMemory>> vocabMap;
for (size_t i = 0; i < vfiles.size(); ++i) {
ABORT_IF(marian::filesystem::Path(vfiles[i]).extension() != marian::filesystem::Path(".spm"),
"Loading non-SentencePiece vocab files into memory is not supported");
auto m = vocabMap.emplace(std::make_pair(vfiles[i], std::shared_ptr<AlignedMemory>()));
if (m.second) {
m.first->second = std::make_shared<AlignedMemory>(loadFileToMemory(vfiles[i], 64));
}
vocabMemories[i] = m.first->second;
}
}
AlignedMemory getQualityEstimatorModel(const marian::Ptr<marian::Options>& options) {
const auto qualityEstimatorPath = options->get<std::string>("quality", "");
if (qualityEstimatorPath.empty()) {
return {};
}
return loadFileToMemory(qualityEstimatorPath, 64);
}
AlignedMemory getQualityEstimatorModel(MemoryBundle& memoryBundle, const marian::Ptr<marian::Options>& options) {
if (memoryBundle.qualityEstimatorMemory.size() == 0) {
return getQualityEstimatorModel(options);
}
return std::move(memoryBundle.qualityEstimatorMemory);
}
MemoryBundle getMemoryBundleFromConfig(marian::Ptr<marian::Options> options) {
MemoryBundle memoryBundle;
memoryBundle.models = getModelMemoryFromConfig(options);
memoryBundle.shortlist = getShortlistMemoryFromConfig(options);
getVocabsMemoryFromConfig(options, memoryBundle.vocabs);
memoryBundle.ssplitPrefixFile = getSsplitPrefixFileMemoryFromConfig(options);
memoryBundle.qualityEstimatorMemory = getQualityEstimatorModel(options);
return memoryBundle;
}
AlignedMemory getSsplitPrefixFileMemoryFromConfig(marian::Ptr<marian::Options> options) {
std::string fpath = options->get<std::string>("ssplit-prefix-file", "");
if (!fpath.empty()) {
return loadFileToMemory(fpath, 64);
}
// Return empty AlignedMemory
return AlignedMemory();
}
} // namespace bergamot
} // namespace marian