1
0
Fork 0

some bugfixes for aggregators

This commit is contained in:
jsteemann 2016-01-11 23:53:03 +01:00
parent 9ee2646b53
commit 9a2d9924b9
9 changed files with 80 additions and 53 deletions

View File

@ -76,7 +76,7 @@ void AggregatorLength::reduce(AqlValue const&, TRI_document_collection_t const*)
++count;
}
AqlValue AggregatorLength::getValue() {
AqlValue AggregatorLength::stealValue() {
uint64_t copy = count;
count = 0;
return AqlValue(new Json(static_cast<double>(copy)));
@ -100,7 +100,7 @@ void AggregatorMin::reduce(AqlValue const& cmpValue,
}
}
AqlValue AggregatorMin::getValue() {
AqlValue AggregatorMin::stealValue() {
AqlValue copy = value;
value.erase();
return copy;
@ -124,7 +124,7 @@ void AggregatorMax::reduce(AqlValue const& cmpValue,
}
}
AqlValue AggregatorMax::getValue() {
AqlValue AggregatorMax::stealValue() {
AqlValue copy = value;
value.erase();
return copy;
@ -152,7 +152,7 @@ void AggregatorSum::reduce(AqlValue const& cmpValue,
invalid = true;
}
AqlValue AggregatorSum::getValue() {
AqlValue AggregatorSum::stealValue() {
if (invalid || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
}
@ -184,7 +184,7 @@ void AggregatorAverage::reduce(AqlValue const& cmpValue,
invalid = true;
}
AqlValue AggregatorAverage::getValue() {
AqlValue AggregatorAverage::stealValue() {
if (invalid || count == 0 || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) {
return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null));
}

View File

@ -44,7 +44,7 @@ struct Aggregator {
virtual char const* name() const = 0;
virtual void reset() = 0;
virtual void reduce(AqlValue const&, struct TRI_document_collection_t const*) = 0;
virtual AqlValue getValue() = 0;
virtual AqlValue stealValue() = 0;
static Aggregator* fromTypeString(triagens::arango::AqlTransaction*, std::string const&);
static Aggregator* fromJson(triagens::arango::AqlTransaction*, triagens::basics::Json const&,
@ -64,7 +64,7 @@ struct AggregatorLength final : public Aggregator {
void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final;
AqlValue stealValue() override final;
uint64_t count;
};
@ -80,7 +80,7 @@ struct AggregatorMin final : public Aggregator {
void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final;
AqlValue stealValue() override final;
AqlValue value;
struct TRI_document_collection_t const* coll;
@ -97,7 +97,7 @@ struct AggregatorMax final : public Aggregator {
void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final;
AqlValue stealValue() override final;
AqlValue value;
struct TRI_document_collection_t const* coll;
@ -114,7 +114,7 @@ struct AggregatorSum final : public Aggregator {
void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final;
AqlValue stealValue() override final;
double sum;
bool invalid;
@ -131,7 +131,7 @@ struct AggregatorAverage final : public Aggregator {
void reset() override final;
void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final;
AqlValue getValue() override final;
AqlValue stealValue() override final;
uint64_t count;
double sum;

View File

@ -30,7 +30,6 @@ using Json = triagens::basics::Json;
using JsonHelper = triagens::basics::JsonHelper;
////////////////////////////////////////////////////////////////////////////////
/// @brief create the block
////////////////////////////////////////////////////////////////////////////////

View File

@ -30,8 +30,6 @@
#include "Aql/Range.h"
#include "Aql/types.h"
#include <iostream>
struct TRI_document_collection_t;
namespace triagens {
@ -58,7 +56,6 @@ namespace aql {
class AqlItemBlock {
friend class AqlItemBlockManager;
public:
//////////////////////////////////////////////////////////////////////////////
@ -94,9 +91,6 @@ class AqlItemBlock {
//////////////////////////////////////////////////////////////////////////////
AqlValue const& getValueReference(size_t index, RegisterId varNr) const {
if (_data.capacity() <= index * _nrRegs + varNr) {
std::cout << "CAPACITY: " << _data.capacity() << ", INDEX: " << index << ", NRREGS: " << _nrRegs << ", VARNR: " << varNr << "\n";
}
TRI_ASSERT_EXPENSIVE(_data.capacity() > index * _nrRegs + varNr);
return _data[index * _nrRegs + varNr];
}

View File

@ -666,8 +666,16 @@ class Ast {
AstNode* nodeFromJson(TRI_json_t const*, bool);
//////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST using a depth-first visitor
//////////////////////////////////////////////////////////////////////////////
static AstNode* traverseAndModify(AstNode*,
std::function<AstNode*(AstNode*, void*)>,
void*);
private:
//////////////////////////////////////////////////////////////////////////////
/// @brief make condition from example
//////////////////////////////////////////////////////////////////////////////
@ -786,14 +794,6 @@ class Ast {
std::function<void(AstNode const*, void*)>,
void*);
//////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST using a depth-first visitor
//////////////////////////////////////////////////////////////////////////////
static AstNode* traverseAndModify(AstNode*,
std::function<AstNode*(AstNode*, void*)>,
void*);
//////////////////////////////////////////////////////////////////////////////
/// @brief traverse the AST, using pre- and post-order visitors
//////////////////////////////////////////////////////////////////////////////

View File

@ -27,8 +27,6 @@
#include "Basics/Exceptions.h"
#include "VocBase/vocbase.h"
#include <iostream>
using namespace std;
using namespace triagens::arango;
using namespace triagens::aql;
@ -576,7 +574,19 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
std::unordered_map<std::vector<AqlValue>, AggregateValuesType*, GroupKeyHash, GroupKeyEqual>
allGroups(1024, GroupKeyHash(_trx, groupColls), GroupKeyEqual(_trx, groupColls));
// cleanup function for group values
auto cleanup = [&allGroups] () -> void {
for (auto& it : allGroups) {
for (auto& it2 : *(it.second)) {
delete it2;
}
delete it.second;
}
};
// prevent memory leaks by always cleaning up the groups
TRI_DEFER(cleanup());
auto buildResult = [&](AqlItemBlock const* src) {
TRI_ASSERT(groupColls.size() == _groupRegisters.size());
@ -617,13 +627,13 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost,
size_t j = 0;
for (auto const& r : *(it.second)) {
// TODO: check if cloning is necessary
result->setValue(row, _aggregateRegisters[j++].first, r->getValue());
result->setValue(row, _aggregateRegisters[j++].first, r->stealValue());
}
if (en->_count) {
// set group count in result register
// TODO: check if cloning is necessary
result->setValue(row, _collectRegister, it.second->back()->getValue());
result->setValue(row, _collectRegister, it.second->back()->stealValue());
// int64_t value = (*(it.second))[0].toInt64();
//result->setValue(row, _collectRegister, AqlValue(new Json(static_cast<double>(value))));
}

View File

@ -233,10 +233,7 @@ std::vector<Variable const*> CollectNode::getVariablesUsedHere() const {
// copy result into vector
std::vector<Variable const*> vv;
vv.reserve(v.size());
for (auto const& x : v) {
vv.emplace_back(x);
}
vv.insert(vv.begin(), v.begin(), v.end());
return vv;
}

View File

@ -244,9 +244,8 @@ class CollectNode : public ExecutionNode {
std::vector<Variable const*> getVariablesSetHere() const override final {
std::vector<Variable const*> v;
size_t const n =
_groupVariables.size() + (_outVariable == nullptr ? 0 : 1);
v.reserve(n);
v.reserve(
_groupVariables.size() + _aggregateVariables.size() + (_outVariable == nullptr ? 0 : 1));
for (auto const& p : _groupVariables) {
v.emplace_back(p.first);

View File

@ -1077,10 +1077,10 @@ ExecutionNode* ExecutionPlan::fromNodeCollectCount(ExecutionNode* previous,
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> const
aggregateVariables{};
auto en = registerNode(
new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
auto collectNode = new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr,
outVariable, std::vector<Variable const*>(),
_ast->variables()->variables(false), true, false));
_ast->variables()->variables(false), true, false);
auto en = registerNode(collectNode);
return addDependency(previous, en);
}
@ -1096,6 +1096,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
auto options = createCollectOptions(node->getMember(0));
std::unordered_map<Variable const*, Variable const*> aliases;
// group variables
std::vector<std::pair<Variable const*, Variable const*>> groupVariables;
{
@ -1127,6 +1129,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
auto calc = createTemporaryCalculation(expression, previous);
previous = calc;
groupVariables.emplace_back(std::make_pair(v, getOutVariable(calc)));
aliases.emplace(v, groupVariables.back().second);
}
}
}
@ -1134,6 +1138,17 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
// aggregate variables
std::vector<std::pair<Variable const*, std::pair<Variable const*, std::string>>> aggregateVariables;
{
auto variableReplacer = [&aliases, this] (AstNode* node, void*) -> AstNode* {
if (node->type == NODE_TYPE_REFERENCE) {
auto it = aliases.find(static_cast<Variable const*>(node->getData()));
if (it != aliases.end()) {
return _ast->createNodeReference((*it).second);
}
}
return node;
};
auto list = node->getMember(2);
size_t const numVars = list->numMembers();
@ -1156,21 +1171,34 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous,
// operand is always a function call
TRI_ASSERT(expression->type == NODE_TYPE_FCALL);
// function should have one argument (an array with the parameters)
TRI_ASSERT(expression->numMembers() == 1);
auto args = expression->getMember(0);
// the number of arguments should also be one (note: this has been
// validated before)
TRI_ASSERT(expression->numMembers() == 1);
auto calc = createTemporaryCalculation(args->getMember(0), previous);
previous = calc;
// build aggregator
auto func = static_cast<Function*>(expression->getData());
TRI_ASSERT(func != nullptr);
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName)));
// function should have one argument (an array with the parameters)
TRI_ASSERT(expression->numMembers() == 1);
auto args = expression->getMember(0);
// the number of arguments should also be one (note: this has been
// validated before)
TRI_ASSERT(args->type == NODE_TYPE_ARRAY);
TRI_ASSERT(args->numMembers() == 1);
auto arg = args->getMember(0);
arg = Ast::traverseAndModify(arg, variableReplacer, nullptr);
if (arg->type == NODE_TYPE_REFERENCE) {
// operand is a variable
auto e = static_cast<Variable*>(arg->getData());
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(e, func->externalName)));
}
else {
auto calc = createTemporaryCalculation(arg, previous);
previous = calc;
aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName)));
}
}
}