Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_VALUE_BUILD_H 00002 #define LM_VALUE_BUILD_H 00003 00004 #include "lm/weights.hh" 00005 #include "lm/word_index.hh" 00006 #include "util/bit_packing.hh" 00007 00008 #include <vector> 00009 00010 namespace lm { 00011 namespace ngram { 00012 00013 struct Config; 00014 struct BackoffValue; 00015 struct RestValue; 00016 00017 class NoRestBuild { 00018 public: 00019 typedef BackoffValue Value; 00020 00021 NoRestBuild() {} 00022 00023 void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 00024 void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {} 00025 00026 template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const { 00027 util::UnsetSign(weights.prob); 00028 return false; 00029 } 00030 00031 // Probing doesn't need to go back to unigram. 00032 const static bool kMarkEvenLower = false; 00033 }; 00034 00035 class MaxRestBuild { 00036 public: 00037 typedef RestValue Value; 00038 00039 MaxRestBuild() {} 00040 00041 void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 00042 void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const { 00043 weights.rest = weights.prob; 00044 util::SetSign(weights.rest); 00045 } 00046 00047 bool MarkExtends(RestWeights &weights, const RestWeights &to) const { 00048 util::UnsetSign(weights.prob); 00049 if (weights.rest >= to.rest) return false; 00050 weights.rest = to.rest; 00051 return true; 00052 } 00053 bool MarkExtends(RestWeights &weights, const Prob &to) const { 00054 util::UnsetSign(weights.prob); 00055 if (weights.rest >= to.prob) return false; 00056 weights.rest = to.prob; 00057 return true; 00058 } 00059 00060 // Probing does need to go back to unigram. 00061 const static bool kMarkEvenLower = true; 00062 }; 00063 00064 template <class Model> class LowerRestBuild { 00065 public: 00066 typedef RestValue Value; 00067 00068 LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab); 00069 00070 ~LowerRestBuild(); 00071 00072 void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {} 00073 void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const { 00074 typename Model::State ignored; 00075 if (n == 1) { 00076 weights.rest = unigrams_[*vocab_ids]; 00077 } else { 00078 weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob; 00079 } 00080 } 00081 00082 template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const { 00083 util::UnsetSign(weights.prob); 00084 return false; 00085 } 00086 00087 const static bool kMarkEvenLower = false; 00088 00089 std::vector<float> unigrams_; 00090 00091 std::vector<const Model*> models_; 00092 }; 00093 00094 } // namespace ngram 00095 } // namespace lm 00096 00097 #endif // LM_VALUE_BUILD_H