Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/left.hh
00001 /* Efficient left and right language model state for sentence fragments.
00002  * Intended usage:
00003  * Store ChartState with every chart entry.
00004  * To do a rule application:
00005  * 1. Make a ChartState object for your new entry.
00006  * 2. Construct RuleScore.
00007  * 3. Going from left to right, call Terminal or NonTerminal.
00008  *   For terminals, just pass the vocab id.
00009  *   For non-terminals, pass that non-terminal's ChartState.
00010  *     If your decoder expects scores inclusive of subtree scores (i.e. you
00011  *     label entries with the highest-scoring path), pass the non-terminal's
00012  *     score as prob.
00013  *     If your decoder expects relative scores and will walk the chart later,
00014  *     pass prob = 0.0.
00015  *     In other words, the only effect of prob is that it gets added to the
00016  *     returned log probability.
00017  * 4. Call Finish.  It returns the log probability.
00018  *
00019  * There's a couple more details:
00020  * Do not pass <s> to Terminal as it is formally not a word in the sentence,
00021  * only context.  Instead, call BeginSentence.  If called, it should be the
00022  * first call after RuleScore is constructed (since <s> is always the
00023  * leftmost).
00024  *
00025  * If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
00026  *
00027  * Hashing and sorting comparison operators are provided.   All state objects
00028  * are POD.  If you intend to use memcmp on raw state objects, you must call
00029  * ZeroRemaining first, as the value of array entries beyond length is
00030  * otherwise undefined.
00031  *
00032  * Usage is of course not limited to chart decoding.  Anything that generates
00033  * sentence fragments missing left context could benefit.  For example, a
00034  * phrase-based decoder could pre-score phrases, storing ChartState with each
00035  * phrase, even if hypotheses are generated left-to-right.
00036  */
00037 
00038 #ifndef LM_LEFT_H
00039 #define LM_LEFT_H
00040 
00041 #include "lm/max_order.hh"
00042 #include "lm/state.hh"
00043 #include "lm/return.hh"
00044 
00045 #include "util/murmur_hash.hh"
00046 
00047 #include <algorithm>
00048 
00049 namespace lm {
00050 namespace ngram {
00051 
00052 template <class M> class RuleScore {
00053   public:
00054     explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
00055       out.left.length = 0;
00056       out.right.length = 0;
00057     }
00058 
00059     void BeginSentence() {
00060       out_->right = model_.BeginSentenceState();
00061       // out_->left is empty.
00062       left_done_ = true;
00063     }
00064 
00065     void Terminal(WordIndex word) {
00066       State copy(out_->right);
00067       FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
00068       if (left_done_) { prob_ += ret.prob; return; }
00069       if (ret.independent_left) {
00070         prob_ += ret.prob;
00071         left_done_ = true;
00072         return;
00073       }
00074       out_->left.pointers[out_->left.length++] = ret.extend_left;
00075       prob_ += ret.rest;
00076       if (out_->right.length != copy.length + 1)
00077         left_done_ = true;
00078     }
00079 
00080     // Faster version of NonTerminal for the case where the rule begins with a non-terminal.
00081     void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
00082       prob_ = prob;
00083       *out_ = in;
00084       left_done_ = in.left.full;
00085     }
00086 
00087     void NonTerminal(const ChartState &in, float prob = 0.0) {
00088       prob_ += prob;
00089 
00090       if (!in.left.length) {
00091         if (in.left.full) {
00092           for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
00093           left_done_ = true;
00094           out_->right = in.right;
00095         }
00096         return;
00097       }
00098 
00099       if (!out_->right.length) {
00100         out_->right = in.right;
00101         if (left_done_) {
00102           prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
00103           return;
00104         }
00105         if (out_->left.length) {
00106           left_done_ = true;
00107         } else {
00108           out_->left = in.left;
00109           left_done_ = in.left.full;
00110         }
00111         return;
00112       }
00113 
00114       float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
00115       float *back = backoffs, *back2 = backoffs2;
00116       unsigned char next_use = out_->right.length;
00117 
00118       // First word
00119       if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
00120 
00121       // Words after the first, so extending a bigram to begin with
00122       for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
00123         if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
00124         std::swap(back, back2);
00125       }
00126 
00127       if (in.left.full) {
00128         for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
00129         left_done_ = true;
00130         out_->right = in.right;
00131         return;
00132       }
00133 
00134       // Right state was minimized, so it's already independent of the new words to the left.
00135       if (in.right.length < in.left.length) {
00136         out_->right = in.right;
00137         return;
00138       }
00139 
00140       // Shift exisiting words down.
00141       for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
00142         *(i + in.right.length) = *i;
00143       }
00144       // Add words from in.right.
00145       std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
00146       // Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
00147       std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
00148       std::copy(back, back + next_use, out_->right.backoff + in.right.length);
00149       out_->right.length = in.right.length + next_use;
00150     }
00151 
00152     float Finish() {
00153       // A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
00154       out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
00155       return prob_;
00156     }
00157 
00158     void Reset() {
00159       prob_ = 0.0;
00160       left_done_ = false;
00161       out_->left.length = 0;
00162       out_->right.length = 0;
00163     }
00164     void Reset(ChartState &replacement) {
00165       out_ = &replacement;
00166       Reset();
00167     }
00168 
00169   private:
00170     bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
00171       ProcessRet(model_.ExtendLeft(
00172             out_->right.words, out_->right.words + next_use, // Words to extend into
00173             back_in, // Backoffs to use
00174             in.left.pointers[extend_length - 1], extend_length, // Words to be extended
00175             back_out, // Backoffs for the next score
00176             next_use)); // Length of n-gram to use in next scoring.
00177       if (next_use != out_->right.length) {
00178         left_done_ = true;
00179         if (!next_use) {
00180           // Early exit.
00181           out_->right = in.right;
00182           prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
00183           return true;
00184         }
00185       }
00186       // Continue scoring.
00187       return false;
00188     }
00189 
00190     void ProcessRet(const FullScoreReturn &ret) {
00191       if (left_done_) {
00192         prob_ += ret.prob;
00193         return;
00194       }
00195       if (ret.independent_left) {
00196         prob_ += ret.prob;
00197         left_done_ = true;
00198         return;
00199       }
00200       out_->left.pointers[out_->left.length++] = ret.extend_left;
00201       prob_ += ret.rest;
00202     }
00203 
00204     const M &model_;
00205 
00206     ChartState *out_;
00207 
00208     bool left_done_;
00209 
00210     float prob_;
00211 };
00212 
00213 } // namespace ngram
00214 } // namespace lm
00215 
00216 #endif // LM_LEFT_H