Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/quantize.hh
00001 #ifndef LM_QUANTIZE_H
00002 #define LM_QUANTIZE_H
00003 
00004 #include "lm/blank.hh"
00005 #include "lm/config.hh"
00006 #include "lm/max_order.hh"
00007 #include "lm/model_type.hh"
00008 #include "util/bit_packing.hh"
00009 
00010 #include <algorithm>
00011 #include <vector>
00012 
00013 #include <stdint.h>
00014 
00015 #include <iostream>
00016 
00017 namespace lm {
00018 namespace ngram {
00019 
00020 struct Config;
00021 class BinaryFormat;
00022 
00023 /* Store values directly and don't quantize. */
00024 class DontQuantize {
00025   public:
00026     static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
00027     static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
00028     static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
00029     static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
00030     static uint8_t LongestBits(const Config &/*config*/) { return 31; }
00031 
00032     class MiddlePointer {
00033       public:
00034         MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {}
00035 
00036         MiddlePointer() : address_(NULL, 0) {}
00037 
00038         bool Found() const {
00039           return address_.base != NULL;
00040         }
00041 
00042         float Prob() const {
00043           return util::ReadNonPositiveFloat31(address_.base, address_.offset);
00044         }
00045 
00046         float Backoff() const {
00047           return util::ReadFloat32(address_.base, address_.offset + 31);
00048         }
00049 
00050         float Rest() const { return Prob(); }
00051 
00052         void Write(float prob, float backoff) {
00053           util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
00054           util::WriteFloat32(address_.base, address_.offset + 31, backoff);
00055         }
00056 
00057       private:
00058         util::BitAddress address_;
00059     };
00060 
00061     class LongestPointer {
00062       public:
00063         explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {}
00064 
00065         LongestPointer() : address_(NULL, 0) {}
00066 
00067         bool Found() const {
00068           return address_.base != NULL;
00069         }
00070 
00071         float Prob() const {
00072           return util::ReadNonPositiveFloat31(address_.base, address_.offset);
00073         }
00074 
00075         void Write(float prob) {
00076           util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
00077         }
00078 
00079       private:
00080         util::BitAddress address_;
00081     };
00082 
00083     DontQuantize() {}
00084 
00085     void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}
00086 
00087     static const bool kTrain = false;
00088     // These should never be called because kTrain is false.
00089     void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {}
00090     void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
00091 
00092     void FinishedLoading(const Config &) {}
00093 };
00094 
00095 class SeparatelyQuantize {
00096   private:
00097     class Bins {
00098       public:
00099         // Sigh C++ default constructor
00100         Bins() {}
00101 
00102         Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
00103 
00104         float *Populate() { return begin_; }
00105 
00106         uint64_t EncodeProb(float value) const {
00107           return Encode(value, 0);
00108         }
00109 
00110         uint64_t EncodeBackoff(float value) const {
00111           if (value == 0.0) {
00112             return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant;
00113           }
00114           return Encode(value, 2);
00115         }
00116 
00117         float Decode(std::size_t off) const { return begin_[off]; }
00118 
00119         uint8_t Bits() const { return bits_; }
00120 
00121         uint64_t Mask() const { return mask_; }
00122 
00123       private:
00124         uint64_t Encode(float value, size_t reserved) const {
00125           const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
00126           if (above == begin_ + reserved) return reserved;
00127           if (above == end_) return end_ - begin_ - 1;
00128           return above - begin_ - (value - *(above - 1) < *above - value);
00129         }
00130 
00131         float *begin_;
00132         const float *end_;
00133         uint8_t bits_;
00134         uint64_t mask_;
00135     };
00136 
00137   public:
00138     static const ModelType kModelTypeAdd = kQuantAdd;
00139 
00140     static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
00141 
00142     static uint64_t Size(uint8_t order, const Config &config) {
00143       uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
00144       uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
00145       // unigrams are currently not quantized so no need for a table.
00146       return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
00147     }
00148 
00149     static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
00150     static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
00151 
00152     class MiddlePointer {
00153       public:
00154         MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
00155 
00156         MiddlePointer() : address_(NULL, 0) {}
00157 
00158         bool Found() const { return address_.base != NULL; }
00159 
00160         float Prob() const {
00161           return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
00162         }
00163 
00164         float Backoff() const {
00165           return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
00166         }
00167 
00168         float Rest() const { return Prob(); }
00169 
00170         void Write(float prob, float backoff) const {
00171           util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
00172               (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff));
00173         }
00174 
00175       private:
00176         const Bins &ProbBins() const { return bins_[0]; }
00177         const Bins &BackoffBins() const { return bins_[1]; }
00178         const Bins *bins_;
00179 
00180         util::BitAddress address_;
00181     };
00182 
00183     class LongestPointer {
00184       public:
00185         LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
00186 
00187         LongestPointer() : address_(NULL, 0) {}
00188 
00189         bool Found() const { return address_.base != NULL; }
00190 
00191         void Write(float prob) const {
00192           util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
00193         }
00194 
00195         float Prob() const {
00196           return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
00197         }
00198 
00199       private:
00200         const Bins *table_;
00201         util::BitAddress address_;
00202     };
00203 
00204     SeparatelyQuantize() {}
00205 
00206     void SetupMemory(void *start, unsigned char order, const Config &config);
00207 
00208     static const bool kTrain = true;
00209     // Assumes 0.0 is removed from backoff.
00210     void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
00211     // Train just probabilities (for longest order).
00212     void TrainProb(uint8_t order, std::vector<float> &prob);
00213 
00214     void FinishedLoading(const Config &config);
00215 
00216     const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
00217 
00218     const Bins &LongestTable() const { return longest_; }
00219 
00220   private:
00221     Bins tables_[KENLM_MAX_ORDER - 1][2];
00222 
00223     Bins longest_;
00224 
00225     uint8_t *actual_base_;
00226 
00227     uint8_t prob_bits_, backoff_bits_;
00228 };
00229 
00230 } // namespace ngram
00231 } // namespace lm
00232 
00233 #endif // LM_QUANTIZE_H