Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/util/probing_hash_table.hh
00001 #ifndef UTIL_PROBING_HASH_TABLE_H
00002 #define UTIL_PROBING_HASH_TABLE_H
00003 
00004 #include "util/exception.hh"
00005 #include "util/mmap.hh"
00006 
00007 #include <algorithm>
00008 #include <cstddef>
00009 #include <functional>
00010 #include <vector>
00011 
00012 #include <cassert>
00013 #include <stdint.h>
00014 
00015 namespace util {
00016 
00017 /* Thrown when table grows too large */
00018 class ProbingSizeException : public Exception {
00019   public:
00020     ProbingSizeException() throw() {}
00021     ~ProbingSizeException() throw() {}
00022 };
00023 
00024 // std::identity is an SGI extension :-(
00025 struct IdentityHash {
00026   template <class T> T operator()(T arg) const { return arg; }
00027 };
00028 
00029 class DivMod {
00030   public:
00031     explicit DivMod(std::size_t buckets) : buckets_(buckets) {}
00032 
00033     static std::size_t RoundBuckets(std::size_t from) {
00034       return from;
00035     }
00036 
00037     template <class It> It Ideal(It begin, uint64_t hash) const {
00038       return begin + (hash % buckets_);
00039     }
00040 
00041     template <class BaseIt, class OutIt> void Next(BaseIt begin, BaseIt end, OutIt &it) const {
00042       if (++it == end) it = begin;
00043     }
00044 
00045     void Double() {
00046       buckets_ *= 2;
00047     }
00048 
00049   private:
00050     std::size_t buckets_;
00051 };
00052 
00053 class Power2Mod {
00054   public:
00055     explicit Power2Mod(std::size_t buckets) {
00056       UTIL_THROW_IF(!buckets || (((buckets - 1) & buckets)), ProbingSizeException, "Size " << buckets << " is not a power of 2.");
00057       mask_ = buckets - 1;
00058     }
00059 
00060     // Round up to next power of 2.
00061     static std::size_t RoundBuckets(std::size_t from) {
00062       --from;
00063       from |= from >> 1;
00064       from |= from >> 2;
00065       from |= from >> 4;
00066       from |= from >> 8;
00067       from |= from >> 16;
00068       from |= from >> 32;
00069       return from + 1;
00070     }
00071 
00072     template <class It> It Ideal(It begin, uint64_t hash) const {
00073       return begin + (hash & mask_);
00074     }
00075 
00076     template <class BaseIt, class OutIt> void Next(BaseIt begin, BaseIt /*end*/, OutIt &it) const {
00077       it = begin + ((it - begin + 1) & mask_);
00078     }
00079 
00080     void Double() {
00081       mask_ = (mask_ << 1) | 1;
00082     }
00083 
00084   private:
00085     std::size_t mask_;
00086 };
00087 
00088 template <class EntryT, class HashT, class EqualT> class AutoProbing;
00089 
00090 /* Non-standard hash table
00091  * Buckets must be set at the beginning and must be greater than maximum number
00092  * of elements, else it throws ProbingSizeException.
00093  * Memory management and initialization is externalized to make it easier to
00094  * serialize these to disk and load them quickly.
00095  * Uses linear probing to find value.
00096  * Only insert and lookup operations.
00097  */
00098 template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key>, class ModT = DivMod> class ProbingHashTable {
00099   public:
00100     typedef EntryT Entry;
00101     typedef typename Entry::Key Key;
00102     typedef const Entry *ConstIterator;
00103     typedef Entry *MutableIterator;
00104     typedef HashT Hash;
00105     typedef EqualT Equal;
00106     typedef ModT Mod;
00107 
00108     static uint64_t Size(uint64_t entries, float multiplier) {
00109       uint64_t buckets = Mod::RoundBuckets(std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries))));
00110       return buckets * sizeof(Entry);
00111     }
00112 
00113     // Must be assigned to later.
00114     ProbingHashTable() : mod_(1), entries_(0)
00115 #ifdef DEBUG
00116       , initialized_(false)
00117 #endif
00118     {}
00119 
00120     ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal())
00121       : begin_(reinterpret_cast<MutableIterator>(start)),
00122         end_(begin_ + allocated / sizeof(Entry)),
00123         buckets_(end_ - begin_),
00124         invalid_(invalid),
00125         hash_(hash_func),
00126         equal_(equal_func),
00127         mod_(end_ - begin_),
00128         entries_(0)
00129 #ifdef DEBUG
00130         , initialized_(true)
00131 #endif
00132     {}
00133 
00134     void Relocate(void *new_base) {
00135       begin_ = reinterpret_cast<MutableIterator>(new_base);
00136       end_ = begin_ + buckets_;
00137     }
00138 
00139     MutableIterator Ideal(const Key key) {
00140       return mod_.Ideal(begin_, hash_(key));
00141     }
00142     ConstIterator Ideal(const Key key) const {
00143       return mod_.Ideal(begin_, hash_(key));
00144     }
00145 
00146     template <class T> MutableIterator Insert(const T &t) {
00147 #ifdef DEBUG
00148       assert(initialized_);
00149 #endif
00150       UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
00151       return UncheckedInsert(t);
00152     }
00153 
00154     // Return true if the value was found (and not inserted).  This is consistent with Find but the opposite of hash_map!
00155     template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
00156 #ifdef DEBUG
00157       assert(initialized_);
00158 #endif
00159       for (MutableIterator i = Ideal(t.GetKey());;mod_.Next(begin_, end_, i)) {
00160         Key got(i->GetKey());
00161         if (equal_(got, t.GetKey())) { out = i; return true; }
00162         if (equal_(got, invalid_)) {
00163           UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full.");
00164           *i = t;
00165           out = i;
00166           return false;
00167         }
00168       }
00169     }
00170 
00171     void FinishedInserting() {}
00172 
00173     // Don't change anything related to GetKey,
00174     template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
00175 #ifdef DEBUG
00176       assert(initialized_);
00177 #endif
00178       for (MutableIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) {
00179         Key got(i->GetKey());
00180         if (equal_(got, key)) { out = i; return true; }
00181         if (equal_(got, invalid_)) return false;
00182       }
00183     }
00184 
00185     // Like UnsafeMutableFind, but the key must be there.
00186     template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
00187       for (MutableIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) {
00188         Key got(i->GetKey());
00189         if (equal_(got, key)) { return i; }
00190         assert(!equal_(got, invalid_));
00191       }
00192     }
00193 
00194     // Iterator is both input and output.
00195     template <class Key> bool FindFromIdeal(const Key key, ConstIterator &i) const {
00196 #ifdef DEBUG
00197       assert(initialized_);
00198 #endif
00199       for (;; mod_.Next(begin_, end_, i)) {
00200         Key got(i->GetKey());
00201         if (equal_(got, key)) return true;
00202         if (equal_(got, invalid_)) return false;
00203       }
00204     }
00205 
00206     template <class Key> bool Find(const Key key, ConstIterator &out) const {
00207       out = Ideal(key);
00208       return FindFromIdeal(key, out);
00209     }
00210 
00211     // Like Find but we're sure it must be there.
00212     template <class Key> ConstIterator MustFind(const Key key) const {
00213       for (ConstIterator i(Ideal(key));; mod_.Next(begin_, end_, i)) {
00214         Key got(i->GetKey());
00215         if (equal_(got, key)) { return i; }
00216         assert(!equal_(got, invalid_));
00217       }
00218     }
00219 
00220     void Clear() {
00221       Entry invalid;
00222       invalid.SetKey(invalid_);
00223       std::fill(begin_, end_, invalid);
00224       entries_ = 0;
00225     }
00226 
00227     // Return number of entries assuming no serialization went on.
00228     std::size_t SizeNoSerialization() const {
00229       return entries_;
00230     }
00231 
00232     // Return memory size expected by Double.
00233     std::size_t DoubleTo() const {
00234       return buckets_ * 2 * sizeof(Entry);
00235     }
00236 
00237     // Inform the table that it has double the amount of memory.
00238     // Pass clear_new = false if you are sure the new memory is initialized
00239     // properly (to invalid_) i.e. by mremap.
00240     void Double(void *new_base, bool clear_new = true) {
00241       begin_ = static_cast<MutableIterator>(new_base);
00242       MutableIterator old_end = begin_ + buckets_;
00243       buckets_ *= 2;
00244       end_ = begin_ + buckets_;
00245       mod_.Double();
00246       if (clear_new) {
00247         Entry invalid;
00248         invalid.SetKey(invalid_);
00249         std::fill(old_end, end_, invalid);
00250       }
00251       std::vector<Entry> rolled_over;
00252       // Move roll-over entries to a buffer because they might not roll over anymore.  This should be small.
00253       for (MutableIterator i = begin_; i != old_end && !equal_(i->GetKey(), invalid_); ++i) {
00254         rolled_over.push_back(*i);
00255         i->SetKey(invalid_);
00256       }
00257       /* Re-insert everything.  Entries might go backwards to take over a
00258        * recently opened gap, stay, move to new territory, or wrap around.   If
00259        * an entry wraps around, it might go to a pointer greater than i (which
00260        * can happen at the beginning) and it will be revisited to possibly fill
00261        * in a gap created later.
00262        */
00263       Entry temp;
00264       for (MutableIterator i = begin_; i != old_end; ++i) {
00265         if (!equal_(i->GetKey(), invalid_)) {
00266           temp = *i;
00267           i->SetKey(invalid_);
00268           UncheckedInsert(temp);
00269         }
00270       }
00271       // Put the roll-over entries back in.
00272       for (typename std::vector<Entry>::const_iterator i(rolled_over.begin()); i != rolled_over.end(); ++i) {
00273         UncheckedInsert(*i);
00274       }
00275     }
00276 
00277     // Mostly for tests, check consistency of every entry.
00278     void CheckConsistency() {
00279       MutableIterator last;
00280       for (last = end_ - 1; last >= begin_ && !equal_(last->GetKey(), invalid_); --last) {}
00281       UTIL_THROW_IF(last == begin_, ProbingSizeException, "Completely full");
00282       MutableIterator i;
00283       // Beginning can be wrap-arounds.
00284       for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) {
00285         MutableIterator ideal = Ideal(i->GetKey());
00286         UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_));
00287       }
00288       MutableIterator pre_gap = i;
00289       for (; i != end_; ++i) {
00290         if (equal_(i->GetKey(), invalid_)) {
00291           pre_gap = i;
00292           continue;
00293         }
00294         MutableIterator ideal = Ideal(i->GetKey());
00295         UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_));
00296       }
00297     }
00298 
00299     ConstIterator RawBegin() const {
00300       return begin_;
00301     }
00302     ConstIterator RawEnd() const {
00303       return end_;
00304     }
00305 
00306   private:
00307     friend class AutoProbing<Entry, Hash, Equal>;
00308 
00309     template <class T> MutableIterator UncheckedInsert(const T &t) {
00310       for (MutableIterator i(Ideal(t.GetKey()));; mod_.Next(begin_, end_, i)) {
00311         if (equal_(i->GetKey(), invalid_)) { *i = t; return i; }
00312       }
00313     }
00314 
00315     MutableIterator begin_;
00316     MutableIterator end_;
00317     std::size_t buckets_;
00318     Key invalid_;
00319     Hash hash_;
00320     Equal equal_;
00321     Mod mod_;
00322 
00323     std::size_t entries_;
00324 #ifdef DEBUG
00325     bool initialized_;
00326 #endif
00327 };
00328 
00329 // Resizable linear probing hash table.  This owns the memory.
00330 template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing {
00331   private:
00332     typedef ProbingHashTable<EntryT, HashT, EqualT, Power2Mod> Backend;
00333   public:
00334     static std::size_t MemUsage(std::size_t size, float multiplier = 1.5) {
00335       return Backend::Size(size, multiplier);
00336     }
00337 
00338     typedef EntryT Entry;
00339     typedef typename Entry::Key Key;
00340     typedef const Entry *ConstIterator;
00341     typedef Entry *MutableIterator;
00342     typedef HashT Hash;
00343     typedef EqualT Equal;
00344 
00345     AutoProbing(std::size_t initial_size = 5, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) :
00346       allocated_(Backend::Size(initial_size, 1.2)), mem_(allocated_, KeyIsRawZero(invalid)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) {
00347       threshold_ = std::min<std::size_t>(backend_.buckets_ - 1, backend_.buckets_ * 0.9);
00348       if (!KeyIsRawZero(invalid)) {
00349         Clear();
00350       }
00351     }
00352 
00353     // Assumes that the key is unique.  Multiple insertions won't cause a failure, just inconsistent lookup.
00354     template <class T> MutableIterator Insert(const T &t) {
00355       ++backend_.entries_;
00356       DoubleIfNeeded();
00357       return backend_.UncheckedInsert(t);
00358     }
00359 
00360     template <class T> bool FindOrInsert(const T &t, MutableIterator &out) {
00361       DoubleIfNeeded();
00362       return backend_.FindOrInsert(t, out);
00363     }
00364 
00365     template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) {
00366       return backend_.UnsafeMutableFind(key, out);
00367     }
00368 
00369     template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) {
00370       return backend_.UnsafeMutableMustFind(key);
00371     }
00372 
00373     template <class Key> bool Find(const Key key, ConstIterator &out) const {
00374       return backend_.Find(key, out);
00375     }
00376 
00377     template <class Key> ConstIterator MustFind(const Key key) const {
00378       return backend_.MustFind(key);
00379     }
00380 
00381     std::size_t Size() const {
00382       return backend_.SizeNoSerialization();
00383     }
00384 
00385     void Clear() {
00386       backend_.Clear();
00387     }
00388 
00389     ConstIterator RawBegin() const {
00390       return backend_.RawBegin();
00391     }
00392     ConstIterator RawEnd() const {
00393       return backend_.RawEnd();
00394     }
00395 
00396   private:
00397     void DoubleIfNeeded() {
00398       if (UTIL_LIKELY(Size() < threshold_))
00399         return;
00400       HugeRealloc(backend_.DoubleTo(), KeyIsRawZero(backend_.invalid_), mem_);
00401       allocated_ = backend_.DoubleTo();
00402       backend_.Double(mem_.get(), !KeyIsRawZero(backend_.invalid_));
00403       threshold_ = std::min<std::size_t>(backend_.buckets_ - 1, backend_.buckets_ * 0.9);
00404     }
00405 
00406     bool KeyIsRawZero(const Key &key) {
00407       for (const uint8_t *i = reinterpret_cast<const uint8_t*>(&key); i < reinterpret_cast<const uint8_t*>(&key) + sizeof(Key); ++i) {
00408         if (*i) return false;
00409       }
00410       return true;
00411     }
00412 
00413     std::size_t allocated_;
00414     util::scoped_memory mem_;
00415     Backend backend_;
00416     std::size_t threshold_;
00417 };
00418 
00419 } // namespace util
00420 
00421 #endif // UTIL_PROBING_HASH_TABLE_H