Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_TRIE_H 00002 #define LM_TRIE_H 00003 00004 #include "lm/weights.hh" 00005 #include "lm/word_index.hh" 00006 #include "util/bit_packing.hh" 00007 00008 #include <cstddef> 00009 00010 #include <stdint.h> 00011 00012 namespace lm { 00013 namespace ngram { 00014 struct Config; 00015 namespace trie { 00016 00017 struct NodeRange { 00018 uint64_t begin, end; 00019 }; 00020 00021 // TODO: if the number of unigrams is a concern, also bit pack these records. 00022 struct UnigramValue { 00023 ProbBackoff weights; 00024 uint64_t next; 00025 uint64_t Next() const { return next; } 00026 }; 00027 00028 class UnigramPointer { 00029 public: 00030 explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {} 00031 00032 UnigramPointer() : to_(NULL) {} 00033 00034 bool Found() const { return to_ != NULL; } 00035 00036 float Prob() const { return to_->prob; } 00037 float Backoff() const { return to_->backoff; } 00038 float Rest() const { return Prob(); } 00039 00040 private: 00041 const ProbBackoff *to_; 00042 }; 00043 00044 class Unigram { 00045 public: 00046 Unigram() {} 00047 00048 void Init(void *start) { 00049 unigram_ = static_cast<UnigramValue*>(start); 00050 } 00051 00052 static uint64_t Size(uint64_t count) { 00053 // +1 in case unknown doesn't appear. +1 for the final next. 00054 return (count + 2) * sizeof(UnigramValue); 00055 } 00056 00057 const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; } 00058 00059 ProbBackoff &Unknown() { return unigram_[0].weights; } 00060 00061 UnigramValue *Raw() { 00062 return unigram_; 00063 } 00064 00065 UnigramPointer Find(WordIndex word, NodeRange &next) const { 00066 UnigramValue *val = unigram_ + word; 00067 next.begin = val->next; 00068 next.end = (val+1)->next; 00069 return UnigramPointer(val->weights); 00070 } 00071 00072 private: 00073 UnigramValue *unigram_; 00074 }; 00075 00076 class BitPacked { 00077 public: 00078 BitPacked() {} 00079 00080 uint64_t InsertIndex() const { 00081 return insert_index_; 00082 } 00083 00084 protected: 00085 static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits); 00086 00087 void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits); 00088 00089 uint8_t word_bits_; 00090 uint8_t total_bits_; 00091 uint64_t word_mask_; 00092 00093 uint8_t *base_; 00094 00095 uint64_t insert_index_, max_vocab_; 00096 }; 00097 00098 template <class Bhiksha> class BitPackedMiddle : public BitPacked { 00099 public: 00100 static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config); 00101 00102 // next_source need not be initialized. 00103 BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config); 00104 00105 util::BitAddress Insert(WordIndex word); 00106 00107 void FinishedLoading(uint64_t next_end, const Config &config); 00108 00109 util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const; 00110 00111 util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) { 00112 uint64_t addr = pointer * total_bits_; 00113 addr += word_bits_; 00114 bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range); 00115 return util::BitAddress(base_, addr); 00116 } 00117 00118 private: 00119 uint8_t quant_bits_; 00120 Bhiksha bhiksha_; 00121 00122 const BitPacked *next_source_; 00123 }; 00124 00125 class BitPackedLongest : public BitPacked { 00126 public: 00127 static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) { 00128 return BaseSize(entries, max_vocab, quant_bits); 00129 } 00130 00131 BitPackedLongest() {} 00132 00133 void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) { 00134 BaseInit(base, max_vocab, quant_bits); 00135 } 00136 00137 util::BitAddress Insert(WordIndex word); 00138 00139 util::BitAddress Find(WordIndex word, const NodeRange &node) const; 00140 }; 00141 00142 } // namespace trie 00143 } // namespace ngram 00144 } // namespace lm 00145 00146 #endif // LM_TRIE_H