mirror of https://gitee.com/bigwinds/arangodb
some bugfixes for aggregators
This commit is contained in:
parent
9ee2646b53
commit
9a2d9924b9
|
@ -76,7 +76,7 @@ void AggregatorLength::reduce(AqlValue const&, TRI_document_collection_t const*)
|
|||
++count;
|
||||
}
|
||||
|
||||
AqlValue AggregatorLength::getValue() {
|
||||
AqlValue AggregatorLength::stealValue() {
|
||||
uint64_t copy = count;
|
||||
count = 0;
|
||||
return AqlValue(new Json(static_cast<double>(copy)));
|
||||
|
@ -100,7 +100,7 @@ void AggregatorMin::reduce(AqlValue const& cmpValue,
|
|||
}
|
||||
}
|
||||
|
||||
AqlValue AggregatorMin::getValue() {
|
||||
AqlValue AggregatorMin::stealValue() {
|
||||
AqlValue copy = value;
|
||||
value.erase();
|
||||
return copy;
|
||||
|
@ -124,7 +124,7 @@ void AggregatorMax::reduce(AqlValue const& cmpValue,
|
|||
}
|
||||
}
|
||||
|
||||
AqlValue AggregatorMax::getValue() {
|
||||
AqlValue AggregatorMax::stealValue() {
|
||||
AqlValue copy = value;
|
||||
value.erase();
|
||||
return copy;
|
||||
|
@ -152,7 +152,7 @@ void AggregatorSum::reduce(AqlValue const& cmpValue,
|
|||
invalid = true;
|
||||
}
|
||||
|
||||
AqlValue AggregatorSum::getValue() {
|
||||
AqlValue AggregatorSum::stealValue() {
|
||||
if (invalid || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
|
||||
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ void AggregatorAverage::reduce(AqlValue const& cmpValue,
|
|||
invalid = true;
|
||||
}
|
||||
|
||||
AqlValue AggregatorAverage::getValue() {
|
||||
AqlValue AggregatorAverage::stealValue() {
|
||||
if (invalid || count == 0 || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
|
||||
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ struct Aggregator {
|
|||
virtual char const* name() const = 0;
|
||||
virtual void reset() = 0;
|
||||
virtual void reduce(AqlValue const&, struct TRI_document_collection_t const*) = 0;
|
||||
virtual AqlValue getValue() = 0;
|
||||
virtual AqlValue stealValue() = 0;
|
||||
|
||||
static Aggregator* fromTypeString(triagens::arango::AqlTransaction*, std::string const&);
|
||||
static Aggregator* fromJson(triagens::arango::AqlTransaction*, triagens::basics::Json const&,
|
||||
|
@ -64,7 +64,7 @@ struct AggregatorLength final : public Aggregator {
|
|||
|
||||
void reset() override final;
|
||||
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
|
||||
AqlValue getValue() override final;
|
||||
AqlValue stealValue() override final;
|
||||
|
||||
uint64_t count;
|
||||
};
|
||||
|
@ -80,7 +80,7 @@ struct AggregatorMin final : public Aggregator {
|
|||
|
||||
void reset() override final;
|
||||
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
|
||||
AqlValue getValue() override final;
|
||||
AqlValue stealValue() override final;
|
||||
|
||||
AqlValue value;
|
||||
struct TRI_document_collection_t const* coll;
|
||||
|
@ -97,7 +97,7 @@ struct AggregatorMax final : public Aggregator {
|
|||
|
||||
void reset() override final;
|
||||
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
|
||||
AqlValue getValue() override final;
|
||||
AqlValue stealValue() override final;
|
||||
|
||||
AqlValue value;
|
||||
struct TRI_document_collection_t const* coll;
|
||||
|
@ -114,7 +114,7 @@ struct AggregatorSum final : public Aggregator {
|
|||
|
||||
void reset() override final;
|
||||
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
|
||||
AqlValue getValue() override final;
|
||||
AqlValue stealValue() override final;
|
||||
|
||||
double sum;
|
||||
bool invalid;
|
||||
|
@ -131,7 +131,7 @@ struct AggregatorAverage final : public Aggregator {
|
|||
|
||||
void reset() override final;
|
||||
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
|
||||
AqlValue getValue() override final;
|
||||
AqlValue stealValue() override final;
|
||||
|
||||
uint64_t count;
|
||||
double sum;
|
||||
|
|
|
@ -30,7 +30,6 @@ using Json = triagens::basics::Json;
|
|||
using JsonHelper = triagens::basics::JsonHelper;
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief create the block
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -30,8 +30,6 @@
|
|||
#include "Aql/Range.h"
|
||||
#include "Aql/types.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
struct TRI_document_collection_t;
|
||||
|
||||
namespace triagens {
|
||||
|
@ -58,7 +56,6 @@ namespace aql {
|
|||
|
||||
class AqlItemBlock {
|
||||
friend class AqlItemBlockManager;
|
||||
|
||||
|
||||
public:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -94,9 +91,6 @@ class AqlItemBlock {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
AqlValue const& getValueReference(size_t index, RegisterId varNr) const {
|
||||
if (_data.capacity() <= index * _nrRegs + varNr) {
|
||||
std::cout << "CAPACITY: " << _data.capacity() << ", INDEX: " << index << ", NRREGS: " << _nrRegs << ", VARNR: " << varNr << "\n";
|
||||
}
|
||||
TRI_ASSERT_EXPENSIVE(_data.capacity() > index * _nrRegs + varNr);
|
||||
return _data[index * _nrRegs + varNr];
|
||||
}
|
||||
|
|
|
@ -666,8 +666,16 @@ class Ast {
|
|||
|
||||
AstNode* nodeFromJson(TRI_json_t const*, bool);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief traverse the AST using a depth-first visitor
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static AstNode* traverseAndModify(AstNode*,
|
||||
std::function<AstNode*(AstNode*, void*)>,
|
||||
void*);
|
||||
|
||||
private:
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief make condition from example
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -786,14 +794,6 @@ class Ast {
|
|||
std::function<void(AstNode const*, void*)>,
|
||||
void*);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief traverse the AST using a depth-first visitor
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static AstNode* traverseAndModify(AstNode*,
|
||||
std::function<AstNode*(AstNode*, void*)>,
|
||||
void*);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief traverse the AST, using pre- and post-order visitors
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -27,8 +27,6 @@
|
|||
#include "Basics/Exceptions.h"
|
||||
#include "VocBase/vocbase.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
using namespace triagens::arango;
|
||||
using namespace triagens::aql;
|
||||
|
@ -576,7 +574,19 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
|
|||
|
||||
std::unordered_map<std::vector<AqlValue>, AggregateValuesType*, GroupKeyHash, GroupKeyEqual>
|
||||
allGroups(1024, GroupKeyHash(_trx, groupColls), GroupKeyEqual(_trx, groupColls));
|
||||
|
||||
|
||||
// cleanup function for group values
|
||||
auto cleanup = [&allGroups] () -> void {
|
||||
for (auto& it : allGroups) {
|
||||
for (auto& it2 : *(it.second)) {
|
||||
delete it2;
|
||||
}
|
||||
delete it.second;
|
||||
}
|
||||
};
|
||||
|
||||
// prevent memory leaks by always cleaning up the groups
|
||||
TRI_DEFER(cleanup());
|
||||
|
||||
auto buildResult = [&](AqlItemBlock const* src) {
|
||||
TRI_ASSERT(groupColls.size() == _groupRegisters.size());
|
||||
|
@ -617,13 +627,13 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
|
|||
size_t j = 0;
|
||||
for (auto const& r : *(it.second)) {
|
||||
// TODO: check if cloning is necessary
|
||||
result->setValue(row, _aggregateRegisters[j++].first, r->getValue());
|
||||
result->setValue(row, _aggregateRegisters[j++].first, r->stealValue());
|
||||
}
|
||||
|
||||
if (en->_count) {
|
||||
// set group count in result register
|
||||
// TODO: check if cloning is necessary
|
||||
result->setValue(row, _collectRegister, it.second->back()->getValue());
|
||||
result->setValue(row, _collectRegister, it.second->back()->stealValue());
|
||||
// int64_t value = (*(it.second))[0].toInt64();
|
||||
//result->setValue(row, _collectRegister, AqlValue(new Json(static_cast<double>(value))));
|
||||
}
|
||||
|
|
|
@ -233,10 +233,7 @@ std::vector<Variable const*> CollectNode::getVariablesUsedHere() const {
|
|||
|
||||
// copy result into vector
|
||||
std::vector<Variable const*> vv;
|
||||
vv.reserve(v.size());
|
||||
for (auto const& x : v) {
|
||||
vv.emplace_back(x);
|
||||
}
|
||||
vv.insert(vv.begin(), v.begin(), v.end());
|
||||
return vv;
|
||||
}
|
||||
|
||||
|
|
|
@ -244,9 +244,8 @@ class CollectNode : public ExecutionNode {
|
|||
|
||||
std::vector<Variable const*> getVariablesSetHere() const override final {
|
||||
std::vector<Variable const*> v;
|
||||
size_t const n =
|
||||
_groupVariables.size() + (_outVariable == nullptr ? 0 : 1);
|
||||
v.reserve(n);
|
||||
v.reserve(
|
||||
_groupVariables.size() + _aggregateVariables.size() + (_outVariable == nullptr ? 0 : 1));
|
||||
|
||||
for (auto const& p : _groupVariables) {
|
||||
v.emplace_back(p.first);
|
||||
|
|
|
@ -1077,10 +1077,10 @@ ExecutionNode* ExecutionPlan::fromNodeCollectCount(ExecutionNode* previous,
|
|||
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> const
|
||||
aggregateVariables{};
|
||||
|
||||
auto en = registerNode(
|
||||
new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
|
||||
auto collectNode = new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
|
||||
outVariable, std::vector<Variable const*>(),
|
||||
_ast->variables()->variables(false), true, false));
|
||||
_ast->variables()->variables(false), true, false);
|
||||
auto en = registerNode(collectNode);
|
||||
|
||||
return addDependency(previous, en);
|
||||
}
|
||||
|
@ -1096,6 +1096,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
|
|||
|
||||
auto options = createCollectOptions(node->getMember(0));
|
||||
|
||||
std::unordered_map<Variable const*, Variable const*> aliases;
|
||||
|
||||
// group variables
|
||||
std::vector<std::pair<Variable const*, Variable const*>> groupVariables;
|
||||
{
|
||||
|
@ -1127,6 +1129,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
|
|||
auto calc = createTemporaryCalculation(expression, previous);
|
||||
previous = calc;
|
||||
groupVariables.emplace_back(std::make_pair(v, getOutVariable(calc)));
|
||||
|
||||
aliases.emplace(v, groupVariables.back().second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1134,6 +1138,17 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
|
|||
// aggregate variables
|
||||
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> aggregateVariables;
|
||||
{
|
||||
auto variableReplacer = [&aliases, this] (AstNode* node, void*) -> AstNode* {
|
||||
if (node->type == NODE_TYPE_REFERENCE) {
|
||||
auto it = aliases.find(static_cast<Variable const*>(node->getData()));
|
||||
|
||||
if (it != aliases.end()) {
|
||||
return _ast->createNodeReference((*it).second);
|
||||
}
|
||||
}
|
||||
return node;
|
||||
};
|
||||
|
||||
auto list = node->getMember(2);
|
||||
size_t const numVars = list->numMembers();
|
||||
|
||||
|
@ -1156,21 +1171,34 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
|
|||
// operand is always a function call
|
||||
TRI_ASSERT(expression->type == NODE_TYPE_FCALL);
|
||||
|
||||
// function should have one argument (an array with the parameters)
|
||||
TRI_ASSERT(expression->numMembers() == 1);
|
||||
auto args = expression->getMember(0);
|
||||
// the number of arguments should also be one (note: this has been
|
||||
// validated before)
|
||||
TRI_ASSERT(expression->numMembers() == 1);
|
||||
|
||||
auto calc = createTemporaryCalculation(args->getMember(0), previous);
|
||||
previous = calc;
|
||||
|
||||
// build aggregator
|
||||
auto func = static_cast<Function*>(expression->getData());
|
||||
TRI_ASSERT(func != nullptr);
|
||||
|
||||
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName)));
|
||||
// function should have one argument (an array with the parameters)
|
||||
TRI_ASSERT(expression->numMembers() == 1);
|
||||
|
||||
auto args = expression->getMember(0);
|
||||
// the number of arguments should also be one (note: this has been
|
||||
// validated before)
|
||||
TRI_ASSERT(args->type == NODE_TYPE_ARRAY);
|
||||
TRI_ASSERT(args->numMembers() == 1);
|
||||
|
||||
auto arg = args->getMember(0);
|
||||
arg = Ast::traverseAndModify(arg, variableReplacer, nullptr);
|
||||
|
||||
|
||||
if (arg->type == NODE_TYPE_REFERENCE) {
|
||||
// operand is a variable
|
||||
auto e = static_cast<Variable*>(arg->getData());
|
||||
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(e, func->externalName)));
|
||||
}
|
||||
else {
|
||||
auto calc = createTemporaryCalculation(arg, previous);
|
||||
previous = calc;
|
||||
|
||||
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue