Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/search_hashed.hh
00001 #ifndef LM_SEARCH_HASHED_H
00002 #define LM_SEARCH_HASHED_H
00003 
00004 #include "lm/model_type.hh"
00005 #include "lm/config.hh"
00006 #include "lm/read_arpa.hh"
00007 #include "lm/return.hh"
00008 #include "lm/weights.hh"
00009 
00010 #include "util/bit_packing.hh"
00011 #include "util/probing_hash_table.hh"
00012 
00013 #include <algorithm>
00014 #include <iostream>
00015 #include <vector>
00016 
00017 namespace util { class FilePiece; }
00018 
00019 namespace lm {
00020 namespace ngram {
00021 class BinaryFormat;
00022 class ProbingVocabulary;
00023 namespace detail {
00024 
00025 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
00026   uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
00027   return ret;
00028 }
00029 
00030 #pragma pack(push)
00031 #pragma pack(4)
00032 struct ProbEntry {
00033   uint64_t key;
00034   Prob value;
00035   typedef uint64_t Key;
00036   typedef Prob Value;
00037   uint64_t GetKey() const {
00038     return key;
00039   }
00040 };
00041 
00042 #pragma pack(pop)
00043 
00044 class LongestPointer {
00045   public:
00046     explicit LongestPointer(const float &to) : to_(&to) {}
00047 
00048     LongestPointer() : to_(NULL) {}
00049 
00050     bool Found() const {
00051       return to_ != NULL;
00052     }
00053 
00054     float Prob() const {
00055       return *to_;
00056     }
00057 
00058   private:
00059     const float *to_;
00060 };
00061 
00062 template <class Value> class HashedSearch {
00063   public:
00064     typedef uint64_t Node;
00065 
00066     typedef typename Value::ProbingProxy UnigramPointer;
00067     typedef typename Value::ProbingProxy MiddlePointer;
00068     typedef ::lm::ngram::detail::LongestPointer LongestPointer;
00069 
00070     static const ModelType kModelType = Value::kProbingModelType;
00071     static const bool kDifferentRest = Value::kDifferentRest;
00072     static const unsigned int kVersion = 0;
00073 
00074     // TODO: move probing_multiplier here with next binary file format update.
00075     static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
00076 
00077     static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
00078       uint64_t ret = Unigram::Size(counts[0]);
00079       for (unsigned char n = 1; n < counts.size() - 1; ++n) {
00080         ret += Middle::Size(counts[n], config.probing_multiplier);
00081       }
00082       return ret + Longest::Size(counts.back(), config.probing_multiplier);
00083     }
00084 
00085     uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
00086 
00087     void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
00088 
00089     unsigned char Order() const {
00090       return middle_.size() + 2;
00091     }
00092 
00093     typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
00094 
00095     UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
00096       extend_left = static_cast<uint64_t>(word);
00097       next = extend_left;
00098       UnigramPointer ret(unigram_.Lookup(word));
00099       independent_left = ret.IndependentLeft();
00100       return ret;
00101     }
00102 
00103     MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
00104       node = extend_pointer;
00105       return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
00106     }
00107 
00108     MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
00109       node = CombineWordHash(node, word);
00110       typename Middle::ConstIterator found;
00111       if (!middle_[order_minus_2].Find(node, found)) {
00112         independent_left = true;
00113         return MiddlePointer();
00114       }
00115       extend_pointer = node;
00116       MiddlePointer ret(found->value);
00117       independent_left = ret.IndependentLeft();
00118       return ret;
00119     }
00120 
00121     LongestPointer LookupLongest(WordIndex word, const Node &node) const {
00122       // Sign bit is always on because longest n-grams do not extend left.
00123       typename Longest::ConstIterator found;
00124       if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
00125       return LongestPointer(found->value.prob);
00126     }
00127 
00128     // Generate a node without necessarily checking that it actually exists.
00129     // Optionally return false if it's know to not exist.
00130     bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
00131       assert(begin != end);
00132       node = static_cast<Node>(*begin);
00133       for (const WordIndex *i = begin + 1; i < end; ++i) {
00134         node = CombineWordHash(node, *i);
00135       }
00136       return true;
00137     }
00138 
00139   private:
00140     // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
00141     void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
00142 
00143     template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
00144 
00145     class Unigram {
00146       public:
00147         Unigram() {}
00148 
00149         Unigram(void *start, uint64_t count) :
00150           unigram_(static_cast<typename Value::Weights*>(start))
00151 #ifdef DEBUG
00152          ,  count_(count)
00153 #endif
00154       {}
00155 
00156         static uint64_t Size(uint64_t count) {
00157           return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk>
00158         }
00159 
00160         const typename Value::Weights &Lookup(WordIndex index) const {
00161 #ifdef DEBUG
00162           assert(index < count_);
00163 #endif
00164           return unigram_[index];
00165         }
00166 
00167         typename Value::Weights &Unknown() { return unigram_[0]; }
00168 
00169         // For building.
00170         typename Value::Weights *Raw() { return unigram_; }
00171 
00172       private:
00173         typename Value::Weights *unigram_;
00174 #ifdef DEBUG
00175         uint64_t count_;
00176 #endif
00177     };
00178 
00179     Unigram unigram_;
00180 
00181     typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
00182     std::vector<Middle> middle_;
00183 
00184     typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
00185     Longest longest_;
00186 };
00187 
00188 } // namespace detail
00189 } // namespace ngram
00190 } // namespace lm
00191 
00192 #endif // LM_SEARCH_HASHED_H