Joshua
open source statistical hierarchical phrase-based machine translation system
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
src/kenlm/lm/ngram_query.hh
00001 #ifndef LM_NGRAM_QUERY_H
00002 #define LM_NGRAM_QUERY_H
00003 
00004 #include "lm/enumerate_vocab.hh"
00005 #include "lm/model.hh"
00006 #include "util/file_stream.hh"
00007 #include "util/file_piece.hh"
00008 #include "util/usage.hh"
00009 
00010 #include <cstdlib>
00011 #include <string>
00012 #include <cmath>
00013 
00014 namespace lm {
00015 namespace ngram {
00016 
00017 class QueryPrinter {
00018   public:
00019     QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush)
00020       : out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {}
00021 
00022     void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) {
00023       if (!print_word_) return;
00024       out_ << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length)  << ' ' << ret.prob << '\t';
00025       if (flush_) out_.flush();
00026     }
00027 
00028     void Line(uint64_t oov, float total) {
00029       if (!print_line_) return;
00030       out_ << "Total: " << total << " OOV: " << oov << '\n';
00031       if (flush_) out_.flush();
00032     }
00033 
00034     void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
00035       if (!print_summary_) return;
00036       out_ <<
00037         "Perplexity including OOVs:\t" << ppl_including_oov << "\n"
00038         "Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
00039         "OOVs:\t" << corpus_oov << "\n"
00040         "Tokens:\t" << corpus_tokens << '\n';
00041       out_.flush();
00042     }
00043 
00044   private:
00045     util::FileStream out_;
00046     bool print_word_;
00047     bool print_line_;
00048     bool print_summary_;
00049     bool flush_;
00050 };
00051 
00052 template <class Model, class Printer> void Query(const Model &model, bool sentence_context, Printer &printer) {
00053   typename Model::State state, out;
00054   lm::FullScoreReturn ret;
00055   StringPiece word;
00056 
00057   util::FilePiece in(0);
00058 
00059   double corpus_total = 0.0;
00060   double corpus_total_oov_only = 0.0;
00061   uint64_t corpus_oov = 0;
00062   uint64_t corpus_tokens = 0;
00063 
00064   while (true) {
00065     state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
00066     float total = 0.0;
00067     uint64_t oov = 0;
00068 
00069     while (in.ReadWordSameLine(word)) {
00070       lm::WordIndex vocab = model.GetVocabulary().Index(word);
00071       ret = model.FullScore(state, vocab, out);
00072       if (vocab == model.GetVocabulary().NotFound()) {
00073         ++oov;
00074         corpus_total_oov_only += ret.prob;
00075       }
00076       total += ret.prob;
00077       printer.Word(word, vocab, ret);
00078       ++corpus_tokens;
00079       state = out;
00080     }
00081     // If people don't have a newline after their last query, this won't add a </s>.
00082     // Sue me.
00083     try {
00084       UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
00085     } catch (const util::EndOfFileException &e) { break; }
00086     if (sentence_context) {
00087       ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
00088       total += ret.prob;
00089       ++corpus_tokens;
00090       printer.Word("</s>", model.GetVocabulary().EndSentence(), ret);
00091     }
00092     printer.Line(oov, total);
00093     corpus_total += total;
00094     corpus_oov += oov;
00095   }
00096   printer.Summary(
00097       pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))), // PPL including OOVs
00098       pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))), // PPL excluding OOVs
00099       corpus_oov,
00100       corpus_tokens);
00101 }
00102 
00103 template <class Model> void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) {
00104   Model model(file, config);
00105   Query<Model, QueryPrinter>(model, sentence_context, printer);
00106 }
00107 
00108 } // namespace ngram
00109 } // namespace lm
00110 
00111 #endif // LM_NGRAM_QUERY_H
00112 
00113