Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/util/sorted_uniform.hh
00001 #ifndef UTIL_SORTED_UNIFORM_H
00002 #define UTIL_SORTED_UNIFORM_H
00003 
00004 #include <algorithm>
00005 #include <cstddef>
00006 #include <cassert>
00007 #include <stdint.h>
00008 
00009 namespace util {
00010 
00011 template <class T> class IdentityAccessor {
00012   public:
00013     typedef T Key;
00014     T operator()(const T *in) const { return *in; }
00015 };
00016 
00017 struct Pivot64 {
00018   static inline std::size_t Calc(uint64_t off, uint64_t range, std::size_t width) {
00019     std::size_t ret = static_cast<std::size_t>(static_cast<float>(off) / static_cast<float>(range) * static_cast<float>(width));
00020     // Cap for floating point rounding
00021     return (ret < width) ? ret : width - 1;
00022   }
00023 };
00024 
00025 // Use when off * width is <2^64.  This is guaranteed when each of them is actually a 32-bit value.
00026 struct Pivot32 {
00027   static inline std::size_t Calc(uint64_t off, uint64_t range, uint64_t width) {
00028     return static_cast<std::size_t>((off * width) / (range + 1));
00029   }
00030 };
00031 
00032 // Usage: PivotSelect<sizeof(DataType)>::T
00033 template <unsigned> struct PivotSelect;
00034 template <> struct PivotSelect<8> { typedef Pivot64 T; };
00035 template <> struct PivotSelect<4> { typedef Pivot32 T; };
00036 template <> struct PivotSelect<2> { typedef Pivot32 T; };
00037 
00038 /* Binary search. */
00039 template <class Iterator, class Accessor> bool BinaryFind(
00040     const Accessor &accessor,
00041     Iterator begin,
00042     Iterator end,
00043     const typename Accessor::Key key, Iterator &out) {
00044   while (end > begin) {
00045     Iterator pivot(begin + (end - begin) / 2);
00046     typename Accessor::Key mid(accessor(pivot));
00047     if (mid < key) {
00048       begin = pivot + 1;
00049     } else if (mid > key) {
00050       end = pivot;
00051     } else {
00052       out = pivot;
00053       return true;
00054     }
00055   }
00056   return false;
00057 }
00058 
00059 // Search the range [before_it + 1, after_it - 1] for key.
00060 // Preconditions:
00061 // before_v <= key <= after_v
00062 // before_v <= all values in the range [before_it + 1, after_it - 1] <= after_v
00063 // range is sorted.
00064 template <class Iterator, class Accessor, class Pivot> bool BoundedSortedUniformFind(
00065     const Accessor &accessor,
00066     Iterator before_it, typename Accessor::Key before_v,
00067     Iterator after_it, typename Accessor::Key after_v,
00068     const typename Accessor::Key key, Iterator &out) {
00069   while (after_it - before_it > 1) {
00070     Iterator pivot(before_it + (1 + Pivot::Calc(key - before_v, after_v - before_v, after_it - before_it - 1)));
00071     typename Accessor::Key mid(accessor(pivot));
00072     if (mid < key) {
00073       before_it = pivot;
00074       before_v = mid;
00075     } else if (mid > key) {
00076       after_it = pivot;
00077       after_v = mid;
00078     } else {
00079       out = pivot;
00080       return true;
00081     }
00082   }
00083   return false;
00084 }
00085 
00086 template <class Iterator, class Accessor, class Pivot> bool SortedUniformFind(const Accessor &accessor, Iterator begin, Iterator end, const typename Accessor::Key key, Iterator &out) {
00087   if (begin == end) return false;
00088   typename Accessor::Key below(accessor(begin));
00089   if (key <= below) {
00090     if (key == below) { out = begin; return true; }
00091     return false;
00092   }
00093   // Make the range [begin, end].
00094   --end;
00095   typename Accessor::Key above(accessor(end));
00096   if (key >= above) {
00097     if (key == above) { out = end; return true; }
00098     return false;
00099   }
00100   return BoundedSortedUniformFind<Iterator, Accessor, Pivot>(accessor, begin, below, end, above, key, out);
00101 }
00102 
00103 } // namespace util
00104 
00105 #endif // UTIL_SORTED_UNIFORM_H