Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/trie.hh
00001 #ifndef LM_TRIE_H
00002 #define LM_TRIE_H
00003 
00004 #include "lm/weights.hh"
00005 #include "lm/word_index.hh"
00006 #include "util/bit_packing.hh"
00007 
00008 #include <cstddef>
00009 
00010 #include <stdint.h>
00011 
00012 namespace lm {
00013 namespace ngram {
00014 struct Config;
00015 namespace trie {
00016 
00017 struct NodeRange {
00018   uint64_t begin, end;
00019 };
00020 
00021 // TODO: if the number of unigrams is a concern, also bit pack these records.
00022 struct UnigramValue {
00023   ProbBackoff weights;
00024   uint64_t next;
00025   uint64_t Next() const { return next; }
00026 };
00027 
00028 class UnigramPointer {
00029   public:
00030     explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
00031 
00032     UnigramPointer() : to_(NULL) {}
00033 
00034     bool Found() const { return to_ != NULL; }
00035 
00036     float Prob() const { return to_->prob; }
00037     float Backoff() const { return to_->backoff; }
00038     float Rest() const { return Prob(); }
00039 
00040   private:
00041     const ProbBackoff *to_;
00042 };
00043 
00044 class Unigram {
00045   public:
00046     Unigram() {}
00047 
00048     void Init(void *start) {
00049       unigram_ = static_cast<UnigramValue*>(start);
00050     }
00051 
00052     static uint64_t Size(uint64_t count) {
00053       // +1 in case unknown doesn't appear.  +1 for the final next.
00054       return (count + 2) * sizeof(UnigramValue);
00055     }
00056 
00057     const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
00058 
00059     ProbBackoff &Unknown() { return unigram_[0].weights; }
00060 
00061     UnigramValue *Raw() {
00062       return unigram_;
00063     }
00064 
00065     UnigramPointer Find(WordIndex word, NodeRange &next) const {
00066       UnigramValue *val = unigram_ + word;
00067       next.begin = val->next;
00068       next.end = (val+1)->next;
00069       return UnigramPointer(val->weights);
00070     }
00071 
00072   private:
00073     UnigramValue *unigram_;
00074 };
00075 
00076 class BitPacked {
00077   public:
00078     BitPacked() {}
00079 
00080     uint64_t InsertIndex() const {
00081       return insert_index_;
00082     }
00083 
00084   protected:
00085     static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
00086 
00087     void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
00088 
00089     uint8_t word_bits_;
00090     uint8_t total_bits_;
00091     uint64_t word_mask_;
00092 
00093     uint8_t *base_;
00094 
00095     uint64_t insert_index_, max_vocab_;
00096 };
00097 
00098 template <class Bhiksha> class BitPackedMiddle : public BitPacked {
00099   public:
00100     static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
00101 
00102     // next_source need not be initialized.
00103     BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
00104 
00105     util::BitAddress Insert(WordIndex word);
00106 
00107     void FinishedLoading(uint64_t next_end, const Config &config);
00108 
00109     util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
00110 
00111     util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
00112       uint64_t addr = pointer * total_bits_;
00113       addr += word_bits_;
00114       bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
00115       return util::BitAddress(base_, addr);
00116     }
00117 
00118   private:
00119     uint8_t quant_bits_;
00120     Bhiksha bhiksha_;
00121 
00122     const BitPacked *next_source_;
00123 };
00124 
00125 class BitPackedLongest : public BitPacked {
00126   public:
00127     static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
00128       return BaseSize(entries, max_vocab, quant_bits);
00129     }
00130 
00131     BitPackedLongest() {}
00132 
00133     void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
00134       BaseInit(base, max_vocab, quant_bits);
00135     }
00136 
00137     util::BitAddress Insert(WordIndex word);
00138 
00139     util::BitAddress Find(WordIndex word, const NodeRange &node) const;
00140 };
00141 
00142 } // namespace trie
00143 } // namespace ngram
00144 } // namespace lm
00145 
00146 #endif // LM_TRIE_H