1
0
Fork 0

finalize distinct aggregator code

This commit is contained in:
jsteemann 2018-06-26 22:46:36 +02:00
parent 2610af55f8
commit c755c7927b
2 changed files with 108 additions and 44 deletions

View File

@ -305,12 +305,10 @@ struct AggregatorAverage : public Aggregator {
}
TRI_ASSERT(count > 0);
builder.clear();
builder.add(VPackValue(sum / count));
AqlValue temp(builder.slice());
double v = sum / count;
reset();
return temp;
return AqlValue(AqlValueHintDouble(v));
}
uint64_t count;
@ -432,19 +430,18 @@ struct AggregatorVariance : public AggregatorVarianceBase {
}
TRI_ASSERT(count > 0);
builder.clear();
double v;
if (!population) {
TRI_ASSERT(count > 1);
builder.add(VPackValue(sum / (count - 1)));
v = sum / (count - 1);
}
else {
builder.add(VPackValue(sum / count));
v = sum / count;
}
AqlValue temp(builder.slice());
reset();
return temp;
return AqlValue(AqlValueHintDouble(v));
}
};
@ -476,46 +473,81 @@ struct AggregatorVarianceBaseStep1 final : public AggregatorVarianceBase {
}
};
/// @brief the coordinator variant of VARIANCE/STDDEV
template<typename T>
struct AggregatorVarianceBaseStep2 final : public T {
/// @brief the coordinator variant of VARIANCE
struct AggregatorVarianceBaseStep2 : public AggregatorVarianceBase {
AggregatorVarianceBaseStep2(transaction::Methods* trx, bool population)
: T(trx, population) {}
: AggregatorVarianceBase(trx, population) {}
void reset() override {
AggregatorVarianceBase::reset();
values.clear();
}
void reduce(AqlValue const& cmpValue) override {
if (!cmpValue.isArray()) {
this->invalid = true;
invalid = true;
return;
}
/*
data = [[10,8,5,3],[4,9,6],[2,10,7,2]]; parts = []; data.forEach(function(part) { let count = 0; let sum = 0; let mean = 0; part.forEach(function(value) { ++count; let delta = value - mean; mean += delta / count; sum += delta * (value - mean); }); parts.push({count, mean, sum, variance:sum/count })}); print(parts); average = 0; n = 0; parts.forEach(function(part) { n += part.count; average += part.mean * part.count; }); average /= n; squares = []; parts.forEach(function(part) { squares.push((part.count) * part.variance + part.count * Math.pow(part.mean - average, 2)); }); s = 0; squares.forEach(function(square) { s += square; }); print(s / (n ));
*/
bool mustDestroy;
AqlValue const& countValue = cmpValue.at(this->trx, 0, mustDestroy, false);
AqlValue const& sumValue = cmpValue.at(this->trx, 1, mustDestroy, false);
AqlValue const& meanValue = cmpValue.at(this->trx, 2, mustDestroy, false);
AqlValue const& countValue = cmpValue.at(trx, 0, mustDestroy, false);
AqlValue const& sumValue = cmpValue.at(trx, 1, mustDestroy, false);
AqlValue const& meanValue = cmpValue.at(trx, 2, mustDestroy, false);
if (countValue.isNull(true) || sumValue.isNull(true) || meanValue.isNull(true)) {
this->invalid = true;
invalid = true;
return;
}
bool failed = false;
double v1 = sumValue.toDouble(this->trx, failed);
double v1 = sumValue.toDouble(trx, failed);
if (failed) {
this->invalid = true;
invalid = true;
return;
}
double v2 = meanValue.toDouble(this->trx, failed);
double v2 = meanValue.toDouble(trx, failed);
if (failed) {
this->invalid = true;
invalid = true;
return;
}
this->count += countValue.toInt64(this->trx);
this->sum += v1;
this->mean += v2;
int64_t c = countValue.toInt64(trx);
if (c == 0) {
invalid = true;
return;
}
count += c;
sum += v2 * c;
mean += v2 * c;
values.emplace_back(std::make_tuple(v1 / c, v2, c));
}
AqlValue stealValue() override {
if (invalid || count == 0 || (count == 1 && !population)) {
return AqlValue(AqlValueHintNull());
}
double const avg = sum / count;
double v = 0.0;
for (auto const& it : values) {
double variance = std::get<0>(it);
double mean = std::get<1>(it);
int64_t count = std::get<2>(it);
v += count * (variance + std::pow(mean - avg, 2));
}
if (!population) {
TRI_ASSERT(count > 1);
v /= count - 1;
}
else {
v /= count;
}
reset();
return AqlValue(AqlValueHintDouble(v));
}
std::vector<std::tuple<double, double, int64_t>> values;
};
/// @brief the single server variant of STDDEV
@ -530,19 +562,51 @@ struct AggregatorStddev : public AggregatorVarianceBase {
}
TRI_ASSERT(count > 0);
builder.clear();
double v;
if (!population) {
TRI_ASSERT(count > 1);
builder.add(VPackValue(sqrt(sum / (count - 1))));
v = std::sqrt(sum / (count - 1));
}
else {
builder.add(VPackValue(sqrt(sum / count)));
v = std::sqrt(sum / count);
}
AqlValue temp(builder.slice());
reset();
return temp;
return AqlValue(AqlValueHintDouble(v));
}
};
/// @brief the coordinator variant of STDDEV
struct AggregatorStddevBaseStep2 final : public AggregatorVarianceBaseStep2 {
AggregatorStddevBaseStep2(transaction::Methods* trx, bool population)
: AggregatorVarianceBaseStep2(trx, population) {}
AqlValue stealValue() override {
if (invalid || count == 0 || (count == 1 && !population)) {
return AqlValue(AqlValueHintNull());
}
double const avg = sum / count;
double v = 0.0;
for (auto const& it : values) {
double variance = std::get<0>(it);
double mean = std::get<1>(it);
int64_t count = std::get<2>(it);
v += count * (variance + std::pow(mean - avg, 2));
}
if (!population) {
TRI_ASSERT(count > 1);
v /= count - 1;
}
else {
v /= count;
}
v = std::sqrt(v);
reset();
return AqlValue(AqlValueHintDouble(v));
}
};
@ -822,7 +886,7 @@ std::unordered_map<std::string, AggregatorInfo> const aggregators = {
"VARIANCE_POPULATION_STEP1"
} },
{ "VARIANCE_POPULATION_STEP2", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2<AggregatorVariance>>(trx, true); },
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2>(trx, true); },
doesRequireInput, internalOnly,
"",
"VARIANCE_POPULATION_STEP2"
@ -840,7 +904,7 @@ std::unordered_map<std::string, AggregatorInfo> const aggregators = {
"VARIANCE_SAMPLE_STEP1"
} },
{ "VARIANCE_SAMPLE_STEP2", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2<AggregatorVariance>>(trx, false); },
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2>(trx, false); },
doesRequireInput, internalOnly,
"",
"VARIANCE_SAMPLE_STEP2"
@ -858,7 +922,7 @@ std::unordered_map<std::string, AggregatorInfo> const aggregators = {
"STDDEV_POPULATION_STEP1"
} },
{ "STDDEV_POPULATION_STEP2", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2<AggregatorStddev>>(trx, true); },
[](transaction::Methods* trx) { return std::make_unique<AggregatorStddevBaseStep2>(trx, true); },
doesRequireInput, internalOnly,
"",
"STDDEV_POPULATION_STEP2"
@ -867,19 +931,19 @@ std::unordered_map<std::string, AggregatorInfo> const aggregators = {
[](transaction::Methods* trx) { return std::make_unique<AggregatorStddev>(trx, false); },
doesRequireInput, official,
"STDDEV_SAMPLE_STEP1",
"STDDEV_SAMPLE_2"
"STDDEV_SAMPLE_STEP2"
} },
{ "STDDEV_SAMPLE_STEP1", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep1>(trx, false); },
doesRequireInput, internalOnly,
"",
"STDDEV_SAMPLE_1"
"STDDEV_SAMPLE_STEP1"
} },
{ "STDDEV_SAMPLE_STEP2", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorVarianceBaseStep2<AggregatorStddev>>(trx, false); },
[](transaction::Methods* trx) { return std::make_unique<AggregatorStddevBaseStep2>(trx, false); },
doesRequireInput, internalOnly,
"",
"STDDEV_SAMPLE_2"
"STDDEV_SAMPLE_STEP2"
} },
{ "UNIQUE", {
[](transaction::Methods* trx) { return std::make_unique<AggregatorUnique>(trx); },

View File

@ -311,7 +311,7 @@ static bool ValidateAggregates(Parser* parser, AstNode const* aggregates) {
}
else {
auto f = static_cast<arangodb::aql::Function*>(func->getData());
if (!Aggregator::isSupported(f->name)) {
if (!Aggregator::isValid(f->name)) {
// aggregate expression must be a call to MIN|MAX|LENGTH...
isValid = false;
}