mirror of https://gitee.com/bigwinds/arangodb
413 lines
11 KiB
C++
413 lines
11 KiB
C++
////////////////////////////////////////////////////////////////////////////////
|
|
/// DISCLAIMER
|
|
///
|
|
/// Copyright 2016 by EMC Corporation, All Rights Reserved
|
|
///
|
|
/// Licensed under the Apache License, Version 2.0 (the "License");
|
|
/// you may not use this file except in compliance with the License.
|
|
/// You may obtain a copy of the License at
|
|
///
|
|
/// http://www.apache.org/licenses/LICENSE-2.0
|
|
///
|
|
/// Unless required by applicable law or agreed to in writing, software
|
|
/// distributed under the License is distributed on an "AS IS" BASIS,
|
|
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
/// See the License for the specific language governing permissions and
|
|
/// limitations under the License.
|
|
///
|
|
/// Copyright holder is EMC Corporation
|
|
///
|
|
/// @author Andrey Abramov
|
|
/// @author Vasiliy Nabatchikov
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include <rapidjson/rapidjson/document.h> // for rapidjson::Document
|
|
|
|
#include "tfidf.hpp"
|
|
|
|
#include "scorers.hpp"
|
|
#include "analysis/token_attributes.hpp"
|
|
#include "index/index_reader.hpp"
|
|
#include "index/field_meta.hpp"
|
|
|
|
NS_LOCAL
|
|
|
|
irs::sort::ptr make_from_bool(
|
|
const rapidjson::Document& json,
|
|
const irs::string_ref& //args
|
|
) {
|
|
assert(json.IsBool());
|
|
|
|
return irs::memory::make_shared<irs::tfidf_sort>(
|
|
json.GetBool()
|
|
);
|
|
}
|
|
|
|
static const irs::string_ref WITH_NORMS_PARAM_NAME = "withNorms";
|
|
|
|
irs::sort::ptr make_from_object(
|
|
const rapidjson::Document& json,
|
|
const irs::string_ref& args) {
|
|
assert(json.IsObject());
|
|
|
|
auto ptr = irs::memory::make_shared<irs::tfidf_sort>();
|
|
|
|
#ifdef IRESEARCH_DEBUG
|
|
auto& scorer = dynamic_cast<irs::tfidf_sort&>(*ptr);
|
|
#else
|
|
auto& scorer = static_cast<irs::tfidf_sort&>(*ptr);
|
|
#endif
|
|
|
|
{
|
|
// optional bool
|
|
|
|
if (json.HasMember(WITH_NORMS_PARAM_NAME.c_str())) {
|
|
if (!json[WITH_NORMS_PARAM_NAME.c_str()].IsBool()) {
|
|
IR_FRMT_ERROR("Non-boolean value in '%s' while constructing tfidf scorer from jSON arguments: %s",
|
|
WITH_NORMS_PARAM_NAME.c_str(),
|
|
args.c_str());
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
scorer.normalize(json[WITH_NORMS_PARAM_NAME.c_str()].GetBool());
|
|
}
|
|
}
|
|
|
|
return ptr;
|
|
}
|
|
|
|
irs::sort::ptr make_from_array(
|
|
const rapidjson::Document& json,
|
|
const irs::string_ref& args) {
|
|
assert(json.IsArray());
|
|
|
|
const auto array = json.GetArray();
|
|
const auto size = array.Size();
|
|
|
|
if (size > 1) {
|
|
// wrong number of arguments
|
|
IR_FRMT_ERROR(
|
|
"Wrong number of arguments while constructing tfidf scorer from jSON arguments (must be <= 1): %s",
|
|
args.c_str()
|
|
);
|
|
return nullptr;
|
|
}
|
|
|
|
// default args
|
|
auto norms = irs::tfidf_sort::WITH_NORMS();
|
|
|
|
// parse `withNorms` optional argument
|
|
if (!array.Empty()) {
|
|
auto& arg = array[0];
|
|
if (!arg.IsBool()) {
|
|
IR_FRMT_ERROR(
|
|
"Non-float value on position `0` while constructing bm25 scorer from jSON arguments: %s",
|
|
args.c_str()
|
|
);
|
|
return nullptr;
|
|
}
|
|
|
|
norms = arg.GetBool();
|
|
}
|
|
|
|
return irs::memory::make_shared<irs::tfidf_sort>(norms);
|
|
}
|
|
|
|
irs::sort::ptr make_json(const irs::string_ref& args) {
|
|
if (args.null()) {
|
|
return irs::memory::make_shared<irs::tfidf_sort>();
|
|
}
|
|
|
|
rapidjson::Document json;
|
|
|
|
if (json.Parse(args.c_str(), args.size()).HasParseError()) {
|
|
IR_FRMT_ERROR(
|
|
"Invalid jSON arguments passed while constructing tfidf scorer, arguments: %s",
|
|
args.c_str()
|
|
);
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
switch (json.GetType()) {
|
|
case rapidjson::kFalseType:
|
|
case rapidjson::kTrueType:
|
|
return make_from_bool(json, args);
|
|
case rapidjson::kObjectType:
|
|
return make_from_object(json, args);
|
|
case rapidjson::kArrayType:
|
|
return make_from_array(json, args);
|
|
default: // wrong type
|
|
IR_FRMT_ERROR(
|
|
"Invalid jSON arguments passed while constructing tfidf scorer, arguments: %s",
|
|
args.c_str()
|
|
);
|
|
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
REGISTER_SCORER_JSON(irs::tfidf_sort, make_json);
|
|
|
|
struct byte_ref_iterator
|
|
: public std::iterator<std::input_iterator_tag, irs::byte_type, void, void, void> {
|
|
const irs::byte_type* end_;
|
|
const irs::byte_type* pos_;
|
|
byte_ref_iterator(const irs::bytes_ref& in)
|
|
: end_(in.c_str() + in.size()), pos_(in.c_str()) {
|
|
}
|
|
|
|
irs::byte_type operator*() {
|
|
if (pos_ >= end_) {
|
|
throw irs::io_error("invalid read past end of input");
|
|
}
|
|
|
|
return *pos_;
|
|
}
|
|
|
|
void operator++() { ++pos_; }
|
|
};
|
|
|
|
struct field_collector final: public irs::sort::field_collector {
|
|
uint64_t docs_with_field = 0; // number of documents containing the matched field (possibly without matching terms)
|
|
|
|
virtual void collect(
|
|
const irs::sub_reader& segment,
|
|
const irs::term_reader& field
|
|
) override {
|
|
docs_with_field += field.docs_count();
|
|
}
|
|
|
|
virtual void collect(const irs::bytes_ref& in) override {
|
|
byte_ref_iterator itr(in);
|
|
auto docs_with_field_value = irs::vread<uint64_t>(itr);
|
|
|
|
if (itr.pos_ != itr.end_) {
|
|
throw irs::io_error("input not read fully");
|
|
}
|
|
|
|
docs_with_field += docs_with_field_value;
|
|
}
|
|
|
|
virtual void write(irs::data_output& out) const override {
|
|
out.write_vlong(docs_with_field);
|
|
}
|
|
};
|
|
|
|
struct term_collector final: public irs::sort::term_collector {
|
|
uint64_t docs_with_term = 0; // number of documents containing the matched term
|
|
|
|
virtual void collect(
|
|
const irs::sub_reader& segment,
|
|
const irs::term_reader& field,
|
|
const irs::attribute_view& term_attrs
|
|
) override {
|
|
auto& meta = term_attrs.get<irs::term_meta>();
|
|
|
|
if (meta) {
|
|
docs_with_term += meta->docs_count;
|
|
}
|
|
}
|
|
|
|
virtual void collect(const irs::bytes_ref& in) override {
|
|
byte_ref_iterator itr(in);
|
|
auto docs_with_term_value = irs::vread<uint64_t>(itr);
|
|
|
|
if (itr.pos_ != itr.end_) {
|
|
throw irs::io_error("input not read fully");
|
|
}
|
|
|
|
docs_with_term += docs_with_term_value;
|
|
}
|
|
|
|
virtual void write(irs::data_output& out) const override {
|
|
out.write_vlong(docs_with_term);
|
|
}
|
|
};
|
|
|
|
FORCE_INLINE float_t tfidf(float_t freq, float_t idf) NOEXCEPT {
|
|
static_assert(
|
|
std::is_same<decltype(std::sqrt(freq)), float_t>::value,
|
|
"float_t expected"
|
|
);
|
|
|
|
return idf * std::sqrt(freq);
|
|
}
|
|
|
|
NS_END // LOCAL
|
|
|
|
NS_ROOT
|
|
NS_BEGIN(tfidf)
|
|
|
|
// empty frequency
|
|
const frequency EMPTY_FREQ;
|
|
|
|
struct idf final : attribute {
|
|
float_t value{ 0.f };
|
|
};
|
|
|
|
typedef tfidf_sort::score_t score_t;
|
|
|
|
struct const_score_ctx final : public irs::sort::score_ctx {
|
|
explicit const_score_ctx(irs::boost_t boost) NOEXCEPT
|
|
: boost_(boost) {
|
|
}
|
|
|
|
const irs::boost_t boost_;
|
|
}; // const_score_ctx
|
|
|
|
struct score_ctx : public irs::sort::score_ctx {
|
|
score_ctx(
|
|
irs::boost_t boost,
|
|
const tfidf::idf& idf,
|
|
const frequency* freq) NOEXCEPT
|
|
: idf_(boost * idf.value),
|
|
freq_(freq ? freq : &EMPTY_FREQ) {
|
|
assert(freq_);
|
|
}
|
|
|
|
float_t idf_; // precomputed : boost * idf
|
|
const frequency* freq_;
|
|
}; // score_ctx
|
|
|
|
struct norm_score_ctx final : public score_ctx {
|
|
norm_score_ctx(
|
|
irs::norm&& norm,
|
|
irs::boost_t boost,
|
|
const tfidf::idf& idf,
|
|
const frequency* freq) NOEXCEPT
|
|
: score_ctx(boost, idf, freq),
|
|
norm_(std::move(norm)) {
|
|
}
|
|
|
|
irs::norm norm_;
|
|
}; // norm_score_ctx
|
|
|
|
class sort final: public irs::sort::prepared_basic<tfidf::score_t, tfidf::idf> {
|
|
public:
|
|
DEFINE_FACTORY_INLINE(prepared)
|
|
|
|
explicit sort(bool normalize) NOEXCEPT
|
|
: normalize_(normalize) {
|
|
}
|
|
|
|
virtual void collect(
|
|
byte_type* stats_buf,
|
|
const irs::index_reader& index,
|
|
const irs::sort::field_collector* field,
|
|
const irs::sort::term_collector* term
|
|
) const override {
|
|
auto& idf = stats_cast(stats_buf);
|
|
|
|
#ifdef IRESEARCH_DEBUG
|
|
auto* field_ptr = dynamic_cast<const field_collector*>(field);
|
|
assert(!field || field_ptr);
|
|
auto* term_ptr = dynamic_cast<const term_collector*>(term);
|
|
assert(!term || term_ptr);
|
|
#else
|
|
auto* field_ptr = static_cast<const field_collector*>(field);
|
|
auto* term_ptr = static_cast<const term_collector*>(term);
|
|
#endif
|
|
|
|
const auto docs_with_field = field_ptr ? field_ptr->docs_with_field : 0; // nullptr possible if e.g. 'all' filter
|
|
const auto docs_with_term = term_ptr ? term_ptr->docs_with_term : 0; // nullptr possible if e.g.'by_column_existence' filter
|
|
|
|
idf.value += float_t(
|
|
std::log((docs_with_field + 1) / double_t(docs_with_term + 1)) + 1.0
|
|
);
|
|
assert(idf.value >= 0.f);
|
|
}
|
|
|
|
virtual const flags& features() const override {
|
|
static const irs::flags FEATURES[] = {
|
|
irs::flags({ irs::frequency::type() }), // without normalization
|
|
irs::flags({ irs::frequency::type(), irs::norm::type() }), // with normalization
|
|
};
|
|
|
|
return FEATURES[normalize_];
|
|
}
|
|
|
|
virtual irs::sort::field_collector::ptr prepare_field_collector() const override {
|
|
return irs::memory::make_unique<field_collector>();
|
|
}
|
|
|
|
virtual std::pair<score_ctx::ptr, score_f> prepare_scorer(
|
|
const sub_reader& segment,
|
|
const term_reader& field,
|
|
const byte_type* stats_buf,
|
|
const attribute_view& doc_attrs,
|
|
boost_t boost
|
|
) const override {
|
|
auto& freq = doc_attrs.get<frequency>();
|
|
|
|
if (!freq) {
|
|
return { nullptr, nullptr };
|
|
}
|
|
|
|
auto& stats = stats_cast(stats_buf);
|
|
|
|
// add norm attribute if requested
|
|
if (normalize_) {
|
|
irs::norm norm;
|
|
|
|
auto& doc = doc_attrs.get<document>();
|
|
|
|
if (!doc) {
|
|
// we need 'document' attribute to be exposed
|
|
return { nullptr, nullptr };
|
|
}
|
|
|
|
if (norm.reset(segment, field.meta().norm, *doc)) {
|
|
return {
|
|
memory::make_unique<tfidf::norm_score_ctx>(std::move(norm), boost, stats, freq.get()),
|
|
[](const void* ctx, byte_type* score_buf) NOEXCEPT {
|
|
auto& state = *static_cast<const tfidf::norm_score_ctx*>(ctx);
|
|
irs::sort::score_cast<tfidf::score_t>(score_buf) = ::tfidf(state.freq_->value, state.idf_)*state.norm_.read();
|
|
}
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
return {
|
|
memory::make_unique<tfidf::score_ctx>(boost, stats, freq.get()),
|
|
[](const void* ctx, byte_type* score_buf) NOEXCEPT {
|
|
auto& state = *static_cast<const tfidf::score_ctx*>(ctx);
|
|
irs::sort::score_cast<score_t>(score_buf) = ::tfidf(state.freq_->value, state.idf_);
|
|
}
|
|
};
|
|
}
|
|
|
|
virtual irs::sort::term_collector::ptr prepare_term_collector() const override {
|
|
return irs::memory::make_unique<term_collector>();
|
|
}
|
|
|
|
private:
|
|
bool normalize_;
|
|
}; // sort
|
|
|
|
NS_END // tfidf
|
|
|
|
DEFINE_SORT_TYPE_NAMED(irs::tfidf_sort, "tfidf")
|
|
DEFINE_FACTORY_DEFAULT(irs::tfidf_sort)
|
|
|
|
tfidf_sort::tfidf_sort(bool normalize) NOEXCEPT
|
|
: sort(tfidf_sort::type()),
|
|
normalize_(normalize) {
|
|
}
|
|
|
|
/*static*/ void tfidf_sort::init() {
|
|
REGISTER_SCORER_JSON(tfidf_sort, make_json); // match registration above
|
|
}
|
|
|
|
sort::prepared::ptr tfidf_sort::prepare() const {
|
|
return tfidf::sort::make<tfidf::sort>(normalize_);
|
|
}
|
|
|
|
NS_END // ROOT
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// --SECTION-- END-OF-FILE
|
|
// -----------------------------------------------------------------------------
|