Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_SEARCH_TRIE_H 00002 #define LM_SEARCH_TRIE_H 00003 00004 #include "lm/config.hh" 00005 #include "lm/model_type.hh" 00006 #include "lm/return.hh" 00007 #include "lm/trie.hh" 00008 #include "lm/weights.hh" 00009 00010 #include "util/file.hh" 00011 #include "util/file_piece.hh" 00012 00013 #include <vector> 00014 #include <cstdlib> 00015 #include <cassert> 00016 00017 namespace lm { 00018 namespace ngram { 00019 class BinaryFormat; 00020 class SortedVocabulary; 00021 namespace trie { 00022 00023 template <class Quant, class Bhiksha> class TrieSearch; 00024 class SortedFiles; 00025 template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); 00026 00027 template <class Quant, class Bhiksha> class TrieSearch { 00028 public: 00029 typedef NodeRange Node; 00030 00031 typedef ::lm::ngram::trie::UnigramPointer UnigramPointer; 00032 typedef typename Quant::MiddlePointer MiddlePointer; 00033 typedef typename Quant::LongestPointer LongestPointer; 00034 00035 static const bool kDifferentRest = false; 00036 00037 static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd); 00038 00039 static const unsigned int kVersion = 1; 00040 00041 static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) { 00042 Quant::UpdateConfigFromBinary(file, offset, config); 00043 // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. 00044 if (counts.size() > 2) 00045 Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); 00046 } 00047 00048 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { 00049 uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]); 00050 for (unsigned char i = 1; i < counts.size() - 1; ++i) { 00051 ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config); 00052 } 00053 return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]); 00054 } 00055 00056 TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {} 00057 00058 ~TrieSearch() { FreeMiddles(); } 00059 00060 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); 00061 00062 void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing); 00063 00064 unsigned char Order() const { 00065 return middle_end_ - middle_begin_ + 2; 00066 } 00067 00068 ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); } 00069 00070 UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { 00071 extend_left = static_cast<uint64_t>(word); 00072 UnigramPointer ret(unigram_.Find(word, next)); 00073 independent_left = (next.begin == next.end); 00074 return ret; 00075 } 00076 00077 MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { 00078 return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node)); 00079 } 00080 00081 MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const { 00082 util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left)); 00083 independent_left = (address.base == NULL) || (node.begin == node.end); 00084 return MiddlePointer(quant_, order_minus_2, address); 00085 } 00086 00087 LongestPointer LookupLongest(WordIndex word, const Node &node) const { 00088 return LongestPointer(quant_, longest_.Find(word, node)); 00089 } 00090 00091 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { 00092 assert(begin != end); 00093 bool independent_left; 00094 uint64_t ignored; 00095 LookupUnigram(*begin, node, independent_left, ignored); 00096 for (const WordIndex *i = begin + 1; i < end; ++i) { 00097 if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false; 00098 } 00099 return true; 00100 } 00101 00102 private: 00103 friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing); 00104 00105 // Middles are managed manually so we can delay construction and they don't have to be copyable. 00106 void FreeMiddles() { 00107 for (const Middle *i = middle_begin_; i != middle_end_; ++i) { 00108 i->~Middle(); 00109 } 00110 std::free(middle_begin_); 00111 } 00112 00113 typedef trie::BitPackedMiddle<Bhiksha> Middle; 00114 00115 typedef trie::BitPackedLongest Longest; 00116 Longest longest_; 00117 00118 Middle *middle_begin_, *middle_end_; 00119 Quant quant_; 00120 00121 typedef ::lm::ngram::trie::Unigram Unigram; 00122 Unigram unigram_; 00123 }; 00124 00125 } // namespace trie 00126 } // namespace ngram 00127 } // namespace lm 00128 00129 #endif // LM_SEARCH_TRIE_H