Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef UTIL_BIT_PACKING_H 00002 #define UTIL_BIT_PACKING_H 00003 00004 /* Bit-level packing routines 00005 * 00006 * WARNING WARNING WARNING: 00007 * The write functions assume that memory is zero initially. This makes them 00008 * faster and is the appropriate case for mmapped language model construction. 00009 * These routines assume that unaligned access to uint64_t is fast. This is 00010 * the case on x86_64. I'm not sure how fast unaligned 64-bit access is on 00011 * x86 but my target audience is large language models for which 64-bit is 00012 * necessary. 00013 * 00014 * Call the BitPackingSanity function to sanity check. Calling once suffices, 00015 * but it may be called multiple times when that's inconvenient. 00016 * 00017 * ARM and MinGW ports contributed by Hideo Okuma and Tomoyuki Yoshimura at 00018 * NICT. 00019 */ 00020 00021 #include <cassert> 00022 #ifdef __APPLE__ 00023 #include <architecture/byte_order.h> 00024 #elif __linux__ 00025 #include <endian.h> 00026 #elif !defined(_WIN32) && !defined(_WIN64) 00027 #include <arpa/nameser_compat.h> 00028 #endif 00029 00030 #include <stdint.h> 00031 #include <cstring> 00032 00033 namespace util { 00034 00035 // Fun fact: __BYTE_ORDER is wrong on Solaris Sparc, but the version without __ is correct. 00036 #if BYTE_ORDER == LITTLE_ENDIAN 00037 inline uint8_t BitPackShift(uint8_t bit, uint8_t /*length*/) { 00038 return bit; 00039 } 00040 #elif BYTE_ORDER == BIG_ENDIAN 00041 inline uint8_t BitPackShift(uint8_t bit, uint8_t length) { 00042 return 64 - length - bit; 00043 } 00044 #else 00045 #error "Bit packing code isn't written for your byte order." 00046 #endif 00047 00048 inline uint64_t ReadOff(const void *base, uint64_t bit_off) { 00049 #if defined(__arm) || defined(__arm__) 00050 const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); 00051 uint64_t value64; 00052 memcpy(&value64, base_off, sizeof(value64)); 00053 return value64; 00054 #else 00055 return *reinterpret_cast<const uint64_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)); 00056 #endif 00057 } 00058 00059 /* Pack integers up to 57 bits using their least significant digits. 00060 * The length is specified using mask: 00061 * Assumes mask == (1 << length) - 1 where length <= 57. 00062 */ 00063 inline uint64_t ReadInt57(const void *base, uint64_t bit_off, uint8_t length, uint64_t mask) { 00064 return (ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, length)) & mask; 00065 } 00066 /* Assumes value < (1 << length) and length <= 57. 00067 * Assumes the memory is zero initially. 00068 */ 00069 inline void WriteInt57(void *base, uint64_t bit_off, uint8_t length, uint64_t value) { 00070 #if defined(__arm) || defined(__arm__) 00071 uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); 00072 uint64_t value64; 00073 memcpy(&value64, base_off, sizeof(value64)); 00074 value64 |= (value << BitPackShift(bit_off & 7, length)); 00075 memcpy(base_off, &value64, sizeof(value64)); 00076 #else 00077 *reinterpret_cast<uint64_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= 00078 (value << BitPackShift(bit_off & 7, length)); 00079 #endif 00080 } 00081 00082 /* Same caveats as above, but for a 25 bit limit. */ 00083 inline uint32_t ReadInt25(const void *base, uint64_t bit_off, uint8_t length, uint32_t mask) { 00084 #if defined(__arm) || defined(__arm__) 00085 const uint8_t *base_off = reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3); 00086 uint32_t value32; 00087 memcpy(&value32, base_off, sizeof(value32)); 00088 return (value32 >> BitPackShift(bit_off & 7, length)) & mask; 00089 #else 00090 return (*reinterpret_cast<const uint32_t*>(reinterpret_cast<const uint8_t*>(base) + (bit_off >> 3)) >> BitPackShift(bit_off & 7, length)) & mask; 00091 #endif 00092 } 00093 00094 inline void WriteInt25(void *base, uint64_t bit_off, uint8_t length, uint32_t value) { 00095 #if defined(__arm) || defined(__arm__) 00096 uint8_t *base_off = reinterpret_cast<uint8_t*>(base) + (bit_off >> 3); 00097 uint32_t value32; 00098 memcpy(&value32, base_off, sizeof(value32)); 00099 value32 |= (value << BitPackShift(bit_off & 7, length)); 00100 memcpy(base_off, &value32, sizeof(value32)); 00101 #else 00102 *reinterpret_cast<uint32_t*>(reinterpret_cast<uint8_t*>(base) + (bit_off >> 3)) |= 00103 (value << BitPackShift(bit_off & 7, length)); 00104 #endif 00105 } 00106 00107 typedef union { float f; uint32_t i; } FloatEnc; 00108 00109 inline float ReadFloat32(const void *base, uint64_t bit_off) { 00110 FloatEnc encoded; 00111 encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); 00112 return encoded.f; 00113 } 00114 inline void WriteFloat32(void *base, uint64_t bit_off, float value) { 00115 FloatEnc encoded; 00116 encoded.f = value; 00117 WriteInt57(base, bit_off, 32, encoded.i); 00118 } 00119 00120 const uint32_t kSignBit = 0x80000000; 00121 00122 inline void SetSign(float &to) { 00123 FloatEnc enc; 00124 enc.f = to; 00125 enc.i |= kSignBit; 00126 to = enc.f; 00127 } 00128 00129 inline void UnsetSign(float &to) { 00130 FloatEnc enc; 00131 enc.f = to; 00132 enc.i &= ~kSignBit; 00133 to = enc.f; 00134 } 00135 00136 inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { 00137 FloatEnc encoded; 00138 encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); 00139 // Sign bit set means negative. 00140 encoded.i |= kSignBit; 00141 return encoded.f; 00142 } 00143 inline void WriteNonPositiveFloat31(void *base, uint64_t bit_off, float value) { 00144 FloatEnc encoded; 00145 encoded.f = value; 00146 encoded.i &= ~kSignBit; 00147 WriteInt57(base, bit_off, 31, encoded.i); 00148 } 00149 00150 void BitPackingSanity(); 00151 00152 // Return bits required to store integers upto max_value. Not the most 00153 // efficient implementation, but this is only called a few times to size tries. 00154 uint8_t RequiredBits(uint64_t max_value); 00155 00156 struct BitsMask { 00157 static BitsMask ByMax(uint64_t max_value) { 00158 BitsMask ret; 00159 ret.FromMax(max_value); 00160 return ret; 00161 } 00162 static BitsMask ByBits(uint8_t bits) { 00163 BitsMask ret; 00164 ret.bits = bits; 00165 ret.mask = (1ULL << bits) - 1; 00166 return ret; 00167 } 00168 void FromMax(uint64_t max_value) { 00169 bits = RequiredBits(max_value); 00170 mask = (1ULL << bits) - 1; 00171 } 00172 uint8_t bits; 00173 uint64_t mask; 00174 }; 00175 00176 struct BitAddress { 00177 BitAddress(void *in_base, uint64_t in_offset) : base(in_base), offset(in_offset) {} 00178 00179 void *base; 00180 uint64_t offset; 00181 }; 00182 00183 } // namespace util 00184 00185 #endif // UTIL_BIT_PACKING_H