Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_STATE_H 00002 #define LM_STATE_H 00003 00004 #include "lm/max_order.hh" 00005 #include "lm/word_index.hh" 00006 #include "util/murmur_hash.hh" 00007 00008 #include <cstring> 00009 00010 namespace lm { 00011 namespace ngram { 00012 00013 // This is a POD but if you want memcmp to return the same as operator==, call 00014 // ZeroRemaining first. 00015 class State { 00016 public: 00017 bool operator==(const State &other) const { 00018 if (length != other.length) return false; 00019 return !memcmp(words, other.words, length * sizeof(WordIndex)); 00020 } 00021 00022 // Three way comparison function. 00023 int Compare(const State &other) const { 00024 if (length != other.length) return length < other.length ? -1 : 1; 00025 return memcmp(words, other.words, length * sizeof(WordIndex)); 00026 } 00027 00028 bool operator<(const State &other) const { 00029 if (length != other.length) return length < other.length; 00030 return memcmp(words, other.words, length * sizeof(WordIndex)) < 0; 00031 } 00032 00033 // Call this before using raw memcmp. 00034 void ZeroRemaining() { 00035 for (unsigned char i = length; i < KENLM_MAX_ORDER - 1; ++i) { 00036 words[i] = 0; 00037 backoff[i] = 0.0; 00038 } 00039 } 00040 00041 unsigned char Length() const { return length; } 00042 00043 // You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD. 00044 // This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit. 00045 WordIndex words[KENLM_MAX_ORDER - 1]; 00046 float backoff[KENLM_MAX_ORDER - 1]; 00047 unsigned char length; 00048 }; 00049 00050 typedef State Right; 00051 00052 inline uint64_t hash_value(const State &state, uint64_t seed = 0) { 00053 return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed); 00054 } 00055 00056 struct Left { 00057 bool operator==(const Left &other) const { 00058 return 00059 length == other.length && 00060 (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full)); 00061 } 00062 00063 int Compare(const Left &other) const { 00064 if (length < other.length) return -1; 00065 if (length > other.length) return 1; 00066 if (length == 0) return 0; // Must be full. 00067 if (pointers[length - 1] > other.pointers[length - 1]) return 1; 00068 if (pointers[length - 1] < other.pointers[length - 1]) return -1; 00069 return (int)full - (int)other.full; 00070 } 00071 00072 bool operator<(const Left &other) const { 00073 return Compare(other) == -1; 00074 } 00075 00076 void ZeroRemaining() { 00077 for (uint64_t * i = pointers + length; i < pointers + KENLM_MAX_ORDER - 1; ++i) 00078 *i = 0; 00079 } 00080 00081 uint64_t pointers[KENLM_MAX_ORDER - 1]; 00082 unsigned char length; 00083 bool full; 00084 }; 00085 00086 inline uint64_t hash_value(const Left &left) { 00087 unsigned char add[2]; 00088 add[0] = left.length; 00089 add[1] = left.full; 00090 return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0); 00091 } 00092 00093 struct ChartState { 00094 bool operator==(const ChartState &other) const { 00095 return (right == other.right) && (left == other.left); 00096 } 00097 00098 int Compare(const ChartState &other) const { 00099 int lres = left.Compare(other.left); 00100 if (lres) return lres; 00101 return right.Compare(other.right); 00102 } 00103 00104 bool operator<(const ChartState &other) const { 00105 return Compare(other) < 0; 00106 } 00107 00108 void ZeroRemaining() { 00109 left.ZeroRemaining(); 00110 right.ZeroRemaining(); 00111 } 00112 00113 Left left; 00114 State right; 00115 }; 00116 00117 inline uint64_t hash_value(const ChartState &state) { 00118 return hash_value(state.right, hash_value(state.left)); 00119 } 00120 00121 00122 } // namespace ngram 00123 } // namespace lm 00124 00125 #endif // LM_STATE_H