1
0
Fork 0

some bugfixes for aggregators

This commit is contained in:
jsteemann 2016-01-11 23:53:03 +01:00
parent 9ee2646b53
commit 9a2d9924b9
9 changed files with 80 additions and 53 deletions

View File

@ -76,7 +76,7 @@ void AggregatorLength::reduce(AqlValue const&, TRI_document_collection_t const*)
++count; ++count;
} }
AqlValue AggregatorLength::getValue() { AqlValue AggregatorLength::stealValue() {
uint64_t copy = count; uint64_t copy = count;
count = 0; count = 0;
return AqlValue(new Json(static_cast<double>(copy))); return AqlValue(new Json(static_cast<double>(copy)));
@ -100,7 +100,7 @@ void AggregatorMin::reduce(AqlValue const& cmpValue,
} }
} }
AqlValue AggregatorMin::getValue() { AqlValue AggregatorMin::stealValue() {
AqlValue copy = value; AqlValue copy = value;
value.erase(); value.erase();
return copy; return copy;
@ -124,7 +124,7 @@ void AggregatorMax::reduce(AqlValue const& cmpValue,
} }
} }
AqlValue AggregatorMax::getValue() { AqlValue AggregatorMax::stealValue() {
AqlValue copy = value; AqlValue copy = value;
value.erase(); value.erase();
return copy; return copy;
@ -152,7 +152,7 @@ void AggregatorSum::reduce(AqlValue const& cmpValue,
invalid = true; invalid = true;
} }
AqlValue AggregatorSum::getValue() { AqlValue AggregatorSum::stealValue() {
if (invalid || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) { if (invalid || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null)); return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
} }
@ -184,7 +184,7 @@ void AggregatorAverage::reduce(AqlValue const& cmpValue,
invalid = true; invalid = true;
} }
AqlValue AggregatorAverage::getValue() { AqlValue AggregatorAverage::stealValue() {
if (invalid || count == 0 || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) { if (invalid || count == 0 || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null)); return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
} }

View File

@ -44,7 +44,7 @@ struct Aggregator {
virtual char const* name() const = 0; virtual char const* name() const = 0;
virtual void reset() = 0; virtual void reset() = 0;
virtual void reduce(AqlValue const&, struct TRI_document_collection_t const*) = 0; virtual void reduce(AqlValue const&, struct TRI_document_collection_t const*) = 0;
virtual AqlValue getValue() = 0; virtual AqlValue stealValue() = 0;
static Aggregator* fromTypeString(triagens::arango::AqlTransaction*, std::string const&); static Aggregator* fromTypeString(triagens::arango::AqlTransaction*, std::string const&);
static Aggregator* fromJson(triagens::arango::AqlTransaction*, triagens::basics::Json const&, static Aggregator* fromJson(triagens::arango::AqlTransaction*, triagens::basics::Json const&,
@ -64,7 +64,7 @@ struct AggregatorLength final : public Aggregator {
void reset() override final; void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final; AqlValue stealValue() override final;
uint64_t count; uint64_t count;
}; };
@ -80,7 +80,7 @@ struct AggregatorMin final : public Aggregator {
void reset() override final; void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final; AqlValue stealValue() override final;
AqlValue value; AqlValue value;
struct TRI_document_collection_t const* coll; struct TRI_document_collection_t const* coll;
@ -97,7 +97,7 @@ struct AggregatorMax final : public Aggregator {
void reset() override final; void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final; AqlValue stealValue() override final;
AqlValue value; AqlValue value;
struct TRI_document_collection_t const* coll; struct TRI_document_collection_t const* coll;
@ -114,7 +114,7 @@ struct AggregatorSum final : public Aggregator {
void reset() override final; void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final; AqlValue stealValue() override final;
double sum; double sum;
bool invalid; bool invalid;
@ -131,7 +131,7 @@ struct AggregatorAverage final : public Aggregator {
void reset() override final; void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final; AqlValue stealValue() override final;
uint64_t count; uint64_t count;
double sum; double sum;

View File

@ -30,7 +30,6 @@ using Json = triagens::basics::Json;
using JsonHelper = triagens::basics::JsonHelper; using JsonHelper = triagens::basics::JsonHelper;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/// @brief create the block /// @brief create the block
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -30,8 +30,6 @@
#include "Aql/Range.h" #include "Aql/Range.h"
#include "Aql/types.h" #include "Aql/types.h"
#include <iostream>
struct TRI_document_collection_t; struct TRI_document_collection_t;
namespace triagens { namespace triagens {
@ -59,7 +57,6 @@ namespace aql {
class AqlItemBlock { class AqlItemBlock {
friend class AqlItemBlockManager; friend class AqlItemBlockManager;
public: public:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// @brief create the block /// @brief create the block
@ -94,9 +91,6 @@ class AqlItemBlock {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
AqlValue const& getValueReference(size_t index, RegisterId varNr) const { AqlValue const& getValueReference(size_t index, RegisterId varNr) const {
if (_data.capacity() <= index * _nrRegs + varNr) {
std::cout << "CAPACITY: " << _data.capacity() << ", INDEX: " << index << ", NRREGS: " << _nrRegs << ", VARNR: " << varNr << "\n";
}
TRI_ASSERT_EXPENSIVE(_data.capacity() > index * _nrRegs + varNr); TRI_ASSERT_EXPENSIVE(_data.capacity() > index * _nrRegs + varNr);
return _data[index * _nrRegs + varNr]; return _data[index * _nrRegs + varNr];
} }

View File

@ -666,8 +666,16 @@ class Ast {
AstNode* nodeFromJson(TRI_json_t const*, bool); AstNode* nodeFromJson(TRI_json_t const*, bool);
//////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST using a depth-first visitor
//////////////////////////////////////////////////////////////////////////////
static AstNode* traverseAndModify(AstNode*,
std::function<AstNode*(AstNode*, void*)>,
void*);
private: private:
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// @brief make condition from example /// @brief make condition from example
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -786,14 +794,6 @@ class Ast {
std::function<void(AstNode const*, void*)>, std::function<void(AstNode const*, void*)>,
void*); void*);
//////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST using a depth-first visitor
//////////////////////////////////////////////////////////////////////////////
static AstNode* traverseAndModify(AstNode*,
std::function<AstNode*(AstNode*, void*)>,
void*);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST, using pre- and post-order visitors /// @brief traverse the AST, using pre- and post-order visitors
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -27,8 +27,6 @@
#include "Basics/Exceptions.h" #include "Basics/Exceptions.h"
#include "VocBase/vocbase.h" #include "VocBase/vocbase.h"
#include <iostream>
using namespace std; using namespace std;
using namespace triagens::arango; using namespace triagens::arango;
using namespace triagens::aql; using namespace triagens::aql;
@ -577,6 +575,18 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
std::unordered_map<std::vector<AqlValue>, AggregateValuesType*, GroupKeyHash, GroupKeyEqual> std::unordered_map<std::vector<AqlValue>, AggregateValuesType*, GroupKeyHash, GroupKeyEqual>
allGroups(1024, GroupKeyHash(_trx, groupColls), GroupKeyEqual(_trx, groupColls)); allGroups(1024, GroupKeyHash(_trx, groupColls), GroupKeyEqual(_trx, groupColls));
// cleanup function for group values
auto cleanup = [&allGroups] () -> void {
for (auto& it : allGroups) {
for (auto& it2 : *(it.second)) {
delete it2;
}
delete it.second;
}
};
// prevent memory leaks by always cleaning up the groups
TRI_DEFER(cleanup());
auto buildResult = [&](AqlItemBlock const* src) { auto buildResult = [&](AqlItemBlock const* src) {
TRI_ASSERT(groupColls.size() == _groupRegisters.size()); TRI_ASSERT(groupColls.size() == _groupRegisters.size());
@ -617,13 +627,13 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
size_t j = 0; size_t j = 0;
for (auto const& r : *(it.second)) { for (auto const& r : *(it.second)) {
// TODO: check if cloning is necessary // TODO: check if cloning is necessary
result->setValue(row, _aggregateRegisters[j++].first, r->getValue()); result->setValue(row, _aggregateRegisters[j++].first, r->stealValue());
} }
if (en->_count) { if (en->_count) {
// set group count in result register // set group count in result register
// TODO: check if cloning is necessary // TODO: check if cloning is necessary
result->setValue(row, _collectRegister, it.second->back()->getValue()); result->setValue(row, _collectRegister, it.second->back()->stealValue());
// int64_t value = (*(it.second))[0].toInt64(); // int64_t value = (*(it.second))[0].toInt64();
//result->setValue(row, _collectRegister, AqlValue(new Json(static_cast<double>(value)))); //result->setValue(row, _collectRegister, AqlValue(new Json(static_cast<double>(value))));
} }

View File

@ -233,10 +233,7 @@ std::vector<Variable const*> CollectNode::getVariablesUsedHere() const {
// copy result into vector // copy result into vector
std::vector<Variable const*> vv; std::vector<Variable const*> vv;
vv.reserve(v.size()); vv.insert(vv.begin(), v.begin(), v.end());
for (auto const& x : v) {
vv.emplace_back(x);
}
return vv; return vv;
} }

View File

@ -244,9 +244,8 @@ class CollectNode : public ExecutionNode {
std::vector<Variable const*> getVariablesSetHere() const override final { std::vector<Variable const*> getVariablesSetHere() const override final {
std::vector<Variable const*> v; std::vector<Variable const*> v;
size_t const n = v.reserve(
_groupVariables.size() + (_outVariable == nullptr ? 0 : 1); _groupVariables.size() + _aggregateVariables.size() + (_outVariable == nullptr ? 0 : 1));
v.reserve(n);
for (auto const& p : _groupVariables) { for (auto const& p : _groupVariables) {
v.emplace_back(p.first); v.emplace_back(p.first);

View File

@ -1077,10 +1077,10 @@ ExecutionNode* ExecutionPlan::fromNodeCollectCount(ExecutionNode* previous,
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> const std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> const
aggregateVariables{}; aggregateVariables{};
auto en = registerNode( auto collectNode = new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
outVariable, std::vector<Variable const*>(), outVariable, std::vector<Variable const*>(),
_ast->variables()->variables(false), true, false)); _ast->variables()->variables(false), true, false);
auto en = registerNode(collectNode);
return addDependency(previous, en); return addDependency(previous, en);
} }
@ -1096,6 +1096,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
auto options = createCollectOptions(node->getMember(0)); auto options = createCollectOptions(node->getMember(0));
std::unordered_map<Variable const*, Variable const*> aliases;
// group variables // group variables
std::vector<std::pair<Variable const*, Variable const*>> groupVariables; std::vector<std::pair<Variable const*, Variable const*>> groupVariables;
{ {
@ -1127,6 +1129,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
auto calc = createTemporaryCalculation(expression, previous); auto calc = createTemporaryCalculation(expression, previous);
previous = calc; previous = calc;
groupVariables.emplace_back(std::make_pair(v, getOutVariable(calc))); groupVariables.emplace_back(std::make_pair(v, getOutVariable(calc)));
aliases.emplace(v, groupVariables.back().second);
} }
} }
} }
@ -1134,6 +1138,17 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
// aggregate variables // aggregate variables
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> aggregateVariables; std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> aggregateVariables;
{ {
auto variableReplacer = [&aliases, this] (AstNode* node, void*) -> AstNode* {
if (node->type == NODE_TYPE_REFERENCE) {
auto it = aliases.find(static_cast<Variable const*>(node->getData()));
if (it != aliases.end()) {
return _ast->createNodeReference((*it).second);
}
}
return node;
};
auto list = node->getMember(2); auto list = node->getMember(2);
size_t const numVars = list->numMembers(); size_t const numVars = list->numMembers();
@ -1156,23 +1171,36 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
// operand is always a function call // operand is always a function call
TRI_ASSERT(expression->type == NODE_TYPE_FCALL); TRI_ASSERT(expression->type == NODE_TYPE_FCALL);
// function should have one argument (an array with the parameters)
TRI_ASSERT(expression->numMembers() == 1);
auto args = expression->getMember(0);
// the number of arguments should also be one (note: this has been
// validated before)
TRI_ASSERT(expression->numMembers() == 1);
auto calc = createTemporaryCalculation(args->getMember(0), previous);
previous = calc;
// build aggregator // build aggregator
auto func = static_cast<Function*>(expression->getData()); auto func = static_cast<Function*>(expression->getData());
TRI_ASSERT(func != nullptr); TRI_ASSERT(func != nullptr);
// function should have one argument (an array with the parameters)
TRI_ASSERT(expression->numMembers() == 1);
auto args = expression->getMember(0);
// the number of arguments should also be one (note: this has been
// validated before)
TRI_ASSERT(args->type == NODE_TYPE_ARRAY);
TRI_ASSERT(args->numMembers() == 1);
auto arg = args->getMember(0);
arg = Ast::traverseAndModify(arg, variableReplacer, nullptr);
if (arg->type == NODE_TYPE_REFERENCE) {
// operand is a variable
auto e = static_cast<Variable*>(arg->getData());
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(e, func->externalName)));
}
else {
auto calc = createTemporaryCalculation(arg, previous);
previous = calc;
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName))); aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName)));
} }
} }
}
auto collectNode = new CollectNode( auto collectNode = new CollectNode(
this, nextId(), options, groupVariables, aggregateVariables, nullptr, this, nextId(), options, groupVariables, aggregateVariables, nullptr,