Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_MODEL_H 00002 #define LM_MODEL_H 00003 00004 #include "lm/bhiksha.hh" 00005 #include "lm/binary_format.hh" 00006 #include "lm/config.hh" 00007 #include "lm/facade.hh" 00008 #include "lm/quantize.hh" 00009 #include "lm/search_hashed.hh" 00010 #include "lm/search_trie.hh" 00011 #include "lm/state.hh" 00012 #include "lm/value.hh" 00013 #include "lm/vocab.hh" 00014 #include "lm/weights.hh" 00015 00016 #include "util/murmur_hash.hh" 00017 00018 #include <algorithm> 00019 #include <vector> 00020 #include <cstring> 00021 00022 namespace util { class FilePiece; } 00023 00024 namespace lm { 00025 namespace ngram { 00026 namespace detail { 00027 00028 // Should return the same results as SRI. 00029 // ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts. 00030 template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { 00031 private: 00032 typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; 00033 public: 00034 // This is the model type returned by RecognizeBinary. 00035 static const ModelType kModelType; 00036 00037 static const unsigned int kVersion = Search::kVersion; 00038 00039 /* Get the size of memory that will be mapped given ngram counts. This 00040 * does not include small non-mapped control structures, such as this class 00041 * itself. 00042 */ 00043 static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); 00044 00045 /* Load the model from a file. It may be an ARPA or binary file. Binary 00046 * files must have the format expected by this class or you'll get an 00047 * exception. So TrieModel can only load ARPA or binary created by 00048 * TrieModel. To classify binary files, call RecognizeBinary in 00049 * lm/binary_format.hh. 00050 */ 00051 explicit GenericModel(const char *file, const Config &config = Config()); 00052 00053 /* Score p(new_word | in_state) and incorporate new_word into out_state. 00054 * Note that in_state and out_state must be different references: 00055 * &in_state != &out_state. 00056 */ 00057 FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; 00058 00059 /* Slower call without in_state. Try to remember state, but sometimes it 00060 * would cost too much memory or your decoder isn't setup properly. 00061 * To use this function, make an array of WordIndex containing the context 00062 * vocabulary ids in reverse order. Then, pass the bounds of the array: 00063 * [context_rbegin, context_rend). The new_word is not part of the context 00064 * array unless you intend to repeat words. 00065 */ 00066 FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; 00067 00068 /* Get the state for a context. Don't use this if you can avoid it. Use 00069 * BeginSentenceState or NullContextState and extend from those. If 00070 * you're only going to use this state to call FullScore once, use 00071 * FullScoreForgotState. 00072 * To use this function, make an array of WordIndex containing the context 00073 * vocabulary ids in reverse order. Then, pass the bounds of the array: 00074 * [context_rbegin, context_rend). 00075 */ 00076 void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; 00077 00078 /* More efficient version of FullScore where a partial n-gram has already 00079 * been scored. 00080 * NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE. 00081 */ 00082 FullScoreReturn ExtendLeft( 00083 // Additional context in reverse order. This will update add_rend to 00084 const WordIndex *add_rbegin, const WordIndex *add_rend, 00085 // Backoff weights to use. 00086 const float *backoff_in, 00087 // extend_left returned by a previous query. 00088 uint64_t extend_pointer, 00089 // Length of n-gram that the pointer corresponds to. 00090 unsigned char extend_length, 00091 // Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)] 00092 float *backoff_out, 00093 // Amount of additional content that should be considered by the next call. 00094 unsigned char &next_use) const; 00095 00096 /* Return probabilities minus rest costs for an array of pointers. The 00097 * first length should be the length of the n-gram to which pointers_begin 00098 * points. 00099 */ 00100 float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { 00101 // Compiler should optimize this if away. 00102 return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; 00103 } 00104 00105 private: 00106 FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; 00107 00108 // Score bigrams and above. Do not include backoff. 00109 void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const; 00110 00111 // Appears after Size in the cc file. 00112 void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); 00113 00114 void InitializeFromARPA(int fd, const char *file, const Config &config); 00115 00116 float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; 00117 00118 BinaryFormat backing_; 00119 00120 VocabularyT vocab_; 00121 00122 Search search_; 00123 }; 00124 00125 } // namespace detail 00126 00127 // Instead of typedef, inherit. This allows the Model etc to be forward declared. 00128 // Oh the joys of C and C++. 00129 #define LM_COMMA() , 00130 #define LM_NAME_MODEL(name, from)\ 00131 class name : public from {\ 00132 public:\ 00133 name(const char *file, const Config &config = Config()) : from(file, config) {}\ 00134 }; 00135 00136 LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>); 00137 LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>); 00138 LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); 00139 LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); 00140 LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); 00141 LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); 00142 00143 // Default implementation. No real reason for it to be the default. 00144 typedef ::lm::ngram::ProbingVocabulary Vocabulary; 00145 typedef ProbingModel Model; 00146 00147 /* Autorecognize the file type, load, and return the virtual base class. Don't 00148 * use the virtual base class if you can avoid it. Instead, use the above 00149 * classes as template arguments to your own virtual feature function.*/ 00150 base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); 00151 00152 } // namespace ngram 00153 } // namespace lm 00154 00155 #endif // LM_MODEL_H