Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/vocab.hh
00001 #ifndef LM_VOCAB_H
00002 #define LM_VOCAB_H
00003 
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/lm_exception.hh"
00006 #include "lm/virtual_interface.hh"
00007 #include "util/file_stream.hh"
00008 #include "util/murmur_hash.hh"
00009 #include "util/pool.hh"
00010 #include "util/probing_hash_table.hh"
00011 #include "util/sorted_uniform.hh"
00012 #include "util/string_piece.hh"
00013 
00014 #include <limits>
00015 #include <string>
00016 #include <vector>
00017 
00018 namespace lm {
00019 struct ProbBackoff;
00020 class EnumerateVocab;
00021 
00022 namespace ngram {
00023 struct Config;
00024 
00025 namespace detail {
00026 uint64_t HashForVocab(const char *str, std::size_t len);
00027 inline uint64_t HashForVocab(const StringPiece &str) {
00028   return HashForVocab(str.data(), str.length());
00029 }
00030 struct ProbingVocabularyHeader;
00031 } // namespace detail
00032 
00033 // Writes words immediately to a file instead of buffering, because we know
00034 // where in the file to put them.
00035 class ImmediateWriteWordsWrapper : public EnumerateVocab {
00036   public:
00037     ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
00038 
00039     void Add(WordIndex index, const StringPiece &str) {
00040       stream_ << str << '\0';
00041       if (inner_) inner_->Add(index, str);
00042     }
00043 
00044   private:
00045     EnumerateVocab *inner_;
00046 
00047     util::FileStream stream_;
00048 };
00049 
00050 // When the binary size isn't known yet.
00051 class WriteWordsWrapper : public EnumerateVocab {
00052   public:
00053     WriteWordsWrapper(EnumerateVocab *inner);
00054 
00055     void Add(WordIndex index, const StringPiece &str);
00056 
00057     const std::string &Buffer() const { return buffer_; }
00058     void Write(int fd, uint64_t start);
00059 
00060   private:
00061     EnumerateVocab *inner_;
00062 
00063     std::string buffer_;
00064 };
00065 
00066 // Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
00067 class SortedVocabulary : public base::Vocabulary {
00068   public:
00069     SortedVocabulary();
00070 
00071     WordIndex Index(const StringPiece &str) const {
00072       const uint64_t *found;
00073       if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
00074             util::IdentityAccessor<uint64_t>(),
00075             begin_ - 1, 0,
00076             end_, std::numeric_limits<uint64_t>::max(),
00077             detail::HashForVocab(str), found)) {
00078         return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
00079       } else {
00080         return 0;
00081       }
00082     }
00083 
00084     // Size for purposes of file writing
00085     static uint64_t Size(uint64_t entries, const Config &config);
00086 
00087     /* Read null-delimited words from file from_words, renumber according to
00088      * hash order, write null-delimited words to to_words, and create a mapping
00089      * from old id to new id.  The 0th vocab word must be <unk>.
00090      */
00091     static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
00092 
00093     // Vocab words are [0, Bound())  Only valid after FinishedLoading/LoadedBinary.
00094     WordIndex Bound() const { return bound_; }
00095 
00096     // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
00097     void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
00098 
00099     void Relocate(void *new_start);
00100 
00101     void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00102 
00103     // Insert and FinishedLoading go together.
00104     WordIndex Insert(const StringPiece &str);
00105     // Reorders reorder_vocab so that the IDs are sorted.
00106     void FinishedLoading(ProbBackoff *reorder_vocab);
00107 
00108     // Trie stores the correct counts including <unk> in the header.  If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
00109     std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
00110 
00111     bool SawUnk() const { return saw_unk_; }
00112 
00113     void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
00114 
00115     uint64_t *&EndHack() { return end_; }
00116 
00117     void Populated();
00118 
00119   private:
00120     template <class T> void GenericFinished(T *reorder);
00121 
00122     uint64_t *begin_, *end_;
00123 
00124     WordIndex bound_;
00125 
00126     bool saw_unk_;
00127 
00128     EnumerateVocab *enumerate_;
00129 
00130     // Actual strings.  Used only when loading from ARPA and enumerate_ != NULL
00131     util::Pool string_backing_;
00132 
00133     std::vector<StringPiece> strings_to_enumerate_;
00134 };
00135 
00136 #pragma pack(push)
00137 #pragma pack(4)
00138 struct ProbingVocabularyEntry {
00139   uint64_t key;
00140   WordIndex value;
00141 
00142   typedef uint64_t Key;
00143   uint64_t GetKey() const { return key; }
00144   void SetKey(uint64_t to) { key = to; }
00145 
00146   static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
00147     ProbingVocabularyEntry ret;
00148     ret.key = key;
00149     ret.value = value;
00150     return ret;
00151   }
00152 };
00153 #pragma pack(pop)
00154 
00155 // Vocabulary storing a map from uint64_t to WordIndex.
00156 class ProbingVocabulary : public base::Vocabulary {
00157   public:
00158     ProbingVocabulary();
00159 
00160     WordIndex Index(const StringPiece &str) const {
00161       Lookup::ConstIterator i;
00162       return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
00163     }
00164 
00165     static uint64_t Size(uint64_t entries, float probing_multiplier);
00166     // This just unwraps Config to get the probing_multiplier.
00167     static uint64_t Size(uint64_t entries, const Config &config);
00168 
00169     // Vocab words are [0, Bound()).
00170     WordIndex Bound() const { return bound_; }
00171 
00172     // Everything else is for populating.  I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
00173     void SetupMemory(void *start, std::size_t allocated);
00174     void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
00175       SetupMemory(start, allocated);
00176     }
00177 
00178     void Relocate(void *new_start);
00179 
00180     void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
00181 
00182     WordIndex Insert(const StringPiece &str);
00183 
00184     template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
00185       InternalFinishedLoading();
00186     }
00187 
00188     std::size_t UnkCountChangePadding() const { return 0; }
00189 
00190     bool SawUnk() const { return saw_unk_; }
00191 
00192     void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
00193 
00194   private:
00195     void InternalFinishedLoading();
00196 
00197     typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
00198 
00199     Lookup lookup_;
00200 
00201     WordIndex bound_;
00202 
00203     bool saw_unk_;
00204 
00205     EnumerateVocab *enumerate_;
00206 
00207     detail::ProbingVocabularyHeader *header_;
00208 };
00209 
00210 void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
00211 void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
00212 
00213 template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
00214   if (!vocab.SawUnk()) MissingUnknown(config);
00215   if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
00216   if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
00217 }
00218 
00219 class WriteUniqueWords {
00220   public:
00221     explicit WriteUniqueWords(int fd) : word_list_(fd) {}
00222 
00223     void operator()(const StringPiece &word) {
00224       word_list_ << word << '\0';
00225     }
00226 
00227   private:
00228     util::FileStream word_list_;
00229 };
00230 
00231 class NoOpUniqueWords {
00232   public:
00233     NoOpUniqueWords() {}
00234     void operator()(const StringPiece &word) {}
00235 };
00236 
00237 template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
00238   public:
00239     static std::size_t MemUsage(WordIndex content) {
00240       return Lookup::MemUsage(content > 2 ? content : 2);
00241     }
00242 
00243     // Does not take ownership of write_wordi
00244     template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
00245       : lookup_(initial_size), new_word_(new_word_construct) {
00246       FindOrInsert("<unk>"); // Force 0
00247       FindOrInsert("<s>"); // Force 1
00248       FindOrInsert("</s>"); // Force 2
00249     }
00250 
00251     WordIndex Index(const StringPiece &str) const {
00252       Lookup::ConstIterator i;
00253       return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
00254     }
00255 
00256     WordIndex FindOrInsert(const StringPiece &word) {
00257       ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
00258       Lookup::MutableIterator it;
00259       if (!lookup_.FindOrInsert(entry, it)) {
00260         new_word_(word);
00261         UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words.  Change WordIndex to uint64_t in lm/word_index.hh");
00262       }
00263       return it->value;
00264     }
00265 
00266     WordIndex Size() const { return lookup_.Size(); }
00267 
00268   private:
00269     typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
00270 
00271     Lookup lookup_;
00272 
00273     NewWordAction new_word_;
00274 };
00275 
00276 } // namespace ngram
00277 } // namespace lm
00278 
00279 #endif // LM_VOCAB_H