//////////////////////////////////////////////////////////////////////////////// /// DISCLAIMER /// /// Copyright 2016 by EMC Corporation, All Rights Reserved /// /// Licensed under the Apache License, Version 2.0 (the "License"); /// you may not use this file except in compliance with the License. /// You may obtain a copy of the License at /// /// http://www.apache.org/licenses/LICENSE-2.0 /// /// Unless required by applicable law or agreed to in writing, software /// distributed under the License is distributed on an "AS IS" BASIS, /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. /// See the License for the specific language governing permissions and /// limitations under the License. /// /// Copyright holder is EMC Corporation /// /// @author Andrey Abramov /// @author Vasiliy Nabatchikov //////////////////////////////////////////////////////////////////////////////// #if defined(_MSC_VER) #pragma warning(disable: 4101) #pragma warning(disable: 4267) #endif #include #if defined(_MSC_VER) #pragma warning(default: 4267) #pragma warning(default: 4101) #endif #include #include #include #if defined(_MSC_VER) #pragma warning(disable: 4229) #endif #include // for u_cleanup #if defined(_MSC_VER) #pragma warning(default: 4229) #endif #include "common.hpp" #include "analysis/analyzers.hpp" #include "analysis/token_attributes.hpp" #include "index/directory_reader.hpp" #include "search/bm25.hpp" #include "search/boolean_filter.hpp" #include "search/filter.hpp" #include "search/phrase_filter.hpp" #include "search/prefix_filter.hpp" #include "search/score.hpp" #include "search/term_filter.hpp" #include "store/fs_directory.hpp" #include "utils/memory_pool.hpp" #include "index-search.hpp" // std::regex support only starting from GCC 4.9 #if !defined(__GNUC__) || (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 8)) #include #else #include "boost/regex.hpp" NS_BEGIN(std) typedef ::boost::regex regex; typedef ::boost::match_results smatch; template bool regex_match(Args&&... args) { return ::boost::regex_match(std::forward(args)...); } NS_END // std #endif NS_LOCAL const std::string HELP = "help"; const std::string INDEX_DIR = "index-dir"; const std::string OUTPUT = "out"; const std::string INPUT = "in"; const std::string MAX = "max-tasks"; const std::string THR = "threads"; const std::string TOPN = "topN"; const std::string RND = "random"; const std::string RPT = "repeat"; const std::string CSV = "csv"; const std::string SCORED_TERMS_LIMIT = "scored-terms-limit"; const std::string SCORER = "scorer"; const std::string SCORER_ARG = "scorer-arg"; const std::string SCORER_ARG_FMT = "scorer-arg-format"; const std::string DIR_TYPE = "dir-type"; const std::string FORMAT = "format"; static bool v = false; typedef std::unique_ptr ustringp; NS_END struct Line { typedef std::shared_ptr ptr; std::string category; std::string text; Line(const std::string& c, const std::string& t): category(c), text(t) {} }; struct Task { std::string category; std::string text; typedef std::shared_ptr ptr; irs::filter::prepared::ptr prepared; int taskId; int totalHitCount; int topN; size_t tdiff_msec; std::thread::id tid; Task(std::string& s, std::string& t, int n, irs::filter::prepared::ptr p) : category(s), text(t), prepared(p), taskId(0), totalHitCount(0), topN(n) { } virtual ~Task() { } void go(std::thread::id id, irs::directory_reader& reader) { tid = id; auto start = std::chrono::system_clock::now(); query(reader); tdiff_msec = std::chrono::duration_cast( std::chrono::system_clock::now() - start ).count(); } virtual int query(irs::directory_reader& reader) = 0; virtual void print(std::ostream& out) = 0; virtual void print_csv(std::ostream& out) = 0; }; struct SearchTask : public Task { SearchTask(std::string& s, std::string& t, int n, irs::filter::prepared::ptr p) : Task(s, t, n, p) { } struct Entry { irs::doc_id_t id; float score; Entry(irs::doc_id_t i, float s) : id(i), score(s) { } }; std::vector top_docs; virtual int query(irs::directory_reader& reader) override { SCOPED_TIMER("Query execution + Result processing time"); for (auto& segment : reader) { // iterate segments irs::order order; order.add(true); auto prepared_order = order.prepare(); auto docs = prepared->execute(segment, prepared_order); // query segment const irs::score* score = docs->attributes().get().get(); auto comparer = [&prepared_order](const irs::bstring& lhs, const irs::bstring& rhs)->bool { return prepared_order.less(lhs.c_str(), rhs.c_str()); }; std::multimap sorted(comparer); // ensure we avoid COW for pre c++11 std::basic_string const irs::bytes_ref score_value = score->value(); while (docs->next()) { SCOPED_TIMER("Result processing time"); ++totalHitCount; score->evaluate(); sorted.emplace( std::piecewise_construct, std::forward_as_tuple(score_value), std::forward_as_tuple(docs->value(), score ? prepared_order.get(score_value.c_str(), 0) : .0) ); if (sorted.size() > topN) { sorted.erase(--(sorted.end())); } } for (auto& entry: sorted) { top_docs.emplace_back(std::move(entry.second)); } } return 0; } void print(std::ostream& out) override { out << "TASK: cat=" << category << " q='body:" << text << "' hits=" << totalHitCount << std::endl; out << " " << tdiff_msec / 1000. << " msec" << std::endl; out << " thread " << tid << std::endl; for (auto& doc : top_docs) { out << " doc=" << doc.id << " score=" << doc.score << std::endl; } out << std::endl; } void print_csv(std::ostream& out) override { out << category << "," << text << "," << totalHitCount << "," << tdiff_msec / 1000. << "," << tdiff_msec << std::endl; } }; class TaskSource { std::atomic idx; std::vector tasks; std::random_device rd; std::mt19937 g; int parseLines(std::string& line, Line::ptr& p) { static const std::regex m1("(\\S+): (.+)"); std::smatch res; std::string category; std::string text; if (std::regex_match(line, res, m1)) { category.assign(res[1].first, res[1].second); text.assign(res[2].first, res[2].second); p = Line::ptr(new Line(category, text)); return 0; } return -1; } int loadLines(std::vector& lines, std::istream& stream) { while (!stream.eof()) { std::string line; std::getline(stream, line); Line::ptr p; if (0 == parseLines(line, p)) { lines.push_back(p); } } return 0; } void shuffle(std::vector& line) { // @todo provide custom random? std::shuffle(line.begin(), line.end(), g); } static int pruneLines(std::vector& lines, std::vector& pruned_lines, int maxtasks) { std::map cat_counts; for (auto& t : lines) { std::map::iterator cat = cat_counts.find(t->category); int count = 0; if (cat != cat_counts.end()) { count = cat->second; } if (count < maxtasks) { ++count; if (cat != cat_counts.end()) { cat->second = count; } else { cat_counts[t->category] = count; } pruned_lines.push_back(t); } } return 0; } bool splitFreq(std::string& text, std::string& term) { static const std::regex freqPattern1("(\\S+)\\s*#\\s*(.+)"); // single term, prefix static const std::regex freqPattern2("\"(.+)\"\\s*#\\s*(.+)"); // phrase static const std::regex freqPattern3("((?:\\S+\\s+)+)\\s*#\\s*(.+)"); // AND/OR groups std::smatch res; if (std::regex_match(text, res, freqPattern1)) { term.assign(res[1].first, res[1].second); return true; } else if (std::regex_match(text, res, freqPattern2)) { term.assign(res[1].first, res[1].second); return true; } else if (std::regex_match(text, res, freqPattern3)) { term.assign(res[1].first, res[1].second); return true; } return false; } int prepareQueries(std::vector& lines, irs::directory_reader& reader, int topN) { irs::order order; order.add(true); auto ord = order.prepare(); for (auto& line : lines) { irs::filter::prepared::ptr prepared = nullptr; std::string terms; if (line->category == "HighTerm" || line->category == "MedTerm" || line->category == "LowTerm") { if (splitFreq(line->text, terms)) { irs::by_term query; query.field("body").term(terms); prepared = query.prepare(reader, ord); } } else if (line->category == "HighPhrase" || line->category == "MedPhrase" || line->category == "LowPhrase") { // @todo what's the difference between irs::by_phrase and irs::And? if (splitFreq(line->text, terms)) { std::istringstream f(terms); std::string term; irs::by_phrase query; query.field("body"); while (getline(f, term, ' ')) { query.push_back(term); } prepared = query.prepare(reader, ord); } } else if (line->category == "AndHighHigh" || line->category == "AndHighMed" || line->category == "AndHighLow") { if (splitFreq(line->text, terms)) { std::istringstream f(terms); std::string term; irs::And query; while (getline(f, term, ' ')) { irs::by_term& part = query.add(); part.field("body").term(term.c_str() + 1); // skip '+' at the start of the term } prepared = query.prepare(reader, ord); } } else if (line->category == "OrHighHigh" || line->category == "OrHighMed" || line->category == "OrHighLow") { if (splitFreq(line->text, terms)) { std::istringstream f(terms); std::string term; irs::Or query; while (getline(f, term, ' ')) { irs::by_term& part = query.add(); part.field("body").term(term); } prepared = query.prepare(reader, ord); } } else if (line->category == "Prefix3") { irs::by_prefix query; terms.assign(line->text.begin(), line->text.end() - 1); // cut '~' at the end of the text query.field("body").term(terms); prepared = query.prepare(reader, ord); } if (prepared != nullptr) { tasks.emplace_back(new SearchTask(line->category, terms, topN, prepared)); if (v) std::cout << tasks.size() << ": cat=" << line->category << "; term=" << terms << std::endl; } } if (v) std::cout << "Tasks prepared=" << tasks.size() << std::endl; return 0; } int repeatLines(std::vector& lines, std::vector& rep_lines, int repeat, bool do_shuffle) { while (repeat != 0) { if (do_shuffle) { shuffle(lines); } rep_lines.insert(std::end(rep_lines), std::begin(lines), std::end(lines)); --repeat; } return 0; } public: TaskSource() : idx(0), g(rd()) { } int load(std::istream& stream, int maxtasks, int repeat, irs::directory_reader& reader, int topN, bool do_shuffle) { /// /// this fn mimics lucene-util's LocalTaskSource behavior /// -- many similar tasks generated /// std::vector rep_lines; { std::vector pruned_lines; { std::vector lines; // parse all lines to category:text loadLines(lines, stream); // shuffle if (do_shuffle) { shuffle(lines); } // prune tasks pruneLines(lines, pruned_lines, maxtasks); } // multiply pruned with shuffling repeatLines(pruned_lines, rep_lines, repeat, do_shuffle); } // prepare queries prepareQueries(rep_lines, reader, topN); return 0; } Task::ptr next() { int next = idx++; // atomic get and increment if (next < tasks.size()) { return tasks[next]; } return nullptr; } std::vector& getTasks() { return tasks; } }; class TaskThread { public: typedef std::shared_ptr ptr; struct Args { TaskSource& tasks; irs::directory_reader& reader; Args(TaskSource& t, irs::directory_reader& r) : tasks(t), reader(r) { } }; private: std::thread* thr; void worker(Args a) { auto task = a.tasks.next(); while (task != nullptr) { task->go(thr->get_id(), a.reader); task = a.tasks.next(); } } public: void start(Args a) { thr = new std::thread(std::bind(&TaskThread::worker, this, a)); } void join() { thr->join(); delete thr; } }; enum class category_t { HighTerm, MedTerm, LowTerm, HighPhrase, MedPhrase, LowPhrase, AndHighHigh, AndHighMed, AndHighLow, OrHighHigh, OrHighMed, OrHighLow, Prefix3, UNKNOWN }; struct task_t { category_t category; std::string text; task_t(category_t v_category, std::string&& v_text): category(v_category), text(std::move(v_text)) {} }; category_t parseCategory(const irs::string_ref& value) { if (value == "HighTerm") return category_t::HighTerm; if (value == "MedTerm") return category_t::MedTerm; if (value == "LowTerm") return category_t::LowTerm; if (value == "HighPhrase") return category_t::HighPhrase; if (value == "MedPhrase") return category_t::MedPhrase; if (value == "LowPhrase") return category_t::LowPhrase; if (value == "AndHighHigh") return category_t::AndHighHigh; if (value == "AndHighMed") return category_t::AndHighMed; if (value == "AndHighLow") return category_t::AndHighLow; if (value == "OrHighHigh") return category_t::OrHighHigh; if (value == "OrHighMed") return category_t::OrHighMed; if (value == "OrHighLow") return category_t::OrHighLow; if (value == "Prefix3") return category_t::Prefix3; return category_t::UNKNOWN; } irs::string_ref stringCategory(category_t category) { switch(category) { case category_t::HighTerm: return "HighTerm"; case category_t::MedTerm: return "MedTerm"; case category_t::LowTerm: return "LowTerm"; case category_t::HighPhrase: return "HighPhrase"; case category_t::MedPhrase: return "MedPhrase"; case category_t::LowPhrase: return "LowPhrase"; case category_t::AndHighHigh: return "AndHighHigh"; case category_t::AndHighMed: return "AndHighMed"; case category_t::AndHighLow: return "AndHighLow"; case category_t::OrHighHigh: return "OrHighHigh"; case category_t::OrHighMed: return "OrHighMed"; case category_t::OrHighLow: return "OrHighLow"; case category_t::Prefix3: return "Prefix3"; default: return ""; } } irs::string_ref splitFreq(const std::string& text) { static const std::regex freqPattern1("(\\S+)\\s*#\\s*(.+)"); // single term, prefix static const std::regex freqPattern2("\"(.+)\"\\s*#\\s*(.+)"); // phrase static const std::regex freqPattern3("((?:\\S+\\s+)+)\\s*#\\s*(.+)"); // AND/OR groups std::smatch res; if (std::regex_match(text, res, freqPattern1)) { return irs::string_ref(&*(res[1].first), std::distance(res[1].first, res[1].second)); } else if (std::regex_match(text, res, freqPattern2)) { return irs::string_ref(&*(res[1].first), std::distance(res[1].first, res[1].second)); } else if (std::regex_match(text, res, freqPattern3)) { return irs::string_ref(&*(res[1].first), std::distance(res[1].first, res[1].second)); } return irs::string_ref::NIL; } irs::filter::prepared::ptr prepareFilter( const irs::directory_reader& reader, const irs::order::prepared& order, category_t category, const std::string& text, const irs::analysis::analyzer::ptr& analyzer, std::string& tmpBuf, size_t scored_terms_limit ) { irs::string_ref terms; switch (category) { case category_t::HighTerm: // fall through case category_t::MedTerm: // fall through case category_t::LowTerm: { if ((terms = splitFreq(text)).null()) { return nullptr; } irs::by_term query; query.field("body").term(terms); return query.prepare(reader, order); } case category_t::HighPhrase: // fall through case category_t::MedPhrase: // fall through case category_t::LowPhrase: { if ((terms = splitFreq(text)).null()) { return nullptr; } irs::by_phrase query; query.field("body"); analyzer->reset(terms); for (auto& term = analyzer->attributes().get(); analyzer->next();) { query.push_back(term->value()); } return query.prepare(reader, order); } case category_t::AndHighHigh: // fall through case category_t::AndHighMed: // fall through case category_t::AndHighLow: { if ((terms = splitFreq(text)).null()) { return nullptr; } irs::And query; for (std::istringstream in(terms); std::getline(in, tmpBuf, ' ');) { query.add().field("body").term(tmpBuf.c_str() + 1); // +1 for skip '+' at the start of the term } return query.prepare(reader, order); } case category_t::OrHighHigh: // fall through case category_t::OrHighMed: // fall through case category_t::OrHighLow: { if ((terms = splitFreq(text)).null()) { return nullptr; } irs::Or query; for (std::istringstream in(terms); std::getline(in, tmpBuf, ' ');) { query.add().field("body").term(tmpBuf); } return query.prepare(reader, order); } case category_t::Prefix3: { irs::by_prefix query; query.scored_terms_limit(scored_terms_limit); terms = irs::string_ref(text, text.size() - 1); // cut '~' at the end of the text query.field("body").term(terms); return query.prepare(reader, order); } } return nullptr; } void prepareTasks(std::vector& buf, std::istream& in, size_t tasks_per_category) { std::map category_counts; std::string tmpBuf; // parse all lines to category:text while (!in.eof()) { static const std::regex m1("(\\S+): (.+)"); std::smatch res; std::getline(in, tmpBuf); if (std::regex_match(tmpBuf, res, m1)) { auto category = parseCategory(irs::string_ref(&*(res[1].first), std::distance(res[1].first, res[1].second))); auto& count = category_counts.emplace(category, 0).first->second; if (++count <= tasks_per_category) { buf.emplace_back(category, std::string(res[2].first, res[2].second)); } } } } static int testQueries(std::vector& tasks, irs::directory_reader& reader) { for (auto& segment : reader) { // iterate segments int cnt = 0; for (auto& task : tasks) { ++cnt; std::cout << "running query=" << cnt << std::endl; auto& query = task->prepared; auto docs = query->execute(segment); // query segment while (docs->next()) { const irs::doc_id_t doc_id = docs->value(); // get doc id std::cout << cnt << " : " << doc_id << std::endl; } } } return 0; } static int printResults(std::vector& tasks, std::ostream& out, bool csv) { for (auto& task : tasks) { csv ? task->print_csv(out) : task->print(out); } return 0; } int search( const std::string& path, const std::string& dir_type, const std::string& format, std::istream& in, std::ostream& out, size_t tasks_max, size_t repeat, size_t search_threads, size_t limit, bool shuffle, bool csv, size_t scored_terms_limit, const std::string& scorer, const std::string& scorer_arg_format, const irs::string_ref& scorer_arg ) { static const std::map text_formats = { { "csv", irs::text_format::csv }, { "json", irs::text_format::json }, { "text", irs::text_format::text }, { "xml", irs::text_format::xml }, }; auto arg_format_itr = text_formats.find(scorer_arg_format); if (arg_format_itr == text_formats.end()) { std::cerr << "Unknown scorer argument format '" << scorer_arg_format << "'" << std::endl; return 1; } auto scr = irs::scorers::get(scorer, arg_format_itr->second, scorer_arg); if (!scr) { if (scorer_arg.null()) { std::cerr << "Unable to instantiate scorer '" << scorer << "' with argument format '" << scorer_arg_format << "' with nil arguments" << std::endl; } else { std::cerr << "Unable to instantiate scorer '" << scorer << "' with argument format '" << scorer_arg_format << "' with arguments '" << scorer_arg << "'" << std::endl; } return 1; } auto dir = create_directory(dir_type, path); if (!dir) { std::cerr << "Unable to create directory of type '" << dir_type << "'" << std::endl; return 1; } auto codec = irs::formats::get(format); if (!codec) { std::cerr << "Unable to find format of type '" << format << "'" << std::endl; return 1; } repeat = (std::max)(size_t(1), repeat); search_threads = (std::max)(size_t(1), search_threads); scored_terms_limit = (std::max)(size_t(1), scored_terms_limit); SCOPED_TIMER("Total Time"); std::cout << "Configuration: " << std::endl; std::cout << INDEX_DIR << "=" << path << std::endl; std::cout << MAX << "=" << tasks_max << std::endl; std::cout << RPT << "=" << repeat << std::endl; std::cout << THR << "=" << search_threads << std::endl; std::cout << TOPN << "=" << limit << std::endl; std::cout << RND << "=" << shuffle << std::endl; std::cout << CSV << "=" << csv << std::endl; std::cout << SCORED_TERMS_LIMIT << "=" << scored_terms_limit << std::endl; std::cout << SCORER << "=" << scorer << std::endl; std::cout << SCORER_ARG_FMT << "=" << scorer_arg_format << std::endl; std::cout << SCORER_ARG << "=" << scorer_arg << std::endl; irs::directory_reader reader; irs::order::prepared order; irs::async_utils::thread_pool thread_pool(search_threads); { SCOPED_TIMER("Index read time"); reader = irs::directory_reader::open(*dir, codec); } { SCOPED_TIMER("Order build time"); irs::order sort; sort.add(true, scr); order = sort.prepare(); } struct task_provider_t { std::mutex mutex; size_t next_task; std::mt19937 randomizer; size_t repeat; bool shuffle; std::vector tasks; std::vector task_ids; task_provider_t(): repeat(0), shuffle(false) {} void operator=(std::vector&& lines) { next_task = 0; tasks = std::move(lines); task_ids.resize(tasks.size()); for (size_t i = 0, count = task_ids.size(); i < count; ++i) { task_ids[i] = i; } // initial shuffle if (shuffle) { std::shuffle(task_ids.begin(), task_ids.end(), randomizer); } } const task_t* operator++() { SCOPED_LOCK(mutex); if (next_task >= task_ids.size()) { return nullptr; } auto& task = tasks[task_ids[next_task++]]; // prepare tasks for next iteration if repeat requested if (next_task >= task_ids.size() && repeat) { next_task = 0; --repeat; // shuffle if (shuffle) { std::shuffle(task_ids.begin(), task_ids.end(), randomizer); } } return &task; } } task_provider; // prepare tasks set { std::vector tasks; prepareTasks(tasks, in, tasks_max); task_provider.repeat = repeat - 1; // -1 for first run (i.e. additional repeats) task_provider.shuffle = shuffle; task_provider = std::move(tasks); } struct Entry { Entry(irs::doc_id_t i, float s) : id(i), score(s) { } irs::doc_id_t id; float score; }; // indexer threads for (size_t i = search_threads; i; --i) { thread_pool.run([&task_provider, &dir, &reader, &order, limit, &out, csv, scored_terms_limit]()->void { static const std::string analyzer_name("text"); static const std::string analyzer_args("{\"locale\":\"en\", \"stopwords\":[\"abc\", \"def\", \"ghi\"]}"); // from index-put auto analyzer = irs::analysis::analyzers::get(analyzer_name, irs::text_format::json, analyzer_args); irs::filter::prepared::ptr filter; std::string tmpBuf; #ifdef IRESEARCH_COMPLEX_SCORING #if defined(_MSC_VER) && defined(IRESEARCH_DEBUG) typedef irs::memory::memory_multi_size_pool pool_t; typedef irs::memory::memory_pool_multi_size_allocator alloc_t; #else typedef irs::memory::memory_pool pool_t; typedef irs::memory::memory_pool_allocator alloc_t; #endif pool_t pool(limit + 1); // +1 for least significant overflow element auto comparer = [&order](const irs::bstring& lhs, const irs::bstring& rhs)->bool { return order.less(lhs.c_str(), rhs.c_str()); }; std::multimap sorted( comparer, alloc_t{pool} ); #else std::vector> sorted; sorted.reserve(limit + 1); // +1 for least significant overflow element #endif // process a single task for (const task_t* task; (task = ++task_provider) != nullptr;) { SCOPED_TIMER("Full task processing time"); size_t doc_count = 0; auto start = std::chrono::system_clock::now(); sorted.clear(); // parse task { static struct timers_t { std::vector stat; timers_t() { for (size_t i = 0, count = size_t(category_t::UNKNOWN); i <= count; ++i) { stat.emplace_back(&irs::timer_utils::get_stat(std::string("Query building (") + stringCategory(category_t(i)).c_str() + ") time")); } } } timers; SCOPED_TIMER("Query building time"); irs::timer_utils::scoped_timer timer(*(timers.stat[size_t(task->category)])); filter = prepareFilter(reader, order, task->category, task->text, analyzer, tmpBuf, scored_terms_limit); if (!filter) { continue; } } // execute task { static struct timers_t { std::vector stat; timers_t() { for (size_t i = 0, count = size_t(category_t::UNKNOWN); i <= count; ++i) { stat.emplace_back(&irs::timer_utils::get_stat(std::string("Query execution (") + stringCategory(category_t(i)).c_str() + ") time")); } } } timers; SCOPED_TIMER("Query execution time"); irs::timer_utils::scoped_timer timer(*(timers.stat[size_t(task->category)])); const float EMPTY_SCORE = 0.f; for (auto& segment: reader) { auto docs = filter->execute(segment, order); // query segment auto& attributes = docs->attributes(); const irs::score& score = irs::score::extract(attributes); const irs::document* doc = attributes.get().get(); #ifdef IRESEARCH_COMPLEX_SCORING // ensure we avoid COW for pre c++11 std::basic_string const irs::bytes_ref raw_score_value = score->value(); #endif const auto& score_value = &score != &irs::score::no_score() ? order.get(score.c_str(), 0) : EMPTY_SCORE; while (docs->next()) { ++doc_count; score.evaluate(); #ifdef IRESEARCH_COMPLEX_SCORING sorted.emplace( std::piecewise_construct, std::forward_as_tuple(raw_score_value), std::forward_as_tuple(docs->value(), score_value) ); sorted.emplace(score_value, doc->value); if (sorted.size() > limit) { sorted.erase(--(sorted.end())); } #else std::push_heap( sorted.begin(), sorted.end(), [](const std::pair& lhs, const std::pair& rhs) NOEXCEPT { return lhs.first < rhs.first; }); if (sorted.size() > limit) { std::pop_heap( sorted.begin(), sorted.end(), [](const std::pair& lhs, const std::pair& rhs) NOEXCEPT { return lhs.first < rhs.first; }); sorted.pop_back(); } auto end = sorted.end(); for (auto begin = sorted.begin(); begin != end; --end) { std::pop_heap( begin, end, [](const std::pair& lhs, const std::pair& rhs) NOEXCEPT { return lhs.first < rhs.first; }); } #endif } } } // output task results { static std::mutex mutex; SCOPED_LOCK(mutex); SCOPED_TIMER("Result processing time"); auto tdiff = std::chrono::duration_cast(std::chrono::system_clock::now() - start); if (csv) { out << stringCategory(task->category) << "," << task->text << "," << doc_count << "," << tdiff.count() / 1000. << "," << tdiff.count() << std::endl; } else { out << "TASK: cat=" << stringCategory(task->category) << " q='body:" << task->text << "' hits=" << doc_count << std::endl; out << " " << tdiff.count() / 1000. << " msec" << std::endl; out << " thread " << std::this_thread::get_id() << std::endl; for (auto& entry : sorted) { #ifdef IRESEARCH_COMPLEX_SCORING out << " doc=" << entry.second.id << " score=" << entry.second.score << std::endl; #else out << " doc=" << entry.second << " score=" << entry.first<< std::endl; #endif } out << std::endl; } } } }); } thread_pool.stop(); u_cleanup(); return 0; } int search(const cmdline::parser& args) { if (!args.exist(INDEX_DIR) || !args.exist(INPUT)) { return 1; } const auto& path = args.get(INDEX_DIR); if (path.empty()) { return 1; } const size_t maxtasks = args.get(MAX); const size_t repeat = args.get(RPT); const bool shuffle = args.exist(RND); const size_t thrs = args.get(THR); const size_t topN = args.get(TOPN); const bool csv = args.exist(CSV); const size_t scored_terms_limit = args.get(SCORED_TERMS_LIMIT); const auto scorer = args.get(SCORER); const auto scorer_arg = args.exist(SCORER_ARG) ? irs::string_ref(args.get(SCORER_ARG)) : irs::string_ref::NIL; const auto scorer_arg_format = args.get(SCORER_ARG_FMT); const auto dir_type = args.exist(DIR_TYPE) ? args.get(DIR_TYPE) : std::string("fs"); const auto format = args.exist(FORMAT) ? args.get(FORMAT) : std::string("1_0"); std::cout << "Max tasks in category=" << maxtasks << '\n' << "Task repeat count=" << repeat << '\n' << "Do task list shuffle=" << shuffle << '\n' << "Search threads=" << thrs << '\n' << "Number of top documents to collect=" << topN << '\n' << "Number of terms to in range/prefix queries=" << scored_terms_limit << '\n' << "Scorer used for ranking query results=" << scorer << '\n' << "Configuration argument format for query scorer=" << scorer_arg_format << '\n' << "Configuration argument for query scorer=" << scorer_arg << '\n' << "Output CSV=" << csv << std::endl; std::fstream in(args.get(INPUT), std::fstream::in); if (!in) { return 1; } if (args.exist(OUTPUT)) { std::fstream out( args.get(OUTPUT), std::fstream::out | std::fstream::trunc ); if (!out) { return 1; } return search(path, dir_type, format, in, out, maxtasks, repeat, thrs, topN, shuffle, csv, scored_terms_limit, scorer, scorer_arg_format, scorer_arg); } return search(path, dir_type, format, in, std::cout, maxtasks, repeat, thrs, topN, shuffle, csv, scored_terms_limit, scorer, scorer_arg_format, scorer_arg); } int search(int argc, char* argv[]) { // mode search cmdline::parser cmdsearch; cmdsearch.add(HELP, '?', "Produce help message"); cmdsearch.add(INDEX_DIR, 0, "Path to index directory", true); cmdsearch.add(DIR_TYPE, 0, "Directory type (fs|mmap)", false, std::string("fs")); cmdsearch.add(FORMAT, 0, "Format (1_0|1_0-optimized)", false, std::string("1_0")); cmdsearch.add(INPUT, 0, "Task file", true); cmdsearch.add(OUTPUT, 0, "Stats file", false); cmdsearch.add(MAX, 0, "Maximum tasks per category", false, size_t(1)); cmdsearch.add(RPT, 0, "Task repeat count", false, size_t(20)); cmdsearch.add(THR, 0, "Number of search threads", false, size_t(1)); cmdsearch.add(TOPN, 0, "Number of top search results", false, size_t(10)); cmdsearch.add(SCORED_TERMS_LIMIT, 0, "Number of terms to score in range/prefix queries", false, size_t(1024)); cmdsearch.add(SCORER, 0, "Scorer used for ranking query results", false, "bm25"); cmdsearch.add(SCORER_ARG, 0, "Configuration argument for query scorer", false); cmdsearch.add(SCORER_ARG_FMT, 0, "Configuration argument format for query scorer", false, "json"); // 'json' is the argument format for 'bm25' cmdsearch.add(RND, 0, "Shuffle tasks"); cmdsearch.add(CSV, 0, "CSV output"); cmdsearch.parse(argc, argv); if (cmdsearch.exist(HELP)) { std::cout << cmdsearch.usage() << std::endl; return 0; } return search(cmdsearch); } // ----------------------------------------------------------------------------- // --SECTION-- END-OF-FILE // -----------------------------------------------------------------------------