Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_BUILDER_ADJUST_COUNTS_H 00002 #define LM_BUILDER_ADJUST_COUNTS_H 00003 00004 #include "lm/builder/discount.hh" 00005 #include "lm/lm_exception.hh" 00006 #include "util/exception.hh" 00007 00008 #include <vector> 00009 00010 #include <stdint.h> 00011 00012 namespace util { namespace stream { class ChainPositions; } } 00013 00014 namespace lm { 00015 namespace builder { 00016 00017 class BadDiscountException : public util::Exception { 00018 public: 00019 BadDiscountException() throw(); 00020 ~BadDiscountException() throw(); 00021 }; 00022 00023 struct DiscountConfig { 00024 // Overrides discounts for orders [1,discount_override.size()]. 00025 std::vector<Discount> overwrite; 00026 // If discounting fails for an order, copy them from here. 00027 Discount fallback; 00028 // What to do when discounts are out of range or would trigger divison by 00029 // zero. It it does something other than THROW_UP, use fallback_discount. 00030 WarningAction bad_action; 00031 }; 00032 00033 /* Compute adjusted counts. 00034 * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. 00035 * Output: [1,N]-grams with adjusted counts. 00036 * [1,N)-grams are in suffix order 00037 * N-grams are in undefined order (they're going to be sorted anyway). 00038 */ 00039 class AdjustCounts { 00040 public: 00041 // counts: output 00042 // counts_pruned: output 00043 // discounts: mostly output. If the input already has entries, they will be kept. 00044 // prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned. 00045 AdjustCounts( 00046 const std::vector<uint64_t> &prune_thresholds, 00047 std::vector<uint64_t> &counts, 00048 std::vector<uint64_t> &counts_pruned, 00049 const std::vector<bool> &prune_words, 00050 const DiscountConfig &discount_config, 00051 std::vector<Discount> &discounts) 00052 : prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned), 00053 prune_words_(prune_words), discount_config_(discount_config), discounts_(discounts) 00054 {} 00055 00056 void Run(const util::stream::ChainPositions &positions); 00057 00058 private: 00059 const std::vector<uint64_t> &prune_thresholds_; 00060 std::vector<uint64_t> &counts_; 00061 std::vector<uint64_t> &counts_pruned_; 00062 const std::vector<bool> &prune_words_; 00063 00064 DiscountConfig discount_config_; 00065 std::vector<Discount> &discounts_; 00066 }; 00067 00068 } // namespace builder 00069 } // namespace lm 00070 00071 #endif // LM_BUILDER_ADJUST_COUNTS_H 00072