Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_QUANTIZE_H 00002 #define LM_QUANTIZE_H 00003 00004 #include "lm/blank.hh" 00005 #include "lm/config.hh" 00006 #include "lm/max_order.hh" 00007 #include "lm/model_type.hh" 00008 #include "util/bit_packing.hh" 00009 00010 #include <algorithm> 00011 #include <vector> 00012 00013 #include <stdint.h> 00014 00015 #include <iostream> 00016 00017 namespace lm { 00018 namespace ngram { 00019 00020 struct Config; 00021 class BinaryFormat; 00022 00023 /* Store values directly and don't quantize. */ 00024 class DontQuantize { 00025 public: 00026 static const ModelType kModelTypeAdd = static_cast<ModelType>(0); 00027 static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} 00028 static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } 00029 static uint8_t MiddleBits(const Config &/*config*/) { return 63; } 00030 static uint8_t LongestBits(const Config &/*config*/) { return 31; } 00031 00032 class MiddlePointer { 00033 public: 00034 MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {} 00035 00036 MiddlePointer() : address_(NULL, 0) {} 00037 00038 bool Found() const { 00039 return address_.base != NULL; 00040 } 00041 00042 float Prob() const { 00043 return util::ReadNonPositiveFloat31(address_.base, address_.offset); 00044 } 00045 00046 float Backoff() const { 00047 return util::ReadFloat32(address_.base, address_.offset + 31); 00048 } 00049 00050 float Rest() const { return Prob(); } 00051 00052 void Write(float prob, float backoff) { 00053 util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); 00054 util::WriteFloat32(address_.base, address_.offset + 31, backoff); 00055 } 00056 00057 private: 00058 util::BitAddress address_; 00059 }; 00060 00061 class LongestPointer { 00062 public: 00063 explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {} 00064 00065 LongestPointer() : address_(NULL, 0) {} 00066 00067 bool Found() const { 00068 return address_.base != NULL; 00069 } 00070 00071 float Prob() const { 00072 return util::ReadNonPositiveFloat31(address_.base, address_.offset); 00073 } 00074 00075 void Write(float prob) { 00076 util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); 00077 } 00078 00079 private: 00080 util::BitAddress address_; 00081 }; 00082 00083 DontQuantize() {} 00084 00085 void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {} 00086 00087 static const bool kTrain = false; 00088 // These should never be called because kTrain is false. 00089 void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {} 00090 void TrainProb(uint8_t, std::vector<float> &/*prob*/) {} 00091 00092 void FinishedLoading(const Config &) {} 00093 }; 00094 00095 class SeparatelyQuantize { 00096 private: 00097 class Bins { 00098 public: 00099 // Sigh C++ default constructor 00100 Bins() {} 00101 00102 Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} 00103 00104 float *Populate() { return begin_; } 00105 00106 uint64_t EncodeProb(float value) const { 00107 return Encode(value, 0); 00108 } 00109 00110 uint64_t EncodeBackoff(float value) const { 00111 if (value == 0.0) { 00112 return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; 00113 } 00114 return Encode(value, 2); 00115 } 00116 00117 float Decode(std::size_t off) const { return begin_[off]; } 00118 00119 uint8_t Bits() const { return bits_; } 00120 00121 uint64_t Mask() const { return mask_; } 00122 00123 private: 00124 uint64_t Encode(float value, size_t reserved) const { 00125 const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value); 00126 if (above == begin_ + reserved) return reserved; 00127 if (above == end_) return end_ - begin_ - 1; 00128 return above - begin_ - (value - *(above - 1) < *above - value); 00129 } 00130 00131 float *begin_; 00132 const float *end_; 00133 uint8_t bits_; 00134 uint64_t mask_; 00135 }; 00136 00137 public: 00138 static const ModelType kModelTypeAdd = kQuantAdd; 00139 00140 static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); 00141 00142 static uint64_t Size(uint8_t order, const Config &config) { 00143 uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); 00144 uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table; 00145 // unigrams are currently not quantized so no need for a table. 00146 return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8; 00147 } 00148 00149 static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } 00150 static uint8_t LongestBits(const Config &config) { return config.prob_bits; } 00151 00152 class MiddlePointer { 00153 public: 00154 MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {} 00155 00156 MiddlePointer() : address_(NULL, 0) {} 00157 00158 bool Found() const { return address_.base != NULL; } 00159 00160 float Prob() const { 00161 return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask())); 00162 } 00163 00164 float Backoff() const { 00165 return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask())); 00166 } 00167 00168 float Rest() const { return Prob(); } 00169 00170 void Write(float prob, float backoff) const { 00171 util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(), 00172 (ProbBins().EncodeProb(prob) << BackoffBins().Bits()) | BackoffBins().EncodeBackoff(backoff)); 00173 } 00174 00175 private: 00176 const Bins &ProbBins() const { return bins_[0]; } 00177 const Bins &BackoffBins() const { return bins_[1]; } 00178 const Bins *bins_; 00179 00180 util::BitAddress address_; 00181 }; 00182 00183 class LongestPointer { 00184 public: 00185 LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {} 00186 00187 LongestPointer() : address_(NULL, 0) {} 00188 00189 bool Found() const { return address_.base != NULL; } 00190 00191 void Write(float prob) const { 00192 util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob)); 00193 } 00194 00195 float Prob() const { 00196 return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask())); 00197 } 00198 00199 private: 00200 const Bins *table_; 00201 util::BitAddress address_; 00202 }; 00203 00204 SeparatelyQuantize() {} 00205 00206 void SetupMemory(void *start, unsigned char order, const Config &config); 00207 00208 static const bool kTrain = true; 00209 // Assumes 0.0 is removed from backoff. 00210 void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); 00211 // Train just probabilities (for longest order). 00212 void TrainProb(uint8_t order, std::vector<float> &prob); 00213 00214 void FinishedLoading(const Config &config); 00215 00216 const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; } 00217 00218 const Bins &LongestTable() const { return longest_; } 00219 00220 private: 00221 Bins tables_[KENLM_MAX_ORDER - 1][2]; 00222 00223 Bins longest_; 00224 00225 uint8_t *actual_base_; 00226 00227 uint8_t prob_bits_, backoff_bits_; 00228 }; 00229 00230 } // namespace ngram 00231 } // namespace lm 00232 00233 #endif // LM_QUANTIZE_H