Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef UTIL_PROBING_HASH_TABLE_H 00002 #define UTIL_PROBING_HASH_TABLE_H 00003 00004 #include "util/exception.hh" 00005 #include "util/mmap.hh" 00006 00007 #include <algorithm> 00008 #include <cstddef> 00009 #include <functional> 00010 #include <vector> 00011 00012 #include <cassert> 00013 #include <stdint.h> 00014 00015 namespace util { 00016 00017 /* Thrown when table grows too large */ 00018 class ProbingSizeException : public Exception { 00019 public: 00020 ProbingSizeException() throw() {} 00021 ~ProbingSizeException() throw() {} 00022 }; 00023 00024 // std::identity is an SGI extension :-( 00025 struct IdentityHash { 00026 template <class T> T operator()(T arg) const { return arg; } 00027 }; 00028 00029 class DivMod { 00030 public: 00031 explicit DivMod(std::size_t buckets) : buckets_(buckets) {} 00032 00033 static std::size_t RoundBuckets(std::size_t from) { 00034 return from; 00035 } 00036 00037 template <class It> It Ideal(It begin, uint64_t hash) const { 00038 return begin + (hash % buckets_); 00039 } 00040 00041 template <class BaseIt, class OutIt> void Next(BaseIt begin, BaseIt end, OutIt &it) const { 00042 if (++it == end) it = begin; 00043 } 00044 00045 void Double() { 00046 buckets_ *= 2; 00047 } 00048 00049 private: 00050 std::size_t buckets_; 00051 }; 00052 00053 class Power2Mod { 00054 public: 00055 explicit Power2Mod(std::size_t buckets) { 00056 UTIL_THROW_IF(!buckets || (((buckets - 1) & buckets)), ProbingSizeException, "Size " << buckets << " is not a power of 2."); 00057 mask_ = buckets - 1; 00058 } 00059 00060 // Round up to next power of 2. 00061 static std::size_t RoundBuckets(std::size_t from) { 00062 --from; 00063 from |= from >> 1; 00064 from |= from >> 2; 00065 from |= from >> 4; 00066 from |= from >> 8; 00067 from |= from >> 16; 00068 from |= from >> 32; 00069 return from + 1; 00070 } 00071 00072 template <class It> It Ideal(It begin, uint64_t hash) const { 00073 return begin + (hash & mask_); 00074 } 00075 00076 template <class BaseIt, class OutIt> void Next(BaseIt begin, BaseIt /*end*/, OutIt &it) const { 00077 it = begin + ((it - begin + 1) & mask_); 00078 } 00079 00080 void Double() { 00081 mask_ = (mask_ << 1) | 1; 00082 } 00083 00084 private: 00085 std::size_t mask_; 00086 }; 00087 00088 template <class EntryT, class HashT, class EqualT> class AutoProbing; 00089 00090 /* Non-standard hash table 00091 * Buckets must be set at the beginning and must be greater than maximum number 00092 * of elements, else it throws ProbingSizeException. 00093 * Memory management and initialization is externalized to make it easier to 00094 * serialize these to disk and load them quickly. 00095 * Uses linear probing to find value. 00096 * Only insert and lookup operations. 00097 */ 00098 template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key>, class ModT = DivMod> class ProbingHashTable { 00099 public: 00100 typedef EntryT Entry; 00101 typedef typename Entry::Key Key; 00102 typedef const Entry *ConstIterator; 00103 typedef Entry *MutableIterator; 00104 typedef HashT Hash; 00105 typedef EqualT Equal; 00106 typedef ModT Mod; 00107 00108 static uint64_t Size(uint64_t entries, float multiplier) { 00109 uint64_t buckets = Mod::RoundBuckets(std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries)))); 00110 return buckets * sizeof(Entry); 00111 } 00112 00113 // Must be assigned to later. 00114 ProbingHashTable() : mod_(1), entries_(0) 00115 #ifdef DEBUG 00116 , initialized_(false) 00117 #endif 00118 {} 00119 00120 ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) 00121 : begin_(reinterpret_cast<MutableIterator>(start)), 00122 end_(begin_ + allocated / sizeof(Entry)), 00123 buckets_(end_ - begin_), 00124 invalid_(invalid), 00125 hash_(hash_func), 00126 equal_(equal_func), 00127 mod_(end_ - begin_), 00128 entries_(0) 00129 #ifdef DEBUG 00130 , initialized_(true) 00131 #endif 00132 {} 00133 00134 void Relocate(void *new_base) { 00135 begin_ = reinterpret_cast<MutableIterator>(new_base); 00136 end_ = begin_ + buckets_; 00137 } 00138 00139 MutableIterator Ideal(const Key key) { 00140 return mod_.Ideal(begin_, hash_(key)); 00141 } 00142 ConstIterator Ideal(const Key key) const { 00143 return mod_.Ideal(begin_, hash_(key)); 00144 } 00145 00146 template <class T> MutableIterator Insert(const T &t) { 00147 #ifdef DEBUG 00148 assert(initialized_); 00149 #endif 00150 UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); 00151 return UncheckedInsert(t); 00152 } 00153 00154 // Return true if the value was found (and not inserted). This is consistent with Find but the opposite of hash_map! 00155 template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { 00156 #ifdef DEBUG 00157 assert(initialized_); 00158 #endif 00159 for (MutableIterator i = Ideal(t.GetKey());;mod_.Next(begin_, end_, i)) { 00160 Key got(i->GetKey()); 00161 if (equal_(got, t.GetKey())) { out = i; return true; } 00162 if (equal_(got, invalid_)) { 00163 UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); 00164 *i = t; 00165 out = i; 00166 return false; 00167 } 00168 } 00169 } 00170 00171 void FinishedInserting() {} 00172 00173 // Don't change anything related to GetKey, 00174 template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { 00175 #ifdef DEBUG 00176 assert(initialized_); 00177 #endif 00178 for (MutableIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) { 00179 Key got(i->GetKey()); 00180 if (equal_(got, key)) { out = i; return true; } 00181 if (equal_(got, invalid_)) return false; 00182 } 00183 } 00184 00185 // Like UnsafeMutableFind, but the key must be there. 00186 template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { 00187 for (MutableIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) { 00188 Key got(i->GetKey()); 00189 if (equal_(got, key)) { return i; } 00190 assert(!equal_(got, invalid_)); 00191 } 00192 } 00193 00194 // Iterator is both input and output. 00195 template <class Key> bool FindFromIdeal(const Key key, ConstIterator &i) const { 00196 #ifdef DEBUG 00197 assert(initialized_); 00198 #endif 00199 for (;; mod_.Next(begin_, end_, i)) { 00200 Key got(i->GetKey()); 00201 if (equal_(got, key)) return true; 00202 if (equal_(got, invalid_)) return false; 00203 } 00204 } 00205 00206 template <class Key> bool Find(const Key key, ConstIterator &out) const { 00207 out = Ideal(key); 00208 return FindFromIdeal(key, out); 00209 } 00210 00211 // Like Find but we're sure it must be there. 00212 template <class Key> ConstIterator MustFind(const Key key) const { 00213 for (ConstIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) { 00214 Key got(i->GetKey()); 00215 if (equal_(got, key)) { return i; } 00216 assert(!equal_(got, invalid_)); 00217 } 00218 } 00219 00220 void Clear() { 00221 Entry invalid; 00222 invalid.SetKey(invalid_); 00223 std::fill(begin_, end_, invalid); 00224 entries_ = 0; 00225 } 00226 00227 // Return number of entries assuming no serialization went on. 00228 std::size_t SizeNoSerialization() const { 00229 return entries_; 00230 } 00231 00232 // Return memory size expected by Double. 00233 std::size_t DoubleTo() const { 00234 return buckets_ * 2 * sizeof(Entry); 00235 } 00236 00237 // Inform the table that it has double the amount of memory. 00238 // Pass clear_new = false if you are sure the new memory is initialized 00239 // properly (to invalid_) i.e. by mremap. 00240 void Double(void *new_base, bool clear_new = true) { 00241 begin_ = static_cast<MutableIterator>(new_base); 00242 MutableIterator old_end = begin_ + buckets_; 00243 buckets_ *= 2; 00244 end_ = begin_ + buckets_; 00245 mod_.Double(); 00246 if (clear_new) { 00247 Entry invalid; 00248 invalid.SetKey(invalid_); 00249 std::fill(old_end, end_, invalid); 00250 } 00251 std::vector<Entry> rolled_over; 00252 // Move roll-over entries to a buffer because they might not roll over anymore. This should be small. 00253 for (MutableIterator i = begin_; i != old_end && !equal_(i->GetKey(), invalid_); ++i) { 00254 rolled_over.push_back(*i); 00255 i->SetKey(invalid_); 00256 } 00257 /* Re-insert everything. Entries might go backwards to take over a 00258 * recently opened gap, stay, move to new territory, or wrap around. If 00259 * an entry wraps around, it might go to a pointer greater than i (which 00260 * can happen at the beginning) and it will be revisited to possibly fill 00261 * in a gap created later. 00262 */ 00263 Entry temp; 00264 for (MutableIterator i = begin_; i != old_end; ++i) { 00265 if (!equal_(i->GetKey(), invalid_)) { 00266 temp = *i; 00267 i->SetKey(invalid_); 00268 UncheckedInsert(temp); 00269 } 00270 } 00271 // Put the roll-over entries back in. 00272 for (typename std::vector<Entry>::const_iterator i(rolled_over.begin()); i != rolled_over.end(); ++i) { 00273 UncheckedInsert(*i); 00274 } 00275 } 00276 00277 // Mostly for tests, check consistency of every entry. 00278 void CheckConsistency() { 00279 MutableIterator last; 00280 for (last = end_ - 1; last >= begin_ && !equal_(last->GetKey(), invalid_); --last) {} 00281 UTIL_THROW_IF(last == begin_, ProbingSizeException, "Completely full"); 00282 MutableIterator i; 00283 // Beginning can be wrap-arounds. 00284 for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) { 00285 MutableIterator ideal = Ideal(i->GetKey()); 00286 UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_)); 00287 } 00288 MutableIterator pre_gap = i; 00289 for (; i != end_; ++i) { 00290 if (equal_(i->GetKey(), invalid_)) { 00291 pre_gap = i; 00292 continue; 00293 } 00294 MutableIterator ideal = Ideal(i->GetKey()); 00295 UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_)); 00296 } 00297 } 00298 00299 ConstIterator RawBegin() const { 00300 return begin_; 00301 } 00302 ConstIterator RawEnd() const { 00303 return end_; 00304 } 00305 00306 private: 00307 friend class AutoProbing<Entry, Hash, Equal>; 00308 00309 template <class T> MutableIterator UncheckedInsert(const T &t) { 00310 for (MutableIterator i(Ideal(t.GetKey()));; mod_.Next(begin_, end_, i)) { 00311 if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } 00312 } 00313 } 00314 00315 MutableIterator begin_; 00316 MutableIterator end_; 00317 std::size_t buckets_; 00318 Key invalid_; 00319 Hash hash_; 00320 Equal equal_; 00321 Mod mod_; 00322 00323 std::size_t entries_; 00324 #ifdef DEBUG 00325 bool initialized_; 00326 #endif 00327 }; 00328 00329 // Resizable linear probing hash table. This owns the memory. 00330 template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing { 00331 private: 00332 typedef ProbingHashTable<EntryT, HashT, EqualT, Power2Mod> Backend; 00333 public: 00334 static std::size_t MemUsage(std::size_t size, float multiplier = 1.5) { 00335 return Backend::Size(size, multiplier); 00336 } 00337 00338 typedef EntryT Entry; 00339 typedef typename Entry::Key Key; 00340 typedef const Entry *ConstIterator; 00341 typedef Entry *MutableIterator; 00342 typedef HashT Hash; 00343 typedef EqualT Equal; 00344 00345 AutoProbing(std::size_t initial_size = 5, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : 00346 allocated_(Backend::Size(initial_size, 1.2)), mem_(allocated_, KeyIsRawZero(invalid)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { 00347 threshold_ = std::min<std::size_t>(backend_.buckets_ - 1, backend_.buckets_ * 0.9); 00348 if (!KeyIsRawZero(invalid)) { 00349 Clear(); 00350 } 00351 } 00352 00353 // Assumes that the key is unique. Multiple insertions won't cause a failure, just inconsistent lookup. 00354 template <class T> MutableIterator Insert(const T &t) { 00355 ++backend_.entries_; 00356 DoubleIfNeeded(); 00357 return backend_.UncheckedInsert(t); 00358 } 00359 00360 template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { 00361 DoubleIfNeeded(); 00362 return backend_.FindOrInsert(t, out); 00363 } 00364 00365 template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { 00366 return backend_.UnsafeMutableFind(key, out); 00367 } 00368 00369 template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { 00370 return backend_.UnsafeMutableMustFind(key); 00371 } 00372 00373 template <class Key> bool Find(const Key key, ConstIterator &out) const { 00374 return backend_.Find(key, out); 00375 } 00376 00377 template <class Key> ConstIterator MustFind(const Key key) const { 00378 return backend_.MustFind(key); 00379 } 00380 00381 std::size_t Size() const { 00382 return backend_.SizeNoSerialization(); 00383 } 00384 00385 void Clear() { 00386 backend_.Clear(); 00387 } 00388 00389 ConstIterator RawBegin() const { 00390 return backend_.RawBegin(); 00391 } 00392 ConstIterator RawEnd() const { 00393 return backend_.RawEnd(); 00394 } 00395 00396 private: 00397 void DoubleIfNeeded() { 00398 if (UTIL_LIKELY(Size() < threshold_)) 00399 return; 00400 HugeRealloc(backend_.DoubleTo(), KeyIsRawZero(backend_.invalid_), mem_); 00401 allocated_ = backend_.DoubleTo(); 00402 backend_.Double(mem_.get(), !KeyIsRawZero(backend_.invalid_)); 00403 threshold_ = std::min<std::size_t>(backend_.buckets_ - 1, backend_.buckets_ * 0.9); 00404 } 00405 00406 bool KeyIsRawZero(const Key &key) { 00407 for (const uint8_t *i = reinterpret_cast<const uint8_t*>(&key); i < reinterpret_cast<const uint8_t*>(&key) + sizeof(Key); ++i) { 00408 if (*i) return false; 00409 } 00410 return true; 00411 } 00412 00413 std::size_t allocated_; 00414 util::scoped_memory mem_; 00415 Backend backend_; 00416 std::size_t threshold_; 00417 }; 00418 00419 } // namespace util 00420 00421 #endif // UTIL_PROBING_HASH_TABLE_H