Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_SEARCH_HASHED_H 00002 #define LM_SEARCH_HASHED_H 00003 00004 #include "lm/model_type.hh" 00005 #include "lm/config.hh" 00006 #include "lm/read_arpa.hh" 00007 #include "lm/return.hh" 00008 #include "lm/weights.hh" 00009 00010 #include "util/bit_packing.hh" 00011 #include "util/probing_hash_table.hh" 00012 00013 #include <algorithm> 00014 #include <iostream> 00015 #include <vector> 00016 00017 namespace util { class FilePiece; } 00018 00019 namespace lm { 00020 namespace ngram { 00021 class BinaryFormat; 00022 class ProbingVocabulary; 00023 namespace detail { 00024 00025 inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) { 00026 uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL); 00027 return ret; 00028 } 00029 00030 #pragma pack(push) 00031 #pragma pack(4) 00032 struct ProbEntry { 00033 uint64_t key; 00034 Prob value; 00035 typedef uint64_t Key; 00036 typedef Prob Value; 00037 uint64_t GetKey() const { 00038 return key; 00039 } 00040 }; 00041 00042 #pragma pack(pop) 00043 00044 class LongestPointer { 00045 public: 00046 explicit LongestPointer(const float &to) : to_(&to) {} 00047 00048 LongestPointer() : to_(NULL) {} 00049 00050 bool Found() const { 00051 return to_ != NULL; 00052 } 00053 00054 float Prob() const { 00055 return *to_; 00056 } 00057 00058 private: 00059 const float *to_; 00060 }; 00061 00062 template <class Value> class HashedSearch { 00063 public: 00064 typedef uint64_t Node; 00065 00066 typedef typename Value::ProbingProxy UnigramPointer; 00067 typedef typename Value::ProbingProxy MiddlePointer; 00068 typedef ::lm::ngram::detail::LongestPointer LongestPointer; 00069 00070 static const ModelType kModelType = Value::kProbingModelType; 00071 static const bool kDifferentRest = Value::kDifferentRest; 00072 static const unsigned int kVersion = 0; 00073 00074 // TODO: move probing_multiplier here with next binary file format update. 00075 static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {} 00076 00077 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) { 00078 uint64_t ret = Unigram::Size(counts[0]); 00079 for (unsigned char n = 1; n < counts.size() - 1; ++n) { 00080 ret += Middle::Size(counts[n], config.probing_multiplier); 00081 } 00082 return ret + Longest::Size(counts.back(), config.probing_multiplier); 00083 } 00084 00085 uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config); 00086 00087 void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing); 00088 00089 unsigned char Order() const { 00090 return middle_.size() + 2; 00091 } 00092 00093 typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); } 00094 00095 UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const { 00096 extend_left = static_cast<uint64_t>(word); 00097 next = extend_left; 00098 UnigramPointer ret(unigram_.Lookup(word)); 00099 independent_left = ret.IndependentLeft(); 00100 return ret; 00101 } 00102 00103 MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const { 00104 node = extend_pointer; 00105 return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value); 00106 } 00107 00108 MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const { 00109 node = CombineWordHash(node, word); 00110 typename Middle::ConstIterator found; 00111 if (!middle_[order_minus_2].Find(node, found)) { 00112 independent_left = true; 00113 return MiddlePointer(); 00114 } 00115 extend_pointer = node; 00116 MiddlePointer ret(found->value); 00117 independent_left = ret.IndependentLeft(); 00118 return ret; 00119 } 00120 00121 LongestPointer LookupLongest(WordIndex word, const Node &node) const { 00122 // Sign bit is always on because longest n-grams do not extend left. 00123 typename Longest::ConstIterator found; 00124 if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer(); 00125 return LongestPointer(found->value.prob); 00126 } 00127 00128 // Generate a node without necessarily checking that it actually exists. 00129 // Optionally return false if it's know to not exist. 00130 bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const { 00131 assert(begin != end); 00132 node = static_cast<Node>(*begin); 00133 for (const WordIndex *i = begin + 1; i < end; ++i) { 00134 node = CombineWordHash(node, *i); 00135 } 00136 return true; 00137 } 00138 00139 private: 00140 // Interpret config's rest cost build policy and pass the right template argument to ApplyBuild. 00141 void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn); 00142 00143 template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build); 00144 00145 class Unigram { 00146 public: 00147 Unigram() {} 00148 00149 Unigram(void *start, uint64_t count) : 00150 unigram_(static_cast<typename Value::Weights*>(start)) 00151 #ifdef DEBUG 00152 , count_(count) 00153 #endif 00154 {} 00155 00156 static uint64_t Size(uint64_t count) { 00157 return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk> 00158 } 00159 00160 const typename Value::Weights &Lookup(WordIndex index) const { 00161 #ifdef DEBUG 00162 assert(index < count_); 00163 #endif 00164 return unigram_[index]; 00165 } 00166 00167 typename Value::Weights &Unknown() { return unigram_[0]; } 00168 00169 // For building. 00170 typename Value::Weights *Raw() { return unigram_; } 00171 00172 private: 00173 typename Value::Weights *unigram_; 00174 #ifdef DEBUG 00175 uint64_t count_; 00176 #endif 00177 }; 00178 00179 Unigram unigram_; 00180 00181 typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle; 00182 std::vector<Middle> middle_; 00183 00184 typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest; 00185 Longest longest_; 00186 }; 00187 00188 } // namespace detail 00189 } // namespace ngram 00190 } // namespace lm 00191 00192 #endif // LM_SEARCH_HASHED_H