Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/partial.hh
00001 #ifndef LM_PARTIAL_H
00002 #define LM_PARTIAL_H
00003 
00004 #include "lm/return.hh"
00005 #include "lm/state.hh"
00006 
00007 #include <algorithm>
00008 #include <cassert>
00009 
00010 namespace lm {
00011 namespace ngram {
00012 
00013 struct ExtendReturn {
00014   float adjust;
00015   bool make_full;
00016   unsigned char next_use;
00017 };
00018 
00019 template <class Model> ExtendReturn ExtendLoop(
00020     const Model &model,
00021     unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
00022     const uint64_t *pointers, const uint64_t *pointers_end,
00023     uint64_t *&pointers_write,
00024     float *backoff_write) {
00025   unsigned char add_length = add_rend - add_rbegin;
00026 
00027   float backoff_buf[2][KENLM_MAX_ORDER - 1];
00028   float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
00029   std::copy(backoff_start, backoff_start + add_length, backoff_in);
00030 
00031   ExtendReturn value;
00032   value.make_full = false;
00033   value.adjust = 0.0;
00034   value.next_use = add_length;
00035 
00036   unsigned char i = 0;
00037   unsigned char length = pointers_end - pointers;
00038   // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
00039   if (pointers_write) {
00040     // Using full context, writing to new left state.
00041     for (; i < length; ++i) {
00042       FullScoreReturn ret(model.ExtendLeft(
00043           add_rbegin, add_rbegin + value.next_use,
00044           backoff_in,
00045           pointers[i], i + seen + 1,
00046           backoff_out,
00047           value.next_use));
00048       std::swap(backoff_in, backoff_out);
00049       if (ret.independent_left) {
00050         value.adjust += ret.prob;
00051         value.make_full = true;
00052         ++i;
00053         break;
00054       }
00055       value.adjust += ret.rest;
00056       *pointers_write++ = ret.extend_left;
00057       if (value.next_use != add_length) {
00058         value.make_full = true;
00059         ++i;
00060         break;
00061       }
00062     }
00063   }
00064   // Using some of the new context.
00065   for (; i < length && value.next_use; ++i) {
00066     FullScoreReturn ret(model.ExtendLeft(
00067         add_rbegin, add_rbegin + value.next_use,
00068         backoff_in,
00069         pointers[i], i + seen + 1,
00070         backoff_out,
00071         value.next_use));
00072     std::swap(backoff_in, backoff_out);
00073     value.adjust += ret.prob;
00074   }
00075   float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
00076   // Using none of the new context.
00077   value.adjust += unrest;
00078 
00079   std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
00080   return value;
00081 }
00082 
00083 template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
00084   assert(seen < reveal.length || reveal_full);
00085   uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
00086   float backoff_buffer[KENLM_MAX_ORDER - 1];
00087   ExtendReturn value(ExtendLoop(
00088       model,
00089       seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
00090       left.pointers, left.pointers + left.length,
00091       pointers_write,
00092       left.full ? backoff_buffer : (right.backoff + right.length)));
00093   if (reveal_full) {
00094     left.length = 0;
00095     value.make_full = true;
00096   } else {
00097     left.length = pointers_write - left.pointers;
00098     value.make_full |= (left.length == model.Order() - 1);
00099   }
00100   if (left.full) {
00101     for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
00102   } else {
00103     // If left wasn't full when it came in, put words into right state.
00104     std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
00105     right.length += value.next_use;
00106     left.full = value.make_full || (right.length == model.Order() - 1);
00107   }
00108   return value.adjust;
00109 }
00110 
00111 template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
00112   assert(seen < reveal.length || reveal.full);
00113   uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
00114   ExtendReturn value(ExtendLoop(
00115       model,
00116       seen, right.words, right.words + right.length, right.backoff,
00117       reveal.pointers + seen, reveal.pointers + reveal.length,
00118       pointers_write,
00119       right.backoff));
00120   if (reveal.full) {
00121     for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
00122     right.length = 0;
00123     value.make_full = true;
00124   } else {
00125     right.length = value.next_use;
00126     value.make_full |= (right.length == model.Order() - 1);
00127   }
00128   if (!left.full) {
00129     left.length = pointers_write - left.pointers;
00130     left.full = value.make_full || (left.length == model.Order() - 1);
00131   }
00132   return value.adjust;
00133 }
00134 
00135 template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
00136   assert(first_right.length < KENLM_MAX_ORDER);
00137   assert(second_left.length < KENLM_MAX_ORDER);
00138   assert(between_length < KENLM_MAX_ORDER - 1);
00139   uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
00140   float backoff_buffer[KENLM_MAX_ORDER - 1];
00141   ExtendReturn value(ExtendLoop(
00142         model,
00143         between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
00144         second_left.pointers, second_left.pointers + second_left.length,
00145         pointers_write,
00146         second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
00147   if (second_left.full) {
00148     for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
00149   } else {
00150     std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
00151     second_right.length += value.next_use;
00152     value.make_full |= (second_right.length == model.Order() - 1);
00153   }
00154   if (!first_left.full) {
00155     first_left.length = pointers_write - first_left.pointers;
00156     first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
00157   }
00158   assert(first_left.length < KENLM_MAX_ORDER);
00159   assert(second_right.length < KENLM_MAX_ORDER);
00160   return value.adjust;
00161 }
00162 
00163 } // namespace ngram
00164 } // namespace lm
00165 
00166 #endif // LM_PARTIAL_H