Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 // Step of trie builder: create sorted files. 00002 00003 #ifndef LM_TRIE_SORT_H 00004 #define LM_TRIE_SORT_H 00005 00006 #include "lm/max_order.hh" 00007 #include "lm/word_index.hh" 00008 00009 #include "util/file.hh" 00010 #include "util/scoped.hh" 00011 00012 #include <cstddef> 00013 #include <functional> 00014 #include <string> 00015 #include <vector> 00016 00017 #include <stdint.h> 00018 00019 namespace util { 00020 class FilePiece; 00021 } // namespace util 00022 00023 namespace lm { 00024 class PositiveProbWarn; 00025 namespace ngram { 00026 class SortedVocabulary; 00027 struct Config; 00028 00029 namespace trie { 00030 00031 class EntryCompare : public std::binary_function<const void*, const void*, bool> { 00032 public: 00033 explicit EntryCompare(unsigned char order) : order_(order) {} 00034 00035 bool operator()(const void *first_void, const void *second_void) const { 00036 const WordIndex *first = static_cast<const WordIndex*>(first_void); 00037 const WordIndex *second = static_cast<const WordIndex*>(second_void); 00038 const WordIndex *end = first + order_; 00039 for (; first != end; ++first, ++second) { 00040 if (*first < *second) return true; 00041 if (*first > *second) return false; 00042 } 00043 return false; 00044 } 00045 private: 00046 unsigned char order_; 00047 }; 00048 00049 class RecordReader { 00050 public: 00051 RecordReader() : remains_(true) {} 00052 00053 void Init(FILE *file, std::size_t entry_size); 00054 00055 void *Data() { return data_.get(); } 00056 const void *Data() const { return data_.get(); } 00057 00058 RecordReader &operator++() { 00059 std::size_t ret = fread(data_.get(), entry_size_, 1, file_); 00060 if (!ret) { 00061 UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file"); 00062 remains_ = false; 00063 } 00064 return *this; 00065 } 00066 00067 operator bool() const { return remains_; } 00068 00069 void Rewind(); 00070 00071 std::size_t EntrySize() const { return entry_size_; } 00072 00073 void Overwrite(const void *start, std::size_t amount); 00074 00075 private: 00076 FILE *file_; 00077 00078 util::scoped_malloc data_; 00079 00080 bool remains_; 00081 00082 std::size_t entry_size_; 00083 }; 00084 00085 class SortedFiles { 00086 public: 00087 // Build from ARPA 00088 SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab); 00089 00090 int StealUnigram() { 00091 return unigram_.release(); 00092 } 00093 00094 FILE *Full(unsigned char order) { 00095 return full_[order - 2].get(); 00096 } 00097 00098 FILE *Context(unsigned char of_order) { 00099 return context_[of_order - 2].get(); 00100 } 00101 00102 private: 00103 void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); 00104 00105 util::scoped_fd unigram_; 00106 00107 util::scoped_FILE full_[KENLM_MAX_ORDER - 1], context_[KENLM_MAX_ORDER - 1]; 00108 }; 00109 00110 } // namespace trie 00111 } // namespace ngram 00112 } // namespace lm 00113 00114 #endif // LM_TRIE_SORT_H