Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 /* Efficient left and right language model state for sentence fragments. 00002 * Intended usage: 00003 * Store ChartState with every chart entry. 00004 * To do a rule application: 00005 * 1. Make a ChartState object for your new entry. 00006 * 2. Construct RuleScore. 00007 * 3. Going from left to right, call Terminal or NonTerminal. 00008 * For terminals, just pass the vocab id. 00009 * For non-terminals, pass that non-terminal's ChartState. 00010 * If your decoder expects scores inclusive of subtree scores (i.e. you 00011 * label entries with the highest-scoring path), pass the non-terminal's 00012 * score as prob. 00013 * If your decoder expects relative scores and will walk the chart later, 00014 * pass prob = 0.0. 00015 * In other words, the only effect of prob is that it gets added to the 00016 * returned log probability. 00017 * 4. Call Finish. It returns the log probability. 00018 * 00019 * There's a couple more details: 00020 * Do not pass <s> to Terminal as it is formally not a word in the sentence, 00021 * only context. Instead, call BeginSentence. If called, it should be the 00022 * first call after RuleScore is constructed (since <s> is always the 00023 * leftmost). 00024 * 00025 * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal. 00026 * 00027 * Hashing and sorting comparison operators are provided. All state objects 00028 * are POD. If you intend to use memcmp on raw state objects, you must call 00029 * ZeroRemaining first, as the value of array entries beyond length is 00030 * otherwise undefined. 00031 * 00032 * Usage is of course not limited to chart decoding. Anything that generates 00033 * sentence fragments missing left context could benefit. For example, a 00034 * phrase-based decoder could pre-score phrases, storing ChartState with each 00035 * phrase, even if hypotheses are generated left-to-right. 00036 */ 00037 00038 #ifndef LM_LEFT_H 00039 #define LM_LEFT_H 00040 00041 #include "lm/max_order.hh" 00042 #include "lm/state.hh" 00043 #include "lm/return.hh" 00044 00045 #include "util/murmur_hash.hh" 00046 00047 #include <algorithm> 00048 00049 namespace lm { 00050 namespace ngram { 00051 00052 template <class M> class RuleScore { 00053 public: 00054 explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) { 00055 out.left.length = 0; 00056 out.right.length = 0; 00057 } 00058 00059 void BeginSentence() { 00060 out_->right = model_.BeginSentenceState(); 00061 // out_->left is empty. 00062 left_done_ = true; 00063 } 00064 00065 void Terminal(WordIndex word) { 00066 State copy(out_->right); 00067 FullScoreReturn ret(model_.FullScore(copy, word, out_->right)); 00068 if (left_done_) { prob_ += ret.prob; return; } 00069 if (ret.independent_left) { 00070 prob_ += ret.prob; 00071 left_done_ = true; 00072 return; 00073 } 00074 out_->left.pointers[out_->left.length++] = ret.extend_left; 00075 prob_ += ret.rest; 00076 if (out_->right.length != copy.length + 1) 00077 left_done_ = true; 00078 } 00079 00080 // Faster version of NonTerminal for the case where the rule begins with a non-terminal. 00081 void BeginNonTerminal(const ChartState &in, float prob = 0.0) { 00082 prob_ = prob; 00083 *out_ = in; 00084 left_done_ = in.left.full; 00085 } 00086 00087 void NonTerminal(const ChartState &in, float prob = 0.0) { 00088 prob_ += prob; 00089 00090 if (!in.left.length) { 00091 if (in.left.full) { 00092 for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i; 00093 left_done_ = true; 00094 out_->right = in.right; 00095 } 00096 return; 00097 } 00098 00099 if (!out_->right.length) { 00100 out_->right = in.right; 00101 if (left_done_) { 00102 prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1); 00103 return; 00104 } 00105 if (out_->left.length) { 00106 left_done_ = true; 00107 } else { 00108 out_->left = in.left; 00109 left_done_ = in.left.full; 00110 } 00111 return; 00112 } 00113 00114 float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1]; 00115 float *back = backoffs, *back2 = backoffs2; 00116 unsigned char next_use = out_->right.length; 00117 00118 // First word 00119 if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return; 00120 00121 // Words after the first, so extending a bigram to begin with 00122 for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) { 00123 if (ExtendLeft(in, next_use, extend_length, back, back2)) return; 00124 std::swap(back, back2); 00125 } 00126 00127 if (in.left.full) { 00128 for (const float *i = back; i != back + next_use; ++i) prob_ += *i; 00129 left_done_ = true; 00130 out_->right = in.right; 00131 return; 00132 } 00133 00134 // Right state was minimized, so it's already independent of the new words to the left. 00135 if (in.right.length < in.left.length) { 00136 out_->right = in.right; 00137 return; 00138 } 00139 00140 // Shift exisiting words down. 00141 for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) { 00142 *(i + in.right.length) = *i; 00143 } 00144 // Add words from in.right. 00145 std::copy(in.right.words, in.right.words + in.right.length, out_->right.words); 00146 // Assemble backoff composed on the existing state's backoff followed by the new state's backoff. 00147 std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff); 00148 std::copy(back, back + next_use, out_->right.backoff + in.right.length); 00149 out_->right.length = in.right.length + next_use; 00150 } 00151 00152 float Finish() { 00153 // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram. 00154 out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1); 00155 return prob_; 00156 } 00157 00158 void Reset() { 00159 prob_ = 0.0; 00160 left_done_ = false; 00161 out_->left.length = 0; 00162 out_->right.length = 0; 00163 } 00164 void Reset(ChartState &replacement) { 00165 out_ = &replacement; 00166 Reset(); 00167 } 00168 00169 private: 00170 bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) { 00171 ProcessRet(model_.ExtendLeft( 00172 out_->right.words, out_->right.words + next_use, // Words to extend into 00173 back_in, // Backoffs to use 00174 in.left.pointers[extend_length - 1], extend_length, // Words to be extended 00175 back_out, // Backoffs for the next score 00176 next_use)); // Length of n-gram to use in next scoring. 00177 if (next_use != out_->right.length) { 00178 left_done_ = true; 00179 if (!next_use) { 00180 // Early exit. 00181 out_->right = in.right; 00182 prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1); 00183 return true; 00184 } 00185 } 00186 // Continue scoring. 00187 return false; 00188 } 00189 00190 void ProcessRet(const FullScoreReturn &ret) { 00191 if (left_done_) { 00192 prob_ += ret.prob; 00193 return; 00194 } 00195 if (ret.independent_left) { 00196 prob_ += ret.prob; 00197 left_done_ = true; 00198 return; 00199 } 00200 out_->left.pointers[out_->left.length++] = ret.extend_left; 00201 prob_ += ret.rest; 00202 } 00203 00204 const M &model_; 00205 00206 ChartState *out_; 00207 00208 bool left_done_; 00209 00210 float prob_; 00211 }; 00212 00213 } // namespace ngram 00214 } // namespace lm 00215 00216 #endif // LM_LEFT_H