Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_PARTIAL_H 00002 #define LM_PARTIAL_H 00003 00004 #include "lm/return.hh" 00005 #include "lm/state.hh" 00006 00007 #include <algorithm> 00008 #include <cassert> 00009 00010 namespace lm { 00011 namespace ngram { 00012 00013 struct ExtendReturn { 00014 float adjust; 00015 bool make_full; 00016 unsigned char next_use; 00017 }; 00018 00019 template <class Model> ExtendReturn ExtendLoop( 00020 const Model &model, 00021 unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start, 00022 const uint64_t *pointers, const uint64_t *pointers_end, 00023 uint64_t *&pointers_write, 00024 float *backoff_write) { 00025 unsigned char add_length = add_rend - add_rbegin; 00026 00027 float backoff_buf[2][KENLM_MAX_ORDER - 1]; 00028 float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1]; 00029 std::copy(backoff_start, backoff_start + add_length, backoff_in); 00030 00031 ExtendReturn value; 00032 value.make_full = false; 00033 value.adjust = 0.0; 00034 value.next_use = add_length; 00035 00036 unsigned char i = 0; 00037 unsigned char length = pointers_end - pointers; 00038 // pointers_write is NULL means that the existing left state is full, so we should use completed probabilities. 00039 if (pointers_write) { 00040 // Using full context, writing to new left state. 00041 for (; i < length; ++i) { 00042 FullScoreReturn ret(model.ExtendLeft( 00043 add_rbegin, add_rbegin + value.next_use, 00044 backoff_in, 00045 pointers[i], i + seen + 1, 00046 backoff_out, 00047 value.next_use)); 00048 std::swap(backoff_in, backoff_out); 00049 if (ret.independent_left) { 00050 value.adjust += ret.prob; 00051 value.make_full = true; 00052 ++i; 00053 break; 00054 } 00055 value.adjust += ret.rest; 00056 *pointers_write++ = ret.extend_left; 00057 if (value.next_use != add_length) { 00058 value.make_full = true; 00059 ++i; 00060 break; 00061 } 00062 } 00063 } 00064 // Using some of the new context. 00065 for (; i < length && value.next_use; ++i) { 00066 FullScoreReturn ret(model.ExtendLeft( 00067 add_rbegin, add_rbegin + value.next_use, 00068 backoff_in, 00069 pointers[i], i + seen + 1, 00070 backoff_out, 00071 value.next_use)); 00072 std::swap(backoff_in, backoff_out); 00073 value.adjust += ret.prob; 00074 } 00075 float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1); 00076 // Using none of the new context. 00077 value.adjust += unrest; 00078 00079 std::copy(backoff_in, backoff_in + value.next_use, backoff_write); 00080 return value; 00081 } 00082 00083 template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) { 00084 assert(seen < reveal.length || reveal_full); 00085 uint64_t *pointers_write = reveal_full ? NULL : left.pointers; 00086 float backoff_buffer[KENLM_MAX_ORDER - 1]; 00087 ExtendReturn value(ExtendLoop( 00088 model, 00089 seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen, 00090 left.pointers, left.pointers + left.length, 00091 pointers_write, 00092 left.full ? backoff_buffer : (right.backoff + right.length))); 00093 if (reveal_full) { 00094 left.length = 0; 00095 value.make_full = true; 00096 } else { 00097 left.length = pointers_write - left.pointers; 00098 value.make_full |= (left.length == model.Order() - 1); 00099 } 00100 if (left.full) { 00101 for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; 00102 } else { 00103 // If left wasn't full when it came in, put words into right state. 00104 std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length); 00105 right.length += value.next_use; 00106 left.full = value.make_full || (right.length == model.Order() - 1); 00107 } 00108 return value.adjust; 00109 } 00110 00111 template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) { 00112 assert(seen < reveal.length || reveal.full); 00113 uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length); 00114 ExtendReturn value(ExtendLoop( 00115 model, 00116 seen, right.words, right.words + right.length, right.backoff, 00117 reveal.pointers + seen, reveal.pointers + reveal.length, 00118 pointers_write, 00119 right.backoff)); 00120 if (reveal.full) { 00121 for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i]; 00122 right.length = 0; 00123 value.make_full = true; 00124 } else { 00125 right.length = value.next_use; 00126 value.make_full |= (right.length == model.Order() - 1); 00127 } 00128 if (!left.full) { 00129 left.length = pointers_write - left.pointers; 00130 left.full = value.make_full || (left.length == model.Order() - 1); 00131 } 00132 return value.adjust; 00133 } 00134 00135 template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) { 00136 assert(first_right.length < KENLM_MAX_ORDER); 00137 assert(second_left.length < KENLM_MAX_ORDER); 00138 assert(between_length < KENLM_MAX_ORDER - 1); 00139 uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length); 00140 float backoff_buffer[KENLM_MAX_ORDER - 1]; 00141 ExtendReturn value(ExtendLoop( 00142 model, 00143 between_length, first_right.words, first_right.words + first_right.length, first_right.backoff, 00144 second_left.pointers, second_left.pointers + second_left.length, 00145 pointers_write, 00146 second_left.full ? backoff_buffer : (second_right.backoff + second_right.length))); 00147 if (second_left.full) { 00148 for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i]; 00149 } else { 00150 std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length); 00151 second_right.length += value.next_use; 00152 value.make_full |= (second_right.length == model.Order() - 1); 00153 } 00154 if (!first_left.full) { 00155 first_left.length = pointers_write - first_left.pointers; 00156 first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1); 00157 } 00158 assert(first_left.length < KENLM_MAX_ORDER); 00159 assert(second_right.length < KENLM_MAX_ORDER); 00160 return value.adjust; 00161 } 00162 00163 } // namespace ngram 00164 } // namespace lm 00165 00166 #endif // LM_PARTIAL_H