Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_VOCAB_H 00002 #define LM_VOCAB_H 00003 00004 #include "lm/enumerate_vocab.hh" 00005 #include "lm/lm_exception.hh" 00006 #include "lm/virtual_interface.hh" 00007 #include "util/file_stream.hh" 00008 #include "util/murmur_hash.hh" 00009 #include "util/pool.hh" 00010 #include "util/probing_hash_table.hh" 00011 #include "util/sorted_uniform.hh" 00012 #include "util/string_piece.hh" 00013 00014 #include <limits> 00015 #include <string> 00016 #include <vector> 00017 00018 namespace lm { 00019 struct ProbBackoff; 00020 class EnumerateVocab; 00021 00022 namespace ngram { 00023 struct Config; 00024 00025 namespace detail { 00026 uint64_t HashForVocab(const char *str, std::size_t len); 00027 inline uint64_t HashForVocab(const StringPiece &str) { 00028 return HashForVocab(str.data(), str.length()); 00029 } 00030 struct ProbingVocabularyHeader; 00031 } // namespace detail 00032 00033 // Writes words immediately to a file instead of buffering, because we know 00034 // where in the file to put them. 00035 class ImmediateWriteWordsWrapper : public EnumerateVocab { 00036 public: 00037 ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start); 00038 00039 void Add(WordIndex index, const StringPiece &str) { 00040 stream_ << str << '\0'; 00041 if (inner_) inner_->Add(index, str); 00042 } 00043 00044 private: 00045 EnumerateVocab *inner_; 00046 00047 util::FileStream stream_; 00048 }; 00049 00050 // When the binary size isn't known yet. 00051 class WriteWordsWrapper : public EnumerateVocab { 00052 public: 00053 WriteWordsWrapper(EnumerateVocab *inner); 00054 00055 void Add(WordIndex index, const StringPiece &str); 00056 00057 const std::string &Buffer() const { return buffer_; } 00058 void Write(int fd, uint64_t start); 00059 00060 private: 00061 EnumerateVocab *inner_; 00062 00063 std::string buffer_; 00064 }; 00065 00066 // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices. 00067 class SortedVocabulary : public base::Vocabulary { 00068 public: 00069 SortedVocabulary(); 00070 00071 WordIndex Index(const StringPiece &str) const { 00072 const uint64_t *found; 00073 if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>( 00074 util::IdentityAccessor<uint64_t>(), 00075 begin_ - 1, 0, 00076 end_, std::numeric_limits<uint64_t>::max(), 00077 detail::HashForVocab(str), found)) { 00078 return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table. 00079 } else { 00080 return 0; 00081 } 00082 } 00083 00084 // Size for purposes of file writing 00085 static uint64_t Size(uint64_t entries, const Config &config); 00086 00087 /* Read null-delimited words from file from_words, renumber according to 00088 * hash order, write null-delimited words to to_words, and create a mapping 00089 * from old id to new id. The 0th vocab word must be <unk>. 00090 */ 00091 static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping); 00092 00093 // Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary. 00094 WordIndex Bound() const { return bound_; } 00095 00096 // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. 00097 void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config); 00098 00099 void Relocate(void *new_start); 00100 00101 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); 00102 00103 // Insert and FinishedLoading go together. 00104 WordIndex Insert(const StringPiece &str); 00105 // Reorders reorder_vocab so that the IDs are sorted. 00106 void FinishedLoading(ProbBackoff *reorder_vocab); 00107 00108 // Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>. 00109 std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); } 00110 00111 bool SawUnk() const { return saw_unk_; } 00112 00113 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); 00114 00115 uint64_t *&EndHack() { return end_; } 00116 00117 void Populated(); 00118 00119 private: 00120 template <class T> void GenericFinished(T *reorder); 00121 00122 uint64_t *begin_, *end_; 00123 00124 WordIndex bound_; 00125 00126 bool saw_unk_; 00127 00128 EnumerateVocab *enumerate_; 00129 00130 // Actual strings. Used only when loading from ARPA and enumerate_ != NULL 00131 util::Pool string_backing_; 00132 00133 std::vector<StringPiece> strings_to_enumerate_; 00134 }; 00135 00136 #pragma pack(push) 00137 #pragma pack(4) 00138 struct ProbingVocabularyEntry { 00139 uint64_t key; 00140 WordIndex value; 00141 00142 typedef uint64_t Key; 00143 uint64_t GetKey() const { return key; } 00144 void SetKey(uint64_t to) { key = to; } 00145 00146 static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) { 00147 ProbingVocabularyEntry ret; 00148 ret.key = key; 00149 ret.value = value; 00150 return ret; 00151 } 00152 }; 00153 #pragma pack(pop) 00154 00155 // Vocabulary storing a map from uint64_t to WordIndex. 00156 class ProbingVocabulary : public base::Vocabulary { 00157 public: 00158 ProbingVocabulary(); 00159 00160 WordIndex Index(const StringPiece &str) const { 00161 Lookup::ConstIterator i; 00162 return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; 00163 } 00164 00165 static uint64_t Size(uint64_t entries, float probing_multiplier); 00166 // This just unwraps Config to get the probing_multiplier. 00167 static uint64_t Size(uint64_t entries, const Config &config); 00168 00169 // Vocab words are [0, Bound()). 00170 WordIndex Bound() const { return bound_; } 00171 00172 // Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway. 00173 void SetupMemory(void *start, std::size_t allocated); 00174 void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) { 00175 SetupMemory(start, allocated); 00176 } 00177 00178 void Relocate(void *new_start); 00179 00180 void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries); 00181 00182 WordIndex Insert(const StringPiece &str); 00183 00184 template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) { 00185 InternalFinishedLoading(); 00186 } 00187 00188 std::size_t UnkCountChangePadding() const { return 0; } 00189 00190 bool SawUnk() const { return saw_unk_; } 00191 00192 void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); 00193 00194 private: 00195 void InternalFinishedLoading(); 00196 00197 typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup; 00198 00199 Lookup lookup_; 00200 00201 WordIndex bound_; 00202 00203 bool saw_unk_; 00204 00205 EnumerateVocab *enumerate_; 00206 00207 detail::ProbingVocabularyHeader *header_; 00208 }; 00209 00210 void MissingUnknown(const Config &config) throw(SpecialWordMissingException); 00211 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); 00212 00213 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { 00214 if (!vocab.SawUnk()) MissingUnknown(config); 00215 if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); 00216 if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); 00217 } 00218 00219 class WriteUniqueWords { 00220 public: 00221 explicit WriteUniqueWords(int fd) : word_list_(fd) {} 00222 00223 void operator()(const StringPiece &word) { 00224 word_list_ << word << '\0'; 00225 } 00226 00227 private: 00228 util::FileStream word_list_; 00229 }; 00230 00231 class NoOpUniqueWords { 00232 public: 00233 NoOpUniqueWords() {} 00234 void operator()(const StringPiece &word) {} 00235 }; 00236 00237 template <class NewWordAction = NoOpUniqueWords> class GrowableVocab { 00238 public: 00239 static std::size_t MemUsage(WordIndex content) { 00240 return Lookup::MemUsage(content > 2 ? content : 2); 00241 } 00242 00243 // Does not take ownership of write_wordi 00244 template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction()) 00245 : lookup_(initial_size), new_word_(new_word_construct) { 00246 FindOrInsert("<unk>"); // Force 0 00247 FindOrInsert("<s>"); // Force 1 00248 FindOrInsert("</s>"); // Force 2 00249 } 00250 00251 WordIndex Index(const StringPiece &str) const { 00252 Lookup::ConstIterator i; 00253 return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0; 00254 } 00255 00256 WordIndex FindOrInsert(const StringPiece &word) { 00257 ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size()); 00258 Lookup::MutableIterator it; 00259 if (!lookup_.FindOrInsert(entry, it)) { 00260 new_word_(word); 00261 UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh"); 00262 } 00263 return it->value; 00264 } 00265 00266 WordIndex Size() const { return lookup_.Size(); } 00267 00268 private: 00269 typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup; 00270 00271 Lookup lookup_; 00272 00273 NewWordAction new_word_; 00274 }; 00275 00276 } // namespace ngram 00277 } // namespace lm 00278 00279 #endif // LM_VOCAB_H