1
0
Fork 0

backport fix from iresearch upstream (#8732)

This commit is contained in:
Andrey Abramov 2019-04-11 17:44:31 +03:00 committed by GitHub
parent 7b012d8fca
commit ac60c99079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 364 additions and 438 deletions

View File

@ -88,6 +88,9 @@ NS_BEGIN(burst_trie)
NS_BEGIN(detail)
// mininum size of string weight we store in FST
CONSTEXPR const size_t MIN_WEIGHT_SIZE = 2;
///////////////////////////////////////////////////////////////////////////////
/// @struct block_meta
/// @brief Provides set of helper functions to work with block metadata
@ -519,7 +522,10 @@ class term_iterator final : public irs::seek_term_iterator {
}
inline block_iterator* push_block(byte_weight&& out, size_t prefix) {
assert(out.Size());
if (out.Size() < MIN_WEIGHT_SIZE) {
return nullptr;
}
block_stack_.emplace_back(std::move(out), prefix, this);
return &block_stack_.back();
}
@ -1013,12 +1019,68 @@ ptrdiff_t term_iterator::seek_cached(
return cmp;
}
//bool term_iterator::seek_to_block(const bytes_ref& term, size_t& prefix) {
// assert(owner_->fst_);
//
// typedef fst_t::Weight weight_t;
//
// const auto& fst = *owner_->fst_;
//
// prefix = 0; // number of current symbol to process
// arc::stateid_t state = fst.Start(); // start state
// weight_.Clear(); // clear aggregated fst output
//
// if (cur_block_) {
// const auto cmp = seek_cached(prefix, state, weight_, term);
// if (cmp > 0) {
// // target term is before the current term
// cur_block_->reset();
// } else if (0 == cmp) {
// // we're already at current term
// return true;
// }
// } else {
// cur_block_ = push_block(fst.Final(state), prefix);
// }
//
// term_.oversize(term.size());
// term_.reset(prefix); // reset to common seek prefix
// sstate_.resize(prefix); // remove invalid cached arcs
//
// bool found = fst_byte_builder::final != state;
// while (found && prefix < term.size()) {
// matcher_.SetState(state);
// if (found = matcher_.Find(term[prefix])) {
// const auto& arc = matcher_.Value();
// term_ += byte_type(arc.ilabel);
// fst_utils::append(weight_, arc.weight);
// ++prefix;
//
// const auto weight = fst.Final(state = arc.nextstate);
// if (weight_t::One() != weight && weight_t::Zero() != weight) {
// cur_block_ = push_block(fst::Times(weight_, weight), prefix);
// } else if (fst_byte_builder::final == state) {
// cur_block_ = push_block(std::move(weight_), prefix);
// found = false;
// }
//
// // cache found arcs, we can reuse it in further seeks
// // avoiding relatively expensive FST lookups
// sstate_.emplace_back(state, arc.weight, cur_block_);
// }
// }
//
// assert(cur_block_);
// sstate_.resize(cur_block_->prefix());
// cur_block_->scan_to_block(term);
//
// return false;
//}
bool term_iterator::seek_to_block(const bytes_ref& term, size_t& prefix) {
assert(owner_->fst_);
assert(owner_->fst_ && owner_->fst_->GetImpl());
typedef fst_t::Weight weight_t;
const auto& fst = *owner_->fst_;
const auto& fst = *owner_->fst_->GetImpl();
prefix = 0; // number of current symbol to process
arc::stateid_t state = fst.Start(); // start state
@ -1035,33 +1097,60 @@ bool term_iterator::seek_to_block(const bytes_ref& term, size_t& prefix) {
}
} else {
cur_block_ = push_block(fst.Final(state), prefix);
if (!cur_block_) {
// stepped on invalid weight
return false;
}
}
term_.oversize(term.size());
term_.reset(prefix); // reset to common seek prefix
sstate_.resize(prefix); // remove invalid cached arcs
bool found = fst_byte_builder::final != state;
while (found && prefix < term.size()) {
while (fst_byte_builder::final != state && prefix < term.size()) {
matcher_.SetState(state);
if (found = matcher_.Find(term[prefix])) {
const auto& arc = matcher_.Value();
term_ += byte_type(arc.ilabel);
fst_utils::append(weight_, arc.weight);
++prefix;
const auto weight = fst.Final(state = arc.nextstate);
if (weight_t::One() != weight && weight_t::Zero() != weight) {
cur_block_ = push_block(fst::Times(weight_, weight), prefix);
} else if (fst_byte_builder::final == state) {
cur_block_ = push_block(std::move(weight_), prefix);
found = false;
}
// cache found arcs, we can reuse it in further seeks
// avoiding relatively expensive FST lookups
sstate_.emplace_back(state, arc.weight, cur_block_);
if (!matcher_.Find(term[prefix])) {
break;
}
const auto& arc = matcher_.Value();
term_ += byte_type(arc.ilabel); // aggregate arc label
weight_.PushBack(arc.weight); // aggregate arc weight
++prefix;
const auto& weight = fst.FinalRef(arc.nextstate);
if (!weight.Empty()) {
cur_block_ = push_block(fst::Times(weight_, weight), prefix);
if (!cur_block_) {
// stepped on invalid weight
return false;
}
} else if (fst_byte_builder::final == arc.nextstate) {
// ensure final state has no weight assigned
// the only case when it's wrong is degerated FST composed of only
// 'fst_byte_builder::final' state.
// in that case we'll never get there due to the loop condition above.
assert(fst.FinalRef(fst_byte_builder::final).Empty());
cur_block_ = push_block(std::move(weight_), prefix);
if (!cur_block_) {
// stepped on invalid weight
return false;
}
}
// cache found arcs, we can reuse it in further seeks
// avoiding relatively expensive FST lookups
sstate_.emplace_back(arc.nextstate, arc.weight, cur_block_);
// proceed to the next state
state = arc.nextstate;
}
assert(cur_block_);
@ -1077,7 +1166,9 @@ SeekResult term_iterator::seek_equal(const bytes_ref& term) {
return SeekResult::FOUND;
}
assert(cur_block_);
if (!cur_block_) {
return SeekResult::NOT_FOUND;
}
if (!block_meta::terms(cur_block_->meta())) {
// current block has no terms
@ -1096,7 +1187,9 @@ SeekResult term_iterator::seek_ge(const bytes_ref& term) {
}
UNUSED(prefix);
assert(cur_block_);
if (!cur_block_) {
return SeekResult::END;
}
cur_block_->load();
switch (cur_block_->scan_to_term(term)) {

View File

@ -26,6 +26,7 @@
#include "shared.hpp"
#include "utils/string.hpp"
#include "utils/noncopyable.hpp"
#if defined(_MSC_VER)
#pragma warning(disable : 4018)
@ -62,7 +63,12 @@
#include <boost/functional/hash.hpp>
NS_ROOT
//////////////////////////////////////////////////////////////////////////////
/// @class fst_builder
/// @brief helper class for building minimal acyclic subsequential transducers
/// algorithm is described there:
/// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.24.3698
//////////////////////////////////////////////////////////////////////////////
template<typename Char, typename Fst>
class fst_builder : util::noncopyable {
public:
@ -76,67 +82,64 @@ class fst_builder : util::noncopyable {
static const stateid_t final = 0;
fst_builder(fst_t& fst) : fst_(fst) {
explicit fst_builder(fst_t& fst) : fst_(fst) {
/* initialize final state */
fst_.AddState();
fst_.SetFinal(final, weight_t::One());
}
void add( const key_t& in, const weight_t& out ) {
/* inputs should be sorted */
assert( last_.empty() || last_ < in );
void add(const key_t& in, const weight_t& out) {
// inputs should be sorted
assert(last_.empty() || last_ < in);
if ( in.empty() ) {
start_out_ = fst::Times( start_out_, out );
if (in.empty()) {
start_out_ = fst::Times(start_out_, out);
return;
}
/* determine common prefix */
const size_t pref = 1 + prefix( last_, in );
// determine common prefix
const size_t pref = 1 + prefix(last_, in);
/* add states for current input */
add_states( in.size() );
// add states for current input
add_states(in.size());
/* minimize last word suffix */
// minimize last word suffix
minimize(pref);
/* add current word suffix */
for ( size_t i = pref; i <= in.size(); ++i ) {
states_[i - 1].arcs.emplace_back( in[i - 1], &states_[i] );
// add current word suffix
for (size_t i = pref; i <= in.size(); ++i) {
states_[i - 1].arcs.emplace_back(in[i - 1], &states_[i]);
}
const bool is_final = last_.size() != in.size() || pref != ( in.size() + 1 );
const bool is_final = last_.size() != in.size() || pref != (in.size() + 1);
weight_t output = out;
for ( size_t i = 1; i < pref; ++i ) {
weight_t output = out; // FIXME remove temporary variable
for (size_t i = 1; i < pref; ++i) {
state& s = states_[i];
state& p = states_[i - 1];
assert( p.arcs.size() );
assert( p.arcs.back().label == in[i - 1] );
assert(!p.arcs.empty() && p.arcs.back().label == in[i - 1]);
const weight_t& last_out = p.arcs.back().out;
if ( last_out != weight_t::One() ) {
const weight_t prefix = fst::Plus( last_out, out );
// otherwise we'll get invalid suffix (fst::kStringBad)
assert(prefix != weight_t::Zero());
if (last_out != weight_t::One()) {
const weight_t prefix = fst::Plus(last_out, out); // FIXME remove temporary variable
const weight_t suffix = fst::Divide(last_out, prefix, fst::DIVIDE_LEFT);
p.arcs.back().out = prefix;
for ( arc& a : s.arcs ) {
a.out = fst::Times( suffix, a.out );
for (arc& a : s.arcs) {
a.out = fst::Times(suffix, a.out);
}
if ( s.final ) {
s.out = fst::Times( suffix, s.out );
if (s.final) {
s.out = fst::Times(suffix, s.out);
}
output = fst::Divide( output, prefix, fst::DIVIDE_LEFT );
output = fst::Divide(output, prefix, fst::DIVIDE_LEFT);
}
}
if ( is_final ) {
if (is_final) {
// set final state
{
state& s = states_[in.size()];
@ -146,13 +149,13 @@ class fst_builder : util::noncopyable {
// set output
{
state& s = states_[pref - 1];
assert(s.arcs.size());
assert(!s.arcs.empty() && s.arcs.back().label == in[pref-1]);
s.arcs.back().out = std::move(output);
}
} else {
state& s = states_[in.size()];
assert( s.arcs.size() );
assert( s.arcs.back().label == in[pref - 1] );
assert(s.arcs.size());
assert(s.arcs.back().label == in[pref - 1]);
s.arcs.back().out = fst::Times(s.arcs.back().out, output);
}
@ -161,24 +164,24 @@ class fst_builder : util::noncopyable {
void finish() {
stateid_t start;
if ( states_.empty() ) {
if (states_.empty()) {
start = final;
} else {
/* minimize last word suffix */
minimize( 1 );
// minimize last word suffix
minimize(1);
start = states_map_.insert(states_[0], fst_);
}
/* set the start state */
fst_.SetStart( start );
fst_.SetFinal( start, start_out_ );
// set the start state
fst_.SetStart(start);
fst_.SetFinal(start, start_out_);
}
void reset() {
/* remove states */
// remove states
fst_.DeleteStates();
/* initialize final state */
// initialize final state
fst_.AddState();
fst_.SetFinal(final, weight_t::One());
@ -253,7 +256,7 @@ class fst_builder : util::noncopyable {
return seed;
}
std::vector< arc > arcs;
std::vector<arc> arcs;
weight_t out{ weight_t::One() };
bool final{ false };
}; // state
@ -272,11 +275,11 @@ class fst_builder : util::noncopyable {
const size_t mask = states_.size() - 1;
size_t pos = hash_value(s) % mask;
for ( ;; ++pos, pos %= mask ) { // TODO: maybe use quadratic probing here
if ( fst::kNoStateId == states_[pos] ) {
if (fst::kNoStateId == states_[pos]) {
states_[pos] = id = add_state(s, fst);
++count_;
if ( count_ > 2 * states_.size() / 3 ) {
if (count_ > 2 * states_.size() / 3) {
rehash(fst);
}
break;
@ -310,7 +313,7 @@ class fst_builder : util::noncopyable {
static size_t hash(stateid_t id, const fst_t& fst) {
size_t hash = 0;
for ( fst::ArcIterator< fst_t > it(fst, id); !it.Done(); it.Next()) {
for (fst::ArcIterator< fst_t > it(fst, id); !it.Done(); it.Next()) {
const arc_t& a = it.Value();
::boost::hash_combine(hash, a.ilabel);
::boost::hash_combine(hash, a.nextstate);
@ -330,7 +333,7 @@ class fst_builder : util::noncopyable {
size_t pos = hash(id, fst) % mask;
for (;;++pos, pos %= mask) { // TODO: maybe use quadratic probing here
if ( fst::kNoStateId == states[pos] ) {
if (fst::kNoStateId == states[pos] ) {
states[pos] = id;
break;
}
@ -354,11 +357,11 @@ class fst_builder : util::noncopyable {
}
// TODO: maybe use "buckets" here
std::vector< stateid_t > states_;
std::vector<stateid_t> states_;
size_t count_{};
}; // state_map
static size_t prefix( const key_t& lhs, const key_t& rhs ) {
static size_t prefix(const key_t& lhs, const key_t& rhs) {
size_t pref = 0;
const size_t max = std::min( lhs.size(), rhs.size() );
while ( pref < max && lhs[pref] == rhs[pref] ) {
@ -367,20 +370,20 @@ class fst_builder : util::noncopyable {
return pref;
}
void add_states( size_t size ) {
/* reserve size + 1 for root state */
void add_states(size_t size) {
// reserve size + 1 for root state
if ( states_.size() < ++size ) {
states_.resize(size);
}
}
void minimize( size_t pref ) {
assert( pref > 0 );
void minimize(size_t pref) {
assert(pref > 0);
for ( size_t i = last_.size(); i >= pref; --i ) {
for (size_t i = last_.size(); i >= pref; --i) {
state& s = states_[i];
state& p = states_[i - 1];
assert( !p.arcs.empty() );
assert(!p.arcs.empty());
p.arcs.back().id = states_map_.insert(s, fst_);
s.clear();
@ -392,346 +395,8 @@ class fst_builder : util::noncopyable {
weight_t start_out_; /* output for "empty" input */
bstring last_;
fst_t& fst_;
};
//template< typename Key, typename Weight >
//class fst_builder : util::noncopyable {
// public:
// typedef Key key_t;
// typedef Weight weight_t;
// typedef typename fst::ArcTpl< weight_t > arc_t;
// typedef typename fst::VectorFst< arc_t > fst_t;
// typedef typename arc_t::Label label_t;
// typedef typename fst_t::StateId stateid_t;
//
// static const stateid_t final = 0;
//
// fst_builder() : states_map_( &fst_ ) {
// /* initialize final state */
// fst_.AddState();
// fst_.SetFinal( final, weight_t::One() );
// }
//
// fst_t& fst() {
// return fst_;
// }
//
// void add( const key_t& in, const weight_t& out ) {
// /* inputs should be sorted */
// assert( last_.empty() || last_ < in );
//
// if ( in.empty() ) {
// start_out_ = fst::Times( start_out_, out );
// return;
// }
//
// /* determine common prefix */
// const size_t pref = 1 + prefix( last_, in );
//
// /* add states for current input */
// add_states( in.size() );
//
// /* minimize last word suffix */
// minimize( pref );
//
// /* add current word suffix */
// for ( size_t i = pref; i <= in.size(); ++i ) {
// states_[i - 1].arcs.emplace_back( in[i - 1], &states_[i] );
// }
//
// const bool is_final = last_.size() != in.size() || pref != ( in.size() + 1 );
//
// weight_t output = out;
// for ( size_t i = 1; i < pref; ++i ) {
// state& s = states_[i];
// state& p = states_[i - 1];
//
// assert( p.arcs.size() );
// assert( p.arcs.back().label == in[i - 1] );
//
// const weight_t& last_out = p.arcs.back().out;
// if ( last_out != weight_t::One() ) {
// const weight_t prefix = fst::Plus( last_out, out );
// // otherwise we'll get invalid suffix (fst::kStringBad)
// assert(prefix != weight_t::Zero());
// const weight_t suffix = fst::Divide(last_out, prefix, fst::DIVIDE_LEFT);
//
// p.arcs.back().out = prefix;
//
// for ( arc& a : s.arcs ) {
// a.out = fst::Times( suffix, a.out );
// }
//
// if ( s.final ) {
// s.out = fst::Times( suffix, s.out );
// }
//
// output = fst::Divide( output, prefix, fst::DIVIDE_LEFT );
// }
// }
//
// if ( is_final ) {
// /* set final state */
// state* s = &states_[in.size()];
// s->final = true;
// s->out = weight_t::One();
//
// /* set output */
// s = &states_[pref - 1];
// assert( s->arcs.size() );
// s->arcs.back().out = std::move( output );
// } else {
// state& s = states_[in.size()];
// assert( s.arcs.size() );
// assert( s.arcs.back().label == in[pref - 1] );
// s.arcs.back().out == fst::Times( s.arcs.back().out, output );
// }
//
// last_ = in;
// }
//
// void finish() {
// stateid_t start;
// if ( states_.empty() ) {
// start = final;
// } else {
// /* minimize last word suffix */
// minimize( 1 );
// start = states_map_.insert( states_[0] );
// }
//
// /* set the start state */
// fst_.SetStart( start );
// fst_.SetFinal( start, start_out_ );
// }
//
// void reset() {
// /* remove states */
// fst_.DeleteStates();
//
// /* initialize final state */
// fst_.AddState();
// fst_.SetFinal( final, weight_t::One() );
//
// states_map_.reset();
// last_ = key_t();
// start_out_ = weight_t();
// }
//
// private:
// struct state;
//
// struct arc : private util::noncopyable {
// arc(label_t label, state* target)
// : target(target),
// label(label) {
// }
//
// arc(arc&& rhs) NOEXCEPT
// : target(rhs.target),
// label(rhs.label),
// out(std::move(rhs.out)) {
// }
//
// bool operator==(const arc_t& rhs) const {
// return label == rhs.ilabel
// && id == rhs.nextstate
// && out == rhs.weight;
// }
//
// bool operator!=(const arc_t& rhs) const {
// return !(*this == rhs);
// }
//
// union {
// state* target;
// stateid_t id;
// };
// label_t label;
// weight_t out{ weight_t::One() };
// }; // arc
//
// struct state : private util::noncopyable {
// state() = default;
//
// state(state&& rhs) NOEXCEPT
// : arcs(std::move(rhs.arcs)),
// out(std::move(rhs.out)),
// final(rhs.final) {
// }
//
// void clear() {
// arcs.clear();
// final = false;
// out = weight_t::One();
// }
//
// std::vector< arc > arcs;
// weight_t out{ weight_t::One() };
// bool final{ false };
// }; // state
//
// class state_map {
// public:
// state_map( fst_t* fst )
// : states_( InitialSize, fst::kNoStateId ),
// count_( 0 ),
// fst_( fst ) {
// assert( fst_ );
// }
//
// stateid_t insert( const state& s ) {
// if ( s.arcs.empty() && s.final ) {
// return fst_builder::final;
// }
//
// stateid_t id;
// const size_t mask = states_.size() - 1;
// size_t pos = hash( s ) % mask;
// for ( ;; ++pos, pos %= mask ) { // TODO: maybe use quadratic probing here
// if ( fst::kNoStateId == states_[pos] ) {
// states_[pos] = id = add_state( s );
// ++count_;
//
// if ( count_ > 2 * states_.size() / 3 ) {
// rehash();
// }
// break;
// } else if ( equals( s, states_[pos] ) ) {
// id = states_[pos];
// break;
// }
// }
//
// return id;
// }
//
// private:
// bool equals( const state& lhs, stateid_t rhs ) {
// if ( fst_->NumArcs( rhs ) != lhs.arcs.size() ) {
// return false;
// }
//
// for ( fst::ArcIterator<fst_t> it( *fst_, rhs ); !it.Done(); it.Next() ) {
// if (lhs.arcs[it.Position()] != it.Value()) {
// return false;
// }
// }
//
// return true;
// }
//
// inline static size_t hash( size_t h, label_t label,
// stateid_t next,
// const weight_t& out ) {
// const size_t prime = 31;
// h = prime*h + label;
// h = prime*h + next;
// h = prime*h + out.Hash();
// return h;
// }
//
// static size_t hash( const state& s ) {
// size_t h = 0;
// for ( const arc& a : s.arcs ) {
// h = hash( h, a.label, a.id, a.out );
// }
//
// return h;
// }
//
// size_t hash( stateid_t id ) {
// size_t h = 0;
// for ( fst::ArcIterator< fst_t > it( *fst_, id ); !it.Done(); it.Next() ) {
// const arc_t& a = it.Value();
// h = hash( h, a.ilabel, a.nextstate, a.weight );
// }
//
// return h;
// }
//
// void rehash() {
// std::vector< stateid_t > states( states_.size() * 2, fst::kNoStateId );
// const size_t mask = states.size() - 1;
// for ( stateid_t id:states_ ) {
//
// if ( fst::kNoStateId == id ) {
// continue;
// }
//
// size_t pos = hash( id ) % mask;
// for ( ;; ++pos, pos %= mask ) { // TODO: maybe use quadratic probing here
// if ( fst::kNoStateId == states[pos] ) {
// states[pos] = id;
// break;
// }
// }
// }
//
// states_ = std::move( states );
// }
//
// stateid_t add_state( const state& s ) {
// const stateid_t id = fst_->AddState();
// if ( s.final ) {
// fst_->SetFinal( id, s.out );
// }
//
// for ( const arc& a : s.arcs ) {
// fst_->AddArc( id, arc_t( a.label, a.label, a.out, a.id ) );
// }
//
// return id;
// }
//
// void reset() {
// count_ = 0;
// std::fill( states_.begin(), states_.end(), fst::kNoStateId );
// }
//
// static const size_t InitialSize = 16;
//
// std::vector< stateid_t > states_; // TODO: maybe use "buckets" here
// size_t count_;
// fst_t* fst_;
// };
//
// state_map states_map_;
// std::vector< state > states_; /* current states */
// weight_t start_out_; /* output for "empty" input */
// key_t last_;
// fst_t fst_;
//
// static size_t prefix( const key_t& lhs, const key_t& rhs ) {
// size_t pref = 0;
// const size_t max = std::min( lhs.size(), rhs.size() );
// while ( pref < max && lhs[pref] == rhs[pref] ) {
// ++pref;
// }
// return pref;
// }
//
// void add_states( size_t size ) {
// /* reserve size + 1 for root state */
// if ( states_.size() < ++size ) {
// states_.resize(size);
// }
// }
//
// void minimize( size_t pref ) {
// assert( pref > 0 );
//
// for ( size_t i = last_.size(); i >= pref; --i ) {
// state& s = states_[i];
// state& p = states_[i - 1];
// assert( !p.arcs.empty() );
//
// p.arcs.back().id = states_map_.insert( s );
// s.clear();
// }
// }
//};
}; // fst_builder
NS_END
#endif
#endif

View File

@ -43,29 +43,28 @@
NS_BEGIN(fst)
template <typename Label>
class StringLeftWeight;
template <typename Label>
struct StringLeftWeightTraits {
inline static const StringLeftWeight<Label>& Zero();
inline static const StringLeftWeight<Label>& One();
inline static const StringLeftWeight<Label>& NoWeight();
inline static bool Member(const StringLeftWeight<Label>& weight);
}; // StringLeftWeightTraits
// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
template <typename Label>
class StringLeftWeight {
class StringLeftWeight : public StringLeftWeightTraits<Label> {
public:
typedef StringLeftWeight<Label> ReverseWeight;
typedef std::basic_string<Label> str_t;
typedef typename str_t::const_iterator iterator;
static const StringLeftWeight<Label>& Zero() {
static const StringLeftWeight<Label> zero((Label)kStringInfinity); // cast same as in FST
return zero;
}
static const StringLeftWeight<Label>& One() {
static const StringLeftWeight<Label> one;
return one;
}
static const StringLeftWeight<Label>& NoWeight() {
static const StringLeftWeight<Label> no_weight((Label)kStringBad); // cast same as in FST
return no_weight;
}
static const std::string& Type() {
static const std::string type = "left_string";
return type;
@ -111,7 +110,7 @@ class StringLeftWeight {
}
bool Member() const NOEXCEPT {
return NoWeight() != *this;
return StringLeftWeightTraits<Label>::Member(*this);
}
std::istream& Read(std::istream& strm) {
@ -212,10 +211,37 @@ class StringLeftWeight {
iterator begin() const NOEXCEPT { return str_.begin(); }
iterator end() const NOEXCEPT { return str_.end(); }
explicit operator irs::basic_string_ref<Label>() const NOEXCEPT {
return str_;
}
private:
str_t str_;
}; // StringLeftWeight
template <typename Label>
/*static*/ const StringLeftWeight<Label>& StringLeftWeightTraits<Label>::Zero() {
static const StringLeftWeight<Label> zero(static_cast<Label>(kStringInfinity)); // cast same as in FST
return zero;
}
template <typename Label>
/*static*/ const StringLeftWeight<Label>& StringLeftWeightTraits<Label>::One() {
static const StringLeftWeight<Label> one;
return one;
}
template <typename Label>
/*static*/ const StringLeftWeight<Label>& StringLeftWeightTraits<Label>::NoWeight() {
static const StringLeftWeight<Label> no_weight(static_cast<Label>(kStringBad)); // cast same as in FST
return no_weight;
}
template <typename Label>
/*static*/ bool StringLeftWeightTraits<Label>::Member(const StringLeftWeight<Label>& weight) {
return weight != NoWeight();
}
template <typename Label>
inline bool operator!=(
const StringLeftWeight<Label>& w1,
@ -341,7 +367,7 @@ inline StringLeftWeight<Label> DivideLeft(
const StringLeftWeight<Label>& lhs,
const StringLeftWeight<Label>& rhs) {
typedef StringLeftWeight<Label> Weight;
if (!lhs.Member() || !rhs.Member()) {
return Weight::NoWeight();
}
@ -356,6 +382,11 @@ inline StringLeftWeight<Label> DivideLeft(
return Weight();
}
assert(irs::starts_with(
irs::basic_string_ref<Label>(lhs),
irs::basic_string_ref<Label>(rhs)
));
return Weight(lhs.begin() + rhs.Size(), lhs.end());
}
@ -368,6 +399,136 @@ inline StringLeftWeight<Label> Divide(
return DivideLeft(lhs, rhs);
}
// -----------------------------------------------------------------------------
// --SECTION-- StringLeftWeight<irs::byte_type>
// -----------------------------------------------------------------------------
template <>
struct StringLeftWeightTraits<irs::byte_type> {
static const StringLeftWeight<irs::byte_type>& Zero() NOEXCEPT {
static const StringLeftWeight<irs::byte_type> zero;
return zero;
}
static const StringLeftWeight<irs::byte_type>& One() NOEXCEPT {
return Zero();
}
static const StringLeftWeight<irs::byte_type>& NoWeight() NOEXCEPT {
return Zero();
}
static bool Member(const StringLeftWeight<irs::byte_type>& weight) NOEXCEPT {
// always member
return true;
}
}; // StringLeftWeightTraits
inline std::ostream& operator<<(
std::ostream& strm,
const StringLeftWeight<irs::byte_type>& weight) {
if (weight.Empty()) {
return strm << "Epsilon";
}
auto begin = weight.begin();
const auto end = weight.end();
if (begin != end) {
strm << *begin;
for (++begin; begin != end; ++begin) {
strm << kStringSeparator << *begin;
}
}
return strm;
}
inline std::istream& operator>>(
std::istream& strm,
StringLeftWeight<irs::byte_type>& weight) {
std::string str;
strm >> str;
if (str == "Epsilon") {
weight = StringLeftWeight<irs::byte_type>::One();
} else {
weight.Clear();
char *p = nullptr;
for (const char *cs = str.c_str(); !p || *p != '\0'; cs = p + 1) {
const irs::byte_type label = strtoll(cs, &p, 10);
if (p == cs || (*p != 0 && *p != kStringSeparator)) {
strm.clear(std::ios::badbit);
break;
}
weight.PushBack(label);
}
}
return strm;
}
// Longest common prefix for left string semiring.
// For binary strings that's impossible to use
// Zero() or NoWeight() as they may interfere
// with real values
inline StringLeftWeight<irs::byte_type> Plus(
const StringLeftWeight<irs::byte_type>& lhs,
const StringLeftWeight<irs::byte_type>& rhs) {
typedef StringLeftWeight<irs::byte_type> Weight;
const auto* plhs = &lhs;
const auto* prhs = &rhs;
if (rhs.Size() > lhs.Size()) {
// enusre that 'prhs' is shorter than 'plhs'
// The behavior is undefined if the second range is shorter than the first range.
// (http://en.cppreference.com/w/cpp/algorithm/mismatch)
std::swap(plhs, prhs);
}
assert(prhs->Size() <= plhs->Size());
return Weight(
prhs->begin(),
std::mismatch(prhs->begin(), prhs->end(), plhs->begin()).first
);
}
// For binary strings that's impossible to use
// Zero() or NoWeight() as they may interfere
// with real values
inline StringLeftWeight<irs::byte_type> Times(
const StringLeftWeight<irs::byte_type>& lhs,
const StringLeftWeight<irs::byte_type>& rhs) {
typedef StringLeftWeight<irs::byte_type> Weight;
Weight product;
product.Reserve(lhs.Size() + rhs.Size());
product.PushBack(lhs.begin(), lhs.end());
product.PushBack(rhs.begin(), rhs.end());
return product;
}
// Left division in a left string semiring.
// For binary strings that's impossible to use
// Zero() or NoWeight() as they may interfere
// with real values
inline StringLeftWeight<irs::byte_type> DivideLeft(
const StringLeftWeight<irs::byte_type>& lhs,
const StringLeftWeight<irs::byte_type>& rhs) {
typedef StringLeftWeight<irs::byte_type> Weight;
if (rhs.Size() > lhs.Size()) {
return Weight();
}
assert(irs::starts_with(
irs::basic_string_ref<irs::byte_type>(lhs),
irs::basic_string_ref<irs::byte_type>(rhs)
));
return Weight(lhs.begin() + rhs.Size(), lhs.end());
}
NS_END // fst
#endif // IRESEARCH_FST_STRING_WEIGHT_H

View File

@ -859,6 +859,8 @@ class ImplToFst : public FST {
Weight Final(StateId s) const override { return impl_->Final(s); }
const Weight& FinalRef(StateId s) const { return impl_->FinalRef(s); }
size_t NumArcs(StateId s) const override { return impl_->NumArcs(s); }
size_t NumInputEpsilons(StateId s) const override {

View File

@ -56,6 +56,8 @@ class VectorState {
Weight Final() const { return final_; }
const Weight& FinalRef() const { return final_; }
size_t NumInputEpsilons() const { return niepsilons_; }
size_t NumOutputEpsilons() const { return noepsilons_; }
@ -148,6 +150,8 @@ class VectorFstBaseImpl : public FstImpl<typename S::Arc> {
Weight Final(StateId state) const { return states_[state]->Final(); }
const Weight& FinalRef(StateId state) const { return states_[state]->FinalRef(); }
StateId NumStates() const { return states_.size(); }
size_t NumArcs(StateId state) const { return states_[state]->NumArcs(); }
@ -508,9 +512,10 @@ class VectorFst : public ImplToMutableFst<internal::VectorFstImpl<S>> {
using ImplToMutableFst<Impl, MutableFst<Arc>>::ReserveArcs;
using ImplToMutableFst<Impl, MutableFst<Arc>>::ReserveStates;
using ImplToMutableFst<Impl, MutableFst<Arc>>::GetImpl;
private:
using ImplToMutableFst<Impl, MutableFst<Arc>>::GetImpl;
//using ImplToMutableFst<Impl, MutableFst<Arc>>::GetImpl;
using ImplToMutableFst<Impl, MutableFst<Arc>>::MutateCheck;
using ImplToMutableFst<Impl, MutableFst<Arc>>::SetImpl;