Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/value_build.hh
00001 #ifndef LM_VALUE_BUILD_H
00002 #define LM_VALUE_BUILD_H
00003 
00004 #include "lm/weights.hh"
00005 #include "lm/word_index.hh"
00006 #include "util/bit_packing.hh"
00007 
00008 #include <vector>
00009 
00010 namespace lm {
00011 namespace ngram {
00012 
00013 struct Config;
00014 struct BackoffValue;
00015 struct RestValue;
00016 
00017 class NoRestBuild {
00018   public:
00019     typedef BackoffValue Value;
00020 
00021     NoRestBuild() {}
00022 
00023     void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
00024     void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}
00025 
00026     template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
00027       util::UnsetSign(weights.prob);
00028       return false;
00029     }
00030 
00031     // Probing doesn't need to go back to unigram.
00032     const static bool kMarkEvenLower = false;
00033 };
00034 
00035 class MaxRestBuild {
00036   public:
00037     typedef RestValue Value;
00038 
00039     MaxRestBuild() {}
00040 
00041     void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
00042     void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
00043       weights.rest = weights.prob;
00044       util::SetSign(weights.rest);
00045     }
00046 
00047     bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
00048       util::UnsetSign(weights.prob);
00049       if (weights.rest >= to.rest) return false;
00050       weights.rest = to.rest;
00051       return true;
00052     }
00053     bool MarkExtends(RestWeights &weights, const Prob &to) const {
00054       util::UnsetSign(weights.prob);
00055       if (weights.rest >= to.prob) return false;
00056       weights.rest = to.prob;
00057       return true;
00058     }
00059 
00060     // Probing does need to go back to unigram.
00061     const static bool kMarkEvenLower = true;
00062 };
00063 
00064 template <class Model> class LowerRestBuild {
00065   public:
00066     typedef RestValue Value;
00067 
00068     LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);
00069 
00070     ~LowerRestBuild();
00071 
00072     void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
00073     void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
00074       typename Model::State ignored;
00075       if (n == 1) {
00076         weights.rest = unigrams_[*vocab_ids];
00077       } else {
00078         weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
00079       }
00080     }
00081 
00082     template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
00083       util::UnsetSign(weights.prob);
00084       return false;
00085     }
00086 
00087     const static bool kMarkEvenLower = false;
00088 
00089     std::vector<float> unigrams_;
00090 
00091     std::vector<const Model*> models_;
00092 };
00093 
00094 } // namespace ngram
00095 } // namespace lm
00096 
00097 #endif // LM_VALUE_BUILD_H