Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/search_trie.hh
00001 #ifndef LM_SEARCH_TRIE_H
00002 #define LM_SEARCH_TRIE_H
00003 
00004 #include "lm/config.hh"
00005 #include "lm/model_type.hh"
00006 #include "lm/return.hh"
00007 #include "lm/trie.hh"
00008 #include "lm/weights.hh"
00009 
00010 #include "util/file.hh"
00011 #include "util/file_piece.hh"
00012 
00013 #include <vector>
00014 #include <cstdlib>
00015 #include <cassert>
00016 
00017 namespace lm {
00018 namespace ngram {
00019 class BinaryFormat;
00020 class SortedVocabulary;
00021 namespace trie {
00022 
00023 template <class Quant, class Bhiksha> class TrieSearch;
00024 class SortedFiles;
00025 template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
00026 
00027 template <class Quant, class Bhiksha> class TrieSearch {
00028   public:
00029     typedef NodeRange Node;
00030 
00031     typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
00032     typedef typename Quant::MiddlePointer MiddlePointer;
00033     typedef typename Quant::LongestPointer LongestPointer;
00034 
00035     static const bool kDifferentRest = false;
00036 
00037     static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
00038 
00039     static const unsigned int kVersion = 1;
00040 
00041     static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
00042       Quant::UpdateConfigFromBinary(file, offset, config);
00043       // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
00044       if (counts.size() > 2)
00045         Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
00046     }
00047 
00048     static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00049       uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
00050       for (unsigned char i = 1; i < counts.size() - 1; ++i) {
00051         ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
00052       }
00053       return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
00054     }
00055 
00056     TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
00057 
00058     ~TrieSearch() { FreeMiddles(); }
00059 
00060     uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00061 
00062     void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
00063 
00064     unsigned char Order() const {
00065       return middle_end_ - middle_begin_ + 2;
00066     }
00067 
00068     ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }
00069 
00070     UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
00071       extend_left = static_cast<uint64_t>(word);
00072       UnigramPointer ret(unigram_.Find(word, next));
00073       independent_left = (next.begin == next.end);
00074       return ret;
00075     }
00076 
00077     MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
00078       return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
00079     }
00080 
00081     MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
00082       util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
00083       independent_left = (address.base == NULL) || (node.begin == node.end);
00084       return MiddlePointer(quant_, order_minus_2, address);
00085     }
00086 
00087     LongestPointer LookupLongest(WordIndex word, const Node &node) const {
00088       return LongestPointer(quant_, longest_.Find(word, node));
00089     }
00090 
00091     bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00092       assert(begin != end);
00093       bool independent_left;
00094       uint64_t ignored;
00095       LookupUnigram(*begin, node, independent_left, ignored);
00096       for (const WordIndex *i = begin + 1; i < end; ++i) {
00097         if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
00098       }
00099       return true;
00100     }
00101 
00102   private:
00103     friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
00104 
00105     // Middles are managed manually so we can delay construction and they don't have to be copyable.
00106     void FreeMiddles() {
00107       for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
00108         i->~Middle();
00109       }
00110       std::free(middle_begin_);
00111     }
00112 
00113     typedef trie::BitPackedMiddle<Bhiksha> Middle;
00114 
00115     typedef trie::BitPackedLongest Longest;
00116     Longest longest_;
00117 
00118     Middle *middle_begin_, *middle_end_;
00119     Quant quant_;
00120 
00121     typedef ::lm::ngram::trie::Unigram Unigram;
00122     Unigram unigram_;
00123 };
00124 
00125 } // namespace trie
00126 } // namespace ngram
00127 } // namespace lm
00128 
00129 #endif // LM_SEARCH_TRIE_H