Joshua
open source statistical hierarchical phrase-based machine translation system
|
00001 #ifndef LM_NGRAM_QUERY_H 00002 #define LM_NGRAM_QUERY_H 00003 00004 #include "lm/enumerate_vocab.hh" 00005 #include "lm/model.hh" 00006 #include "util/file_stream.hh" 00007 #include "util/file_piece.hh" 00008 #include "util/usage.hh" 00009 00010 #include <cstdlib> 00011 #include <string> 00012 #include <cmath> 00013 00014 namespace lm { 00015 namespace ngram { 00016 00017 class QueryPrinter { 00018 public: 00019 QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush) 00020 : out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {} 00021 00022 void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) { 00023 if (!print_word_) return; 00024 out_ << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t'; 00025 if (flush_) out_.flush(); 00026 } 00027 00028 void Line(uint64_t oov, float total) { 00029 if (!print_line_) return; 00030 out_ << "Total: " << total << " OOV: " << oov << '\n'; 00031 if (flush_) out_.flush(); 00032 } 00033 00034 void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) { 00035 if (!print_summary_) return; 00036 out_ << 00037 "Perplexity including OOVs:\t" << ppl_including_oov << "\n" 00038 "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n" 00039 "OOVs:\t" << corpus_oov << "\n" 00040 "Tokens:\t" << corpus_tokens << '\n'; 00041 out_.flush(); 00042 } 00043 00044 private: 00045 util::FileStream out_; 00046 bool print_word_; 00047 bool print_line_; 00048 bool print_summary_; 00049 bool flush_; 00050 }; 00051 00052 template <class Model, class Printer> void Query(const Model &model, bool sentence_context, Printer &printer) { 00053 typename Model::State state, out; 00054 lm::FullScoreReturn ret; 00055 StringPiece word; 00056 00057 util::FilePiece in(0); 00058 00059 double corpus_total = 0.0; 00060 double corpus_total_oov_only = 0.0; 00061 uint64_t corpus_oov = 0; 00062 uint64_t corpus_tokens = 0; 00063 00064 while (true) { 00065 state = sentence_context ? model.BeginSentenceState() : model.NullContextState(); 00066 float total = 0.0; 00067 uint64_t oov = 0; 00068 00069 while (in.ReadWordSameLine(word)) { 00070 lm::WordIndex vocab = model.GetVocabulary().Index(word); 00071 ret = model.FullScore(state, vocab, out); 00072 if (vocab == model.GetVocabulary().NotFound()) { 00073 ++oov; 00074 corpus_total_oov_only += ret.prob; 00075 } 00076 total += ret.prob; 00077 printer.Word(word, vocab, ret); 00078 ++corpus_tokens; 00079 state = out; 00080 } 00081 // If people don't have a newline after their last query, this won't add a </s>. 00082 // Sue me. 00083 try { 00084 UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused."); 00085 } catch (const util::EndOfFileException &e) { break; } 00086 if (sentence_context) { 00087 ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out); 00088 total += ret.prob; 00089 ++corpus_tokens; 00090 printer.Word("</s>", model.GetVocabulary().EndSentence(), ret); 00091 } 00092 printer.Line(oov, total); 00093 corpus_total += total; 00094 corpus_oov += oov; 00095 } 00096 printer.Summary( 00097 pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))), // PPL including OOVs 00098 pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))), // PPL excluding OOVs 00099 corpus_oov, 00100 corpus_tokens); 00101 } 00102 00103 template <class Model> void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) { 00104 Model model(file, config); 00105 Query<Model, QueryPrinter>(model, sentence_context, printer); 00106 } 00107 00108 } // namespace ngram 00109 } // namespace lm 00110 00111 #endif // LM_NGRAM_QUERY_H 00112 00113