Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef UTIL_STREAM_CHAIN_H 00002 #define UTIL_STREAM_CHAIN_H 00003 00004 #include "util/stream/block.hh" 00005 #include "util/stream/config.hh" 00006 #include "util/stream/multi_progress.hh" 00007 #include "util/scoped.hh" 00008 00009 #include <boost/ptr_container/ptr_vector.hpp> 00010 #include <boost/thread/thread.hpp> 00011 00012 #include <cstddef> 00013 #include <cassert> 00014 00015 namespace util { 00016 template <class T> class PCQueue; 00017 namespace stream { 00018 00019 class ChainConfigException : public Exception { 00020 public: 00021 ChainConfigException() throw(); 00022 ~ChainConfigException() throw(); 00023 }; 00024 00025 class Chain; 00026 class RewindableStream; 00027 00033 class ChainPosition { 00034 public: 00035 const Chain &GetChain() const { return *chain_; } 00036 private: 00037 friend class Chain; 00038 friend class Link; 00039 friend class RewindableStream; 00040 ChainPosition(PCQueue<Block> &in, PCQueue<Block> &out, Chain *chain, MultiProgress &progress) 00041 : in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {} 00042 00043 PCQueue<Block> *in_, *out_; 00044 00045 Chain *chain_; 00046 00047 WorkerProgress progress_; 00048 }; 00049 00050 00056 class Thread { 00057 public: 00058 00066 template <class Position, class Worker> Thread(const Position &position, const Worker &worker) 00067 : thread_(boost::ref(*this), position, worker) {} 00068 00069 ~Thread(); 00070 00076 template <class Position, class Worker> void operator()(const Position &position, Worker &worker) { 00077 // try { 00078 worker.Run(position); 00079 // } catch (const std::exception &e) { 00080 // UnhandledException(e); 00081 // } 00082 } 00083 00084 private: 00085 void UnhandledException(const std::exception &e); 00086 00087 boost::thread thread_; 00088 }; 00089 00093 class Recycler { 00094 public: 00101 void Run(const ChainPosition &position); 00102 }; 00103 00104 extern const Recycler kRecycle; 00105 class WriteAndRecycle; 00106 class PWriteAndRecycle; 00107 00111 class Chain { 00112 private: 00113 template <class T, void (T::*ptr)(const ChainPosition &) = &T::Run> struct CheckForRun { 00114 typedef Chain type; 00115 }; 00116 00117 public: 00118 00124 explicit Chain(const ChainConfig &config); 00125 00132 ~Chain(); 00133 00134 void ActivateProgress() { 00135 assert(!Running()); 00136 progress_.Activate(); 00137 } 00138 00139 void SetProgressTarget(uint64_t target) { 00140 progress_.SetTarget(target); 00141 } 00142 00148 std::size_t EntrySize() const { 00149 return config_.entry_size; 00150 } 00151 00157 std::size_t BlockSize() const { 00158 return block_size_; 00159 } 00160 00164 std::size_t BlockCount() const { 00165 return config_.block_count; 00166 } 00167 00169 ChainPosition Add(); 00170 00179 template <class Worker> typename CheckForRun<Worker>::type &operator>>(const Worker &worker) { 00180 assert(!complete_called_); 00181 threads_.push_back(new Thread(Add(), worker)); 00182 return *this; 00183 } 00184 00193 template <class Worker> typename CheckForRun<Worker>::type &operator>>(const boost::reference_wrapper<Worker> &worker) { 00194 assert(!complete_called_); 00195 threads_.push_back(new Thread(Add(), worker)); 00196 return *this; 00197 } 00198 00199 // Note that Link and Stream also define operator>> outside this class. 00200 00201 // To complete the loop, call CompleteLoop(), >> kRecycle, or the destructor. 00202 void CompleteLoop() { 00203 threads_.push_back(new Thread(Complete(), kRecycle)); 00204 } 00205 00210 Chain &operator>>(const Recycler &) { 00211 CompleteLoop(); 00212 return *this; 00213 } 00214 00219 Chain &operator>>(const WriteAndRecycle &writer); 00220 Chain &operator>>(const PWriteAndRecycle &writer); 00221 00222 // Chains are reusable. Call Wait to wait for everything to finish and free memory. 00223 void Wait(bool release_memory = true); 00224 00225 // Waits for the current chain to complete (if any) then starts again. 00226 void Start(); 00227 00228 bool Running() const { return !queues_.empty(); } 00229 00230 private: 00231 ChainPosition Complete(); 00232 00233 ChainConfig config_; 00234 00235 std::size_t block_size_; 00236 00237 scoped_malloc memory_; 00238 00239 boost::ptr_vector<PCQueue<Block> > queues_; 00240 00241 bool complete_called_; 00242 00243 boost::ptr_vector<Thread> threads_; 00244 00245 MultiProgress progress_; 00246 }; 00247 00248 // Create the link in the worker thread using the position token. 00252 class Link { 00253 public: 00254 00255 // Either default construct and Init or just construct all at once. 00256 00262 explicit Link(const ChainPosition &position); 00263 00269 Link(); 00270 00276 void Init(const ChainPosition &position); 00277 00286 ~Link(); 00287 00291 Block &operator*() { return current_; } 00292 00296 const Block &operator*() const { return current_; } 00297 00301 Block *operator->() { return ¤t_; } 00302 00306 const Block *operator->() const { return ¤t_; } 00307 00311 Link &operator++(); 00312 00320 operator bool() const { return current_; } 00321 00328 void Poison(); 00329 00330 private: 00331 Block current_; 00332 PCQueue<Block> *in_, *out_; 00333 00334 bool poisoned_; 00335 00336 WorkerProgress progress_; 00337 }; 00338 00339 inline Chain &operator>>(Chain &chain, Link &link) { 00340 link.Init(chain.Add()); 00341 return chain; 00342 } 00343 00344 } // namespace stream 00345 } // namespace util 00346 00347 #endif // UTIL_STREAM_CHAIN_H