diff --git a/arangod/Aql/Aggregator.cpp b/arangod/Aql/Aggregator.cpp index f658ca97bf..a3fa84740e 100644 --- a/arangod/Aql/Aggregator.cpp +++ b/arangod/Aql/Aggregator.cpp @@ -76,7 +76,7 @@ void AggregatorLength::reduce(AqlValue const&, TRI_document_collection_t const*) ++count; } -AqlValue AggregatorLength::getValue() { +AqlValue AggregatorLength::stealValue() { uint64_t copy = count; count = 0; return AqlValue(new Json(static_cast(copy))); @@ -100,7 +100,7 @@ void AggregatorMin::reduce(AqlValue const& cmpValue, } } -AqlValue AggregatorMin::getValue() { +AqlValue AggregatorMin::stealValue() { AqlValue copy = value; value.erase(); return copy; @@ -124,7 +124,7 @@ void AggregatorMax::reduce(AqlValue const& cmpValue, } } -AqlValue AggregatorMax::getValue() { +AqlValue AggregatorMax::stealValue() { AqlValue copy = value; value.erase(); return copy; @@ -152,7 +152,7 @@ void AggregatorSum::reduce(AqlValue const& cmpValue, invalid = true; } -AqlValue AggregatorSum::getValue() { +AqlValue AggregatorSum::stealValue() { if (invalid || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) { return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null)); } @@ -184,7 +184,7 @@ void AggregatorAverage::reduce(AqlValue const& cmpValue, invalid = true; } -AqlValue AggregatorAverage::getValue() { +AqlValue AggregatorAverage::stealValue() { if (invalid || count == 0 || std::isnan(sum) || sum == HUGE_VAL || sum == -HUGE_VAL) { return AqlValue(new triagens::basics::Json(triagens::basics::Json::Null)); } diff --git a/arangod/Aql/Aggregator.h b/arangod/Aql/Aggregator.h index f7dc4488f3..95102f8750 100644 --- a/arangod/Aql/Aggregator.h +++ b/arangod/Aql/Aggregator.h @@ -44,7 +44,7 @@ struct Aggregator { virtual char const* name() const = 0; virtual void reset() = 0; virtual void reduce(AqlValue const&, struct TRI_document_collection_t const*) = 0; - virtual AqlValue getValue() = 0; + virtual AqlValue stealValue() = 0; static Aggregator* fromTypeString(triagens::arango::AqlTransaction*, std::string const&); static Aggregator* fromJson(triagens::arango::AqlTransaction*, triagens::basics::Json const&, @@ -64,7 +64,7 @@ struct AggregatorLength final : public Aggregator { void reset() override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; - AqlValue getValue() override final; + AqlValue stealValue() override final; uint64_t count; }; @@ -80,7 +80,7 @@ struct AggregatorMin final : public Aggregator { void reset() override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; - AqlValue getValue() override final; + AqlValue stealValue() override final; AqlValue value; struct TRI_document_collection_t const* coll; @@ -97,7 +97,7 @@ struct AggregatorMax final : public Aggregator { void reset() override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; - AqlValue getValue() override final; + AqlValue stealValue() override final; AqlValue value; struct TRI_document_collection_t const* coll; @@ -114,7 +114,7 @@ struct AggregatorSum final : public Aggregator { void reset() override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; - AqlValue getValue() override final; + AqlValue stealValue() override final; double sum; bool invalid; @@ -131,7 +131,7 @@ struct AggregatorAverage final : public Aggregator { void reset() override final; void reduce(AqlValue const&, struct TRI_document_collection_t const*) override final; - AqlValue getValue() override final; + AqlValue stealValue() override final; uint64_t count; double sum; diff --git a/arangod/Aql/AqlItemBlock.cpp b/arangod/Aql/AqlItemBlock.cpp index 57114b5e70..913c9aafcd 100644 --- a/arangod/Aql/AqlItemBlock.cpp +++ b/arangod/Aql/AqlItemBlock.cpp @@ -30,7 +30,6 @@ using Json = triagens::basics::Json; using JsonHelper = triagens::basics::JsonHelper; - //////////////////////////////////////////////////////////////////////////////// /// @brief create the block //////////////////////////////////////////////////////////////////////////////// diff --git a/arangod/Aql/AqlItemBlock.h b/arangod/Aql/AqlItemBlock.h index 10856211ea..8aa68f3001 100644 --- a/arangod/Aql/AqlItemBlock.h +++ b/arangod/Aql/AqlItemBlock.h @@ -30,8 +30,6 @@ #include "Aql/Range.h" #include "Aql/types.h" -#include - struct TRI_document_collection_t; namespace triagens { @@ -58,7 +56,6 @@ namespace aql { class AqlItemBlock { friend class AqlItemBlockManager; - public: ////////////////////////////////////////////////////////////////////////////// @@ -94,9 +91,6 @@ class AqlItemBlock { ////////////////////////////////////////////////////////////////////////////// AqlValue const& getValueReference(size_t index, RegisterId varNr) const { - if (_data.capacity() <= index * _nrRegs + varNr) { - std::cout << "CAPACITY: " << _data.capacity() << ", INDEX: " << index << ", NRREGS: " << _nrRegs << ", VARNR: " << varNr << "\n"; - } TRI_ASSERT_EXPENSIVE(_data.capacity() > index * _nrRegs + varNr); return _data[index * _nrRegs + varNr]; } diff --git a/arangod/Aql/Ast.h b/arangod/Aql/Ast.h index 7e2d3a4a1c..b9a4c01c86 100644 --- a/arangod/Aql/Ast.h +++ b/arangod/Aql/Ast.h @@ -666,8 +666,16 @@ class Ast { AstNode* nodeFromJson(TRI_json_t const*, bool); - + ////////////////////////////////////////////////////////////////////////////// + /// @brief traverse the AST using a depth-first visitor + ////////////////////////////////////////////////////////////////////////////// + + static AstNode* traverseAndModify(AstNode*, + std::function, + void*); + private: + ////////////////////////////////////////////////////////////////////////////// /// @brief make condition from example ////////////////////////////////////////////////////////////////////////////// @@ -786,14 +794,6 @@ class Ast { std::function, void*); - ////////////////////////////////////////////////////////////////////////////// - /// @brief traverse the AST using a depth-first visitor - ////////////////////////////////////////////////////////////////////////////// - - static AstNode* traverseAndModify(AstNode*, - std::function, - void*); - ////////////////////////////////////////////////////////////////////////////// /// @brief traverse the AST, using pre- and post-order visitors ////////////////////////////////////////////////////////////////////////////// diff --git a/arangod/Aql/CollectBlock.cpp b/arangod/Aql/CollectBlock.cpp index 5c9e70ed29..39749731a1 100644 --- a/arangod/Aql/CollectBlock.cpp +++ b/arangod/Aql/CollectBlock.cpp @@ -27,8 +27,6 @@ #include "Basics/Exceptions.h" #include "VocBase/vocbase.h" -#include - using namespace std; using namespace triagens::arango; using namespace triagens::aql; @@ -576,7 +574,19 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost, std::unordered_map, AggregateValuesType*, GroupKeyHash, GroupKeyEqual> allGroups(1024, GroupKeyHash(_trx, groupColls), GroupKeyEqual(_trx, groupColls)); - + + // cleanup function for group values + auto cleanup = [&allGroups] () -> void { + for (auto& it : allGroups) { + for (auto& it2 : *(it.second)) { + delete it2; + } + delete it.second; + } + }; + + // prevent memory leaks by always cleaning up the groups + TRI_DEFER(cleanup()); auto buildResult = [&](AqlItemBlock const* src) { TRI_ASSERT(groupColls.size() == _groupRegisters.size()); @@ -617,13 +627,13 @@ int HashedCollectBlock::getOrSkipSome(size_t atLeast, size_t atMost, size_t j = 0; for (auto const& r : *(it.second)) { // TODO: check if cloning is necessary - result->setValue(row, _aggregateRegisters[j++].first, r->getValue()); + result->setValue(row, _aggregateRegisters[j++].first, r->stealValue()); } if (en->_count) { // set group count in result register // TODO: check if cloning is necessary - result->setValue(row, _collectRegister, it.second->back()->getValue()); + result->setValue(row, _collectRegister, it.second->back()->stealValue()); // int64_t value = (*(it.second))[0].toInt64(); //result->setValue(row, _collectRegister, AqlValue(new Json(static_cast(value)))); } diff --git a/arangod/Aql/CollectNode.cpp b/arangod/Aql/CollectNode.cpp index ca75986506..196c015346 100644 --- a/arangod/Aql/CollectNode.cpp +++ b/arangod/Aql/CollectNode.cpp @@ -233,10 +233,7 @@ std::vector CollectNode::getVariablesUsedHere() const { // copy result into vector std::vector vv; - vv.reserve(v.size()); - for (auto const& x : v) { - vv.emplace_back(x); - } + vv.insert(vv.begin(), v.begin(), v.end()); return vv; } diff --git a/arangod/Aql/CollectNode.h b/arangod/Aql/CollectNode.h index cdd58d8ffb..c21591b5c6 100644 --- a/arangod/Aql/CollectNode.h +++ b/arangod/Aql/CollectNode.h @@ -244,9 +244,8 @@ class CollectNode : public ExecutionNode { std::vector getVariablesSetHere() const override final { std::vector v; - size_t const n = - _groupVariables.size() + (_outVariable == nullptr ? 0 : 1); - v.reserve(n); + v.reserve( + _groupVariables.size() + _aggregateVariables.size() + (_outVariable == nullptr ? 0 : 1)); for (auto const& p : _groupVariables) { v.emplace_back(p.first); diff --git a/arangod/Aql/ExecutionPlan.cpp b/arangod/Aql/ExecutionPlan.cpp index ab42b6b0a5..e1eb4c045d 100644 --- a/arangod/Aql/ExecutionPlan.cpp +++ b/arangod/Aql/ExecutionPlan.cpp @@ -1077,10 +1077,10 @@ ExecutionNode* ExecutionPlan::fromNodeCollectCount(ExecutionNode* previous, std::vector>> const aggregateVariables{}; - auto en = registerNode( - new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr, + auto collectNode = new CollectNode(this, nextId(), options, groupVariables, aggregateVariables, nullptr, outVariable, std::vector(), - _ast->variables()->variables(false), true, false)); + _ast->variables()->variables(false), true, false); + auto en = registerNode(collectNode); return addDependency(previous, en); } @@ -1096,6 +1096,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous, auto options = createCollectOptions(node->getMember(0)); + std::unordered_map aliases; + // group variables std::vector> groupVariables; { @@ -1127,6 +1129,8 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous, auto calc = createTemporaryCalculation(expression, previous); previous = calc; groupVariables.emplace_back(std::make_pair(v, getOutVariable(calc))); + + aliases.emplace(v, groupVariables.back().second); } } } @@ -1134,6 +1138,17 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous, // aggregate variables std::vector>> aggregateVariables; { + auto variableReplacer = [&aliases, this] (AstNode* node, void*) -> AstNode* { + if (node->type == NODE_TYPE_REFERENCE) { + auto it = aliases.find(static_cast(node->getData())); + + if (it != aliases.end()) { + return _ast->createNodeReference((*it).second); + } + } + return node; + }; + auto list = node->getMember(2); size_t const numVars = list->numMembers(); @@ -1156,21 +1171,34 @@ ExecutionNode* ExecutionPlan::fromNodeCollectAggregate(ExecutionNode* previous, // operand is always a function call TRI_ASSERT(expression->type == NODE_TYPE_FCALL); - // function should have one argument (an array with the parameters) - TRI_ASSERT(expression->numMembers() == 1); - auto args = expression->getMember(0); - // the number of arguments should also be one (note: this has been - // validated before) - TRI_ASSERT(expression->numMembers() == 1); - - auto calc = createTemporaryCalculation(args->getMember(0), previous); - previous = calc; - // build aggregator auto func = static_cast(expression->getData()); TRI_ASSERT(func != nullptr); - aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName))); + // function should have one argument (an array with the parameters) + TRI_ASSERT(expression->numMembers() == 1); + + auto args = expression->getMember(0); + // the number of arguments should also be one (note: this has been + // validated before) + TRI_ASSERT(args->type == NODE_TYPE_ARRAY); + TRI_ASSERT(args->numMembers() == 1); + + auto arg = args->getMember(0); + arg = Ast::traverseAndModify(arg, variableReplacer, nullptr); + + + if (arg->type == NODE_TYPE_REFERENCE) { + // operand is a variable + auto e = static_cast(arg->getData()); + aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(e, func->externalName))); + } + else { + auto calc = createTemporaryCalculation(arg, previous); + previous = calc; + + aggregateVariables.emplace_back(std::make_pair(v, std::make_pair(getOutVariable(calc), func->externalName))); + } } }