mirror of https://gitee.com/bigwinds/arangodb
Working on PageRank and SCC
This commit is contained in:
parent
1b65ff07b8
commit
afaab2e8d5
|
@ -369,6 +369,7 @@ SET(ARANGOD_SOURCES
|
|||
Pregel/Algos/RecoveringPageRank.cpp
|
||||
Pregel/Algos/LineRank.cpp
|
||||
Pregel/Algos/ConnectedComponents.cpp
|
||||
Pregel/Algos/SCC.cpp
|
||||
Pregel/Conductor.cpp
|
||||
Pregel/GraphStore.cpp
|
||||
Pregel/IncomingCache.cpp
|
||||
|
|
|
@ -113,6 +113,7 @@ struct ValueAggregator : public NumberAggregator<T> {
|
|||
void aggregate(void const* valuePtr) override { this->_value = *((T*)valuePtr); };
|
||||
void parse(VPackSlice slice) override {this-> _value = slice.getNumber<T>(); }
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "Pregel/Algos/ShortestPath.h"
|
||||
#include "Pregel/Algos/LineRank.h"
|
||||
#include "Pregel/Algos/ConnectedComponents.h"
|
||||
#include "Pregel/Algos/SCC.h"
|
||||
#include "Pregel/Utils.h"
|
||||
|
||||
using namespace arangodb;
|
||||
|
@ -46,6 +47,8 @@ IAlgorithm* AlgoRegistry::createAlgorithm(std::string const& algorithm,
|
|||
return new algos::LineRank(userParams);
|
||||
} else if (algorithm == "connectedcomponents") {
|
||||
return new algos::ConnectedComponents(userParams);
|
||||
} else if (algorithm == "scc") {
|
||||
return new algos::SCC(userParams);
|
||||
} else {
|
||||
THROW_ARANGO_EXCEPTION_MESSAGE(TRI_ERROR_BAD_PARAMETER,
|
||||
"Unsupported Algorithm");
|
||||
|
@ -85,6 +88,8 @@ IWorker* AlgoRegistry::createWorker(TRI_vocbase_t* vocbase,
|
|||
return createWorker(vocbase, new algos::LineRank(userParams), body);
|
||||
} else if (algorithm == "connectedcomponents") {
|
||||
return createWorker(vocbase, new algos::ConnectedComponents(userParams), body);
|
||||
} else if (algorithm == "scc") {
|
||||
return createWorker(vocbase, new algos::SCC(userParams), body);
|
||||
} else {
|
||||
THROW_ARANGO_EXCEPTION_MESSAGE(TRI_ERROR_BAD_PARAMETER,
|
||||
"Unsupported Algorithm");
|
||||
|
|
|
@ -84,7 +84,9 @@ struct Algorithm : IAlgorithm {
|
|||
}
|
||||
virtual GraphFormat* inputFormat() const = 0;
|
||||
virtual MessageFormat<M>* messageFormat() const = 0;
|
||||
virtual MessageCombiner<M>* messageCombiner() const = 0;
|
||||
virtual MessageCombiner<M>* messageCombiner() const {
|
||||
return nullptr;
|
||||
};
|
||||
virtual VertexComputation<V, E, M>* createComputation(
|
||||
WorkerConfig const*) const = 0;
|
||||
virtual VertexCompensation<V, E, M>* createCompensation(
|
||||
|
|
|
@ -38,7 +38,7 @@ struct ConnectedComponents : public SimpleAlgorithm<int64_t, int64_t, int64_t> {
|
|||
public:
|
||||
ConnectedComponents(VPackSlice userParams) : SimpleAlgorithm("ConnectedComponents", userParams) {}
|
||||
|
||||
bool supportsAsyncMode() const override { return false; }
|
||||
bool supportsAsyncMode() const override { return true; }
|
||||
bool supportsCompensation() const override { return true; }
|
||||
|
||||
GraphFormat* inputFormat() const override;
|
||||
|
|
|
@ -44,7 +44,7 @@ struct LineRank : public SimpleAlgorithm<float, float, float> {
|
|||
return new VertexGraphFormat<float, float>(_resultField, -1.0);
|
||||
}
|
||||
MessageFormat<float>* messageFormat() const override {
|
||||
return new FloatMessageFormat();
|
||||
return new NumberMessageFormat<float>();
|
||||
}
|
||||
|
||||
MessageCombiner<float>* messageCombiner() const override {
|
||||
|
|
|
@ -39,63 +39,56 @@ using namespace arangodb;
|
|||
using namespace arangodb::pregel;
|
||||
using namespace arangodb::pregel::algos;
|
||||
|
||||
static float EPS = 0.00001;
|
||||
static std::string const kConvergence = "convergence";
|
||||
|
||||
static double EPS = 0.00001;
|
||||
PageRank::PageRank(arangodb::velocypack::Slice params)
|
||||
: SimpleAlgorithm("PageRank", params) {
|
||||
VPackSlice t = params.get("convergenceThreshold");
|
||||
_threshold = t.isNumber() ? t.getNumber<double>() : EPS;
|
||||
}
|
||||
|
||||
struct PRComputation : public VertexComputation<double, float, double> {
|
||||
struct PRComputation : public VertexComputation<float, float, float> {
|
||||
|
||||
PRComputation() {}
|
||||
void compute(MessageIterator<double> const& messages) override {
|
||||
double* ptr = mutableVertexData();
|
||||
double copy = *ptr;
|
||||
void compute(MessageIterator<float> const& messages) override {
|
||||
float* ptr = mutableVertexData();
|
||||
float copy = *ptr;
|
||||
|
||||
if (globalSuperstep() == 0) {
|
||||
*ptr = 1.0 / context()->vertexCount();
|
||||
} else {
|
||||
double sum = 0.0;
|
||||
for (const double* msg : messages) {
|
||||
float sum = 0.0;
|
||||
for (const float* msg : messages) {
|
||||
sum += *msg;
|
||||
}
|
||||
*ptr = 0.85 * sum + 0.15 / context()->vertexCount();
|
||||
}
|
||||
double diff = fabs(copy - *ptr);
|
||||
float diff = fabs(copy - *ptr);
|
||||
aggregate(kConvergence, diff);
|
||||
|
||||
if (globalSuperstep() < 50) {
|
||||
RangeIterator<Edge<float>> edges = getEdges();
|
||||
double val = *ptr / edges.size();
|
||||
for (Edge<float>* edge : edges) {
|
||||
sendMessage(edge, val);
|
||||
}
|
||||
} else {
|
||||
voteHalt();
|
||||
RangeIterator<Edge<float>> edges = getEdges();
|
||||
float val = *ptr / edges.size();
|
||||
for (Edge<float>* edge : edges) {
|
||||
sendMessage(edge, val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
VertexComputation<double, float, double>* PageRank::createComputation(
|
||||
VertexComputation<float, float, float>* PageRank::createComputation(
|
||||
WorkerConfig const* config) const {
|
||||
return new PRComputation();
|
||||
}
|
||||
|
||||
IAggregator* PageRank::aggregator(std::string const& name) const {
|
||||
if (name == kConvergence) {
|
||||
return new MaxAggregator<double>(-1.0f, false);
|
||||
return new MaxAggregator<float>(MAXFLOAT, false);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
struct MyMasterContext : MasterContext {
|
||||
MyMasterContext() {}// TODO use _threashold
|
||||
bool postGlobalSuperstep(uint64_t gss) {
|
||||
double const* diff = getAggregatedValue<double>(kConvergence);
|
||||
return globalSuperstep() < 2 || *diff > EPS;
|
||||
struct MyMasterContext : public MasterContext {
|
||||
MyMasterContext() {
|
||||
//VPackSlice t = params.get("convergenceThreshold");
|
||||
//_threshold = t.isNumber() ? t.getNumber<float>() : EPS;
|
||||
}// TODO use _threashold
|
||||
bool postGlobalSuperstep() override {
|
||||
float const* diff = getAggregatedValue<float>(kConvergence);
|
||||
return globalSuperstep() < 50 && *diff > EPS;
|
||||
};
|
||||
};
|
||||
|
||||
|
|
|
@ -31,25 +31,24 @@ namespace pregel {
|
|||
namespace algos {
|
||||
|
||||
/// PageRank
|
||||
struct PageRank : public SimpleAlgorithm<double, float, double> {
|
||||
float _threshold;
|
||||
struct PageRank : public SimpleAlgorithm<float, float, float> {
|
||||
|
||||
public:
|
||||
PageRank(arangodb::velocypack::Slice params);
|
||||
PageRank(arangodb::velocypack::Slice params)
|
||||
: SimpleAlgorithm("PageRank", params) {}
|
||||
|
||||
GraphFormat* inputFormat() const override {
|
||||
return new VertexGraphFormat<double, float>(_resultField, 0);
|
||||
return new VertexGraphFormat<float, float>(_resultField, 0);
|
||||
}
|
||||
|
||||
MessageFormat<double>* messageFormat() const override {
|
||||
return new DoubleMessageFormat();
|
||||
MessageFormat<float>* messageFormat() const override {
|
||||
return new NumberMessageFormat<float>();
|
||||
}
|
||||
|
||||
MessageCombiner<double>* messageCombiner() const override {
|
||||
return new SumCombiner<double>();
|
||||
MessageCombiner<float>* messageCombiner() const override {
|
||||
return new SumCombiner<float>();
|
||||
}
|
||||
|
||||
VertexComputation<double, float, double>* createComputation(
|
||||
VertexComputation<float, float, float>* createComputation(
|
||||
WorkerConfig const*) const override;
|
||||
IAggregator* aggregator(std::string const& name) const override;
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ using namespace arangodb;
|
|||
using namespace arangodb::pregel;
|
||||
using namespace arangodb::pregel::algos;
|
||||
|
||||
static float EPS = 0.00001;
|
||||
static std::string const kConvergence = "convergence";
|
||||
static std::string const kStep = "step";
|
||||
static std::string const kRank = "rank";
|
||||
|
@ -47,15 +48,9 @@ static std::string const kNonFailedCount = "nonfailedCount";
|
|||
static std::string const kScale = "scale";
|
||||
|
||||
|
||||
RecoveringPageRank::RecoveringPageRank(arangodb::velocypack::Slice params)
|
||||
: SimpleAlgorithm("PageRank", params) {
|
||||
VPackSlice t = params.get("convergenceThreshold");
|
||||
_threshold = t.isNumber() ? t.getNumber<float>() : 0.000002f;
|
||||
}
|
||||
|
||||
struct RPRComputation : public VertexComputation<float, float, float> {
|
||||
float _limit;
|
||||
RPRComputation(float t) : _limit(t) {}
|
||||
struct MyComputation : public VertexComputation<float, float, float> {
|
||||
|
||||
MyComputation() {}
|
||||
void compute(MessageIterator<float> const& messages) override {
|
||||
float* ptr = mutableVertexData();
|
||||
float copy = *ptr;
|
||||
|
@ -72,26 +67,18 @@ struct RPRComputation : public VertexComputation<float, float, float> {
|
|||
float diff = fabsf(copy - *ptr);
|
||||
aggregate(kConvergence, diff);
|
||||
aggregate(kRank, ptr);
|
||||
// const float* val = getAggregatedValue<float>("convergence");
|
||||
// if (val) { // if global convergence is available use it
|
||||
// diff = *val;
|
||||
//}
|
||||
|
||||
if (globalSuperstep() < 50 && diff > _limit) {
|
||||
RangeIterator<Edge<float>> edges = getEdges();
|
||||
float val = *ptr / edges.size();
|
||||
for (Edge<float>* edge : edges) {
|
||||
sendMessage(edge, val);
|
||||
}
|
||||
} else {
|
||||
voteHalt();
|
||||
RangeIterator<Edge<float>> edges = getEdges();
|
||||
float val = *ptr / edges.size();
|
||||
for (Edge<float>* edge : edges) {
|
||||
sendMessage(edge, val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
VertexComputation<float, float, float>* RecoveringPageRank::createComputation(
|
||||
WorkerConfig const* config) const {
|
||||
return new RPRComputation(_threshold);
|
||||
return new MyComputation();
|
||||
}
|
||||
|
||||
IAggregator* RecoveringPageRank::aggregator(std::string const& name) const {
|
||||
|
@ -140,24 +127,31 @@ VertexCompensation<float, float, float>* RecoveringPageRank::createCompensation(
|
|||
}
|
||||
|
||||
struct MyMasterContext : public MasterContext {
|
||||
MyMasterContext(VPackSlice params) {};
|
||||
float _threshold;
|
||||
|
||||
MyMasterContext(VPackSlice params) {
|
||||
VPackSlice t = params.get("convergenceThreshold");
|
||||
_threshold = t.isNumber() ? t.getNumber<float>() : EPS;
|
||||
};
|
||||
|
||||
int32_t recoveryStep = 0;
|
||||
float totalRank = 0;
|
||||
|
||||
bool postGlobalSuperstep(uint64_t gss) override {
|
||||
bool postGlobalSuperstep() override {
|
||||
const float* convergence = getAggregatedValue<float>(kConvergence);
|
||||
LOG(INFO) << "Current convergence level" << *convergence;
|
||||
totalRank = *getAggregatedValue<float>(kRank);
|
||||
return true;
|
||||
|
||||
float const* diff = getAggregatedValue<float>(kConvergence);
|
||||
return globalSuperstep() < 50 && *diff > _threshold;
|
||||
}
|
||||
|
||||
bool preCompensation(uint64_t gss) override {
|
||||
bool preCompensation() override {
|
||||
aggregate(kStep, recoveryStep);
|
||||
return totalRank != 0;
|
||||
}
|
||||
|
||||
bool postCompensation(uint64_t gss) override {
|
||||
bool postCompensation() override {
|
||||
if (recoveryStep == 0) {
|
||||
recoveryStep = 1;
|
||||
|
||||
|
|
|
@ -32,10 +32,9 @@ namespace algos {
|
|||
|
||||
/// PageRank
|
||||
struct RecoveringPageRank : public SimpleAlgorithm<float, float, float> {
|
||||
float _threshold;
|
||||
|
||||
public:
|
||||
RecoveringPageRank(arangodb::velocypack::Slice params);
|
||||
|
||||
RecoveringPageRank(arangodb::velocypack::Slice params)
|
||||
: SimpleAlgorithm("PageRank", params) {}
|
||||
|
||||
bool supportsCompensation() const override { return true; }
|
||||
MasterContext* masterContext(VPackSlice userParams) const override;
|
||||
|
@ -45,7 +44,7 @@ struct RecoveringPageRank : public SimpleAlgorithm<float, float, float> {
|
|||
}
|
||||
|
||||
MessageFormat<float>* messageFormat() const override {
|
||||
return new FloatMessageFormat();
|
||||
return new NumberMessageFormat<float>();
|
||||
}
|
||||
|
||||
MessageCombiner<float>* messageCombiner() const override {
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// DISCLAIMER
|
||||
///
|
||||
/// Copyright 2016 ArangoDB GmbH, Cologne, Germany
|
||||
///
|
||||
/// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
/// you may not use this file except in compliance with the License.
|
||||
/// You may obtain a copy of the License at
|
||||
///
|
||||
/// http://www.apache.org/licenses/LICENSE-2.0
|
||||
///
|
||||
/// Unless required by applicable law or agreed to in writing, software
|
||||
/// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
/// See the License for the specific language governing permissions and
|
||||
/// limitations under the License.
|
||||
///
|
||||
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
|
||||
///
|
||||
/// @author Simon Grätzer
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "SCC.h"
|
||||
#include "Pregel/Aggregator.h"
|
||||
#include "Pregel/Algorithm.h"
|
||||
#include "Pregel/GraphStore.h"
|
||||
#include "Pregel/IncomingCache.h"
|
||||
#include "Pregel/VertexComputation.h"
|
||||
#include "Cluster/ClusterInfo.h"
|
||||
#include "Cluster/ServerState.h"
|
||||
#include "Pregel/MasterContext.h"
|
||||
|
||||
using namespace arangodb::pregel;
|
||||
using namespace arangodb::pregel::algos;
|
||||
|
||||
static std::string const kPhase = "phase";
|
||||
static std::string const kFoundNewMax = "max";
|
||||
static std::string const kConverged = "converged";
|
||||
|
||||
enum SCCPhase {
|
||||
TRANSPOSE = 0,
|
||||
TRIMMING = 1,
|
||||
FORWARD_TRAVERSAL = 2,
|
||||
BACKWARD_TRAVERSAL_START = 3,
|
||||
BACKWARD_TRAVERSAL_REST = 4
|
||||
};
|
||||
|
||||
struct MyComputation : public VertexComputation<SCCValue, int32_t, SenderMessage<uint64_t>> {
|
||||
MyComputation() {}
|
||||
|
||||
void compute(MessageIterator<SenderMessage<uint64_t>> const& messages) override {
|
||||
if (isActive() == false) {
|
||||
// color was already determinded or vertex was trimmed
|
||||
return;
|
||||
}
|
||||
|
||||
SCCValue *vertexState = mutableVertexData();
|
||||
|
||||
int64_t const* phase = getAggregatedValue<int64_t>(kPhase);
|
||||
switch (*phase) {
|
||||
|
||||
// let all our connected nodes know we are there
|
||||
case SCCPhase::TRANSPOSE: {
|
||||
|
||||
vertexState->parents.clear();
|
||||
SenderMessage<uint64_t> message(pregelId(), 0);
|
||||
sendMessageToAllEdges(message);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
// Creates list of parents based on the received ids and halts the vertices
|
||||
// that don't have any parent or outgoing edge, hence, they can't be
|
||||
// part of an SCC.
|
||||
case SCCPhase::TRIMMING:{
|
||||
|
||||
for (SenderMessage<uint64_t> const* msg : messages) {
|
||||
vertexState->parents.push_back(msg->pregelId);
|
||||
}
|
||||
// reset the color for vertices which are not active
|
||||
vertexState->color = vertexState->vertexID;
|
||||
// If this node doesn't have any parents or outgoing edges,
|
||||
// it can't be part of an SCC
|
||||
RangeIterator<Edge<int32_t>> edges = getEdges();
|
||||
if (vertexState->parents.size() == 0 || edges.size() == 0) {
|
||||
voteHalt();
|
||||
} else {
|
||||
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
|
||||
sendMessageToAllEdges(message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SCCPhase::FORWARD_TRAVERSAL:{
|
||||
uint64_t old = vertexState->color;
|
||||
for (SenderMessage<uint64_t> const* msg : messages) {
|
||||
if (vertexState->color < msg->value) {
|
||||
vertexState->color = msg->value;
|
||||
}
|
||||
}
|
||||
if (old != vertexState->color) {
|
||||
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
|
||||
sendMessageToAllEdges(message);
|
||||
aggregate(kFoundNewMax, true);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SCCPhase::BACKWARD_TRAVERSAL_START:{
|
||||
|
||||
// if I am the 'root' of a SCC start traversak
|
||||
if (vertexState->vertexID == vertexState->color) {
|
||||
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
|
||||
// sendMessageToAllParents
|
||||
for (PregelID const& pid : vertexState->parents) {
|
||||
sendMessage(pid, message);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case SCCPhase::BACKWARD_TRAVERSAL_REST:{
|
||||
|
||||
for (SenderMessage<uint64_t> const* msg : messages) {
|
||||
if (vertexState->color == msg->value) {
|
||||
for (PregelID const& pid : vertexState->parents) {
|
||||
sendMessage(pid, *msg);
|
||||
}
|
||||
aggregate(kConverged, true);
|
||||
voteHalt();
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
VertexComputation<SCCValue, int32_t,
|
||||
SenderMessage<uint64_t>>* SCC::createComputation(
|
||||
WorkerConfig const* config) const {
|
||||
return new MyComputation();
|
||||
}
|
||||
|
||||
|
||||
struct MyGraphFormat : public GraphFormat {
|
||||
const std::string _resultField;
|
||||
uint64_t vertexIdRange = 0;
|
||||
|
||||
MyGraphFormat(std::string const& result) : _resultField(result) {}
|
||||
|
||||
void willLoadVertices(uint64_t count) override {
|
||||
// if we aren't running in a cluster it doesn't matter
|
||||
if (arangodb::ServerState::instance()->isRunningInCluster()) {
|
||||
arangodb::ClusterInfo *ci = arangodb::ClusterInfo::instance();
|
||||
if (ci) {
|
||||
vertexIdRange = ci->uniqid(count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t estimatedVertexSize() const override { return sizeof(SCCValue); };
|
||||
size_t estimatedEdgeSize() const override { return 0; };
|
||||
|
||||
size_t copyVertexData(VertexEntry const& vertex,
|
||||
std::string const& documentId,
|
||||
arangodb::velocypack::Slice document,
|
||||
void* targetPtr, size_t maxSize) override {
|
||||
SCCValue *senders = (SCCValue*) targetPtr;
|
||||
senders->vertexID = vertexIdRange++;
|
||||
return sizeof(SCCValue);
|
||||
}
|
||||
|
||||
size_t copyEdgeData(arangodb::velocypack::Slice document, void* targetPtr,
|
||||
size_t maxSize) override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool buildVertexDocument(arangodb::velocypack::Builder& b,
|
||||
const void* ptr,
|
||||
size_t size) override {
|
||||
SCCValue *senders = (SCCValue*) ptr;
|
||||
b.add(_resultField, VPackValue(senders->color));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool buildEdgeDocument(arangodb::velocypack::Builder& b, const void* ptr,
|
||||
size_t size) override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
GraphFormat* SCC::inputFormat() const {
|
||||
return new MyGraphFormat(_resultField);
|
||||
}
|
||||
|
||||
struct MyMasterContext : public MasterContext {
|
||||
MyMasterContext() {}// TODO use _threashold
|
||||
void preGlobalSuperstep() override {
|
||||
if (globalSuperstep() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t const* phase = getAggregatedValue<int64_t>(kPhase);
|
||||
switch (*phase) {
|
||||
case SCCPhase::TRANSPOSE:
|
||||
aggregate(kPhase, SCCPhase::TRIMMING);
|
||||
break;
|
||||
|
||||
case SCCPhase::TRIMMING:
|
||||
aggregate(kPhase, SCCPhase::FORWARD_TRAVERSAL);
|
||||
break;
|
||||
|
||||
case SCCPhase::FORWARD_TRAVERSAL: {
|
||||
bool const* newMaxFound = getAggregatedValue<bool>(kFoundNewMax);
|
||||
if (*newMaxFound == false) {
|
||||
aggregate(kPhase, SCCPhase::BACKWARD_TRAVERSAL_START);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case SCCPhase::BACKWARD_TRAVERSAL_START:
|
||||
aggregate(kPhase, SCCPhase::BACKWARD_TRAVERSAL_REST);
|
||||
break;
|
||||
|
||||
case SCCPhase::BACKWARD_TRAVERSAL_REST:
|
||||
bool const* converged = getAggregatedValue<bool>(kConverged);
|
||||
// continue until no more vertices are updated
|
||||
if (*converged == false) {
|
||||
aggregate(kPhase, SCCPhase::TRANSPOSE);
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
MasterContext* SCC::masterContext(VPackSlice userParams) const {
|
||||
return new MyMasterContext();
|
||||
}
|
||||
|
||||
IAggregator* SCC::aggregator(std::string const& name) const {
|
||||
if (name == kPhase) {
|
||||
return new ValueAggregator<int64_t>(SCCPhase::TRANSPOSE, true);
|
||||
} else if (name == kFoundNewMax) {
|
||||
return new ValueAggregator<bool>(false, false);
|
||||
} else if (name == kConverged) {
|
||||
return new ValueAggregator<bool>(false, true);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// DISCLAIMER
|
||||
///
|
||||
/// Copyright 2016 ArangoDB GmbH, Cologne, Germany
|
||||
///
|
||||
/// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
/// you may not use this file except in compliance with the License.
|
||||
/// You may obtain a copy of the License at
|
||||
///
|
||||
/// http://www.apache.org/licenses/LICENSE-2.0
|
||||
///
|
||||
/// Unless required by applicable law or agreed to in writing, software
|
||||
/// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
/// See the License for the specific language governing permissions and
|
||||
/// limitations under the License.
|
||||
///
|
||||
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
|
||||
///
|
||||
/// @author Simon Grätzer
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef ARANGODB_PREGEL_ALGOS_SCC_H
|
||||
#define ARANGODB_PREGEL_ALGOS_SCC_H 1
|
||||
|
||||
#include "Pregel/Algorithm.h"
|
||||
#include "Pregel/CommonFormats.h"
|
||||
|
||||
namespace arangodb {
|
||||
namespace pregel {
|
||||
namespace algos {
|
||||
|
||||
/// Finds strongly connected components of the graph.
|
||||
///
|
||||
/// 1. Each vertex starts with its vertex id as its "color".
|
||||
/// 2. Remove vertices which cannot be in a SCC (no incoming or no outgoing edges)
|
||||
/// 3. Propagate the color forward from each vertex, accept a neighbor's color if it's smaller than yours.
|
||||
/// At convergence, vertices with the same color represents all nodes that are visitable from the root of that color.
|
||||
/// 4. Reverse the graph.
|
||||
/// 5. Start at all roots, walk the graph. Visit a neighbor if it has the same color as you.
|
||||
/// All nodes visited belongs to the SCC identified by the root color.
|
||||
|
||||
struct SCC : public SimpleAlgorithm<SCCValue, int32_t, SenderMessage<uint64_t>> {
|
||||
public:
|
||||
SCC(VPackSlice userParams)
|
||||
: SimpleAlgorithm<SCCValue, int32_t, SenderMessage<uint64_t>>("SCC", userParams) {}
|
||||
|
||||
GraphFormat* inputFormat() const override;
|
||||
MessageFormat<SenderMessage<uint64_t>>* messageFormat() const override {
|
||||
return new SenderMessageFormat<uint64_t>();
|
||||
}
|
||||
|
||||
VertexComputation<SCCValue, int32_t, SenderMessage<uint64_t>>*
|
||||
createComputation( WorkerConfig const*) const override;
|
||||
|
||||
MasterContext* masterContext(VPackSlice userParams) const override;
|
||||
|
||||
IAggregator* aggregator(std::string const& name) const override;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -20,34 +20,48 @@
|
|||
/// @author Simon Grätzer
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#ifndef ARANGODB_PREGEL_ADDITIONAL_MFORMATS_H
|
||||
#define ARANGODB_PREGEL_ADDITIONAL_MFORMATS_H 1
|
||||
|
||||
// NOTE: this files exists primarily to include these algorithm specfic structs in the
|
||||
// cpp files to do template specialization
|
||||
|
||||
#ifndef ARANGODB_PREGEL_COMMON_MFORMATS_H
|
||||
#define ARANGODB_PREGEL_COMMON_MFORMATS_H 1
|
||||
|
||||
#include "Pregel/GraphFormat.h"
|
||||
#include "Pregel/MessageFormat.h"
|
||||
#include "Pregel/Graph.h"
|
||||
|
||||
|
||||
namespace arangodb {
|
||||
namespace pregel {
|
||||
|
||||
struct SCCValue {
|
||||
std::vector<PregelID> parents;
|
||||
uint64_t vertexID;
|
||||
uint64_t color;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct SenderValue {
|
||||
struct SenderMessage {
|
||||
SenderMessage() {}
|
||||
SenderMessage(PregelID const& pid, T const& val)
|
||||
: pregelId(pid), value(val) {}
|
||||
|
||||
PregelID pregelId;
|
||||
T value;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumberSenderFormat : public MessageFormat<SenderValue<T>> {
|
||||
struct SenderMessageFormat : public MessageFormat<SenderMessage<T>> {
|
||||
static_assert(std::is_arithmetic<T>::value, "Message type must be numeric");
|
||||
NumberSenderFormat() {}
|
||||
void unwrapValue(VPackSlice s, SenderValue<T>& senderVal) const override {
|
||||
SenderMessageFormat() {}
|
||||
void unwrapValue(VPackSlice s, SenderMessage<T>& senderVal) const override {
|
||||
VPackArrayIterator array(s);
|
||||
senderVal.pregelId.shard = (*array).getUInt();
|
||||
senderVal.pregelId.key = (*(++array)).copyString();
|
||||
senderVal.value = (*(++array)).getNumber<T>();
|
||||
}
|
||||
void addValue(VPackBuilder& arrayBuilder, SenderValue<T> const& senderVal) const override {
|
||||
void addValue(VPackBuilder& arrayBuilder, SenderMessage<T> const& senderVal) const override {
|
||||
arrayBuilder.openArray();
|
||||
arrayBuilder.add(VPackValue(senderVal.pregelId.shard));
|
||||
arrayBuilder.add(VPackValue(senderVal.pregelId.key));
|
|
@ -141,7 +141,7 @@ bool Conductor::_startGlobalStep() {
|
|||
bool proceed = true;
|
||||
if (_masterContext && _globalSuperstep > 0) { // ask algorithm to evaluate aggregated values
|
||||
_masterContext->_globalSuperstep = _globalSuperstep - 1;
|
||||
proceed = _masterContext->postGlobalSuperstep(_globalSuperstep);
|
||||
proceed = _masterContext->postGlobalSuperstep();
|
||||
if (!proceed) {
|
||||
LOG(INFO) << "Master context ended execution";
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ bool Conductor::_startGlobalStep() {
|
|||
_masterContext->_globalSuperstep = _globalSuperstep;
|
||||
_masterContext->_vertexCount = _totalVerticesCount;
|
||||
_masterContext->_edgeCount = _totalEdgesCount;
|
||||
_masterContext->preGlobalSuperstep(_globalSuperstep);
|
||||
_masterContext->preGlobalSuperstep();
|
||||
}
|
||||
|
||||
b.clear();
|
||||
|
@ -291,7 +291,7 @@ void Conductor::finishedRecoveryStep(VPackSlice data) {
|
|||
// only compensations supported
|
||||
bool proceed = false;
|
||||
if (_masterContext) {
|
||||
proceed = proceed || _masterContext->postCompensation(_globalSuperstep);
|
||||
proceed = proceed || _masterContext->postCompensation();
|
||||
}
|
||||
|
||||
int res = TRI_ERROR_NO_ERROR;
|
||||
|
@ -299,7 +299,7 @@ void Conductor::finishedRecoveryStep(VPackSlice data) {
|
|||
// reset values which are calculated during the superstep
|
||||
_aggregators->resetValues();
|
||||
if (_masterContext) {
|
||||
_masterContext->preCompensation(_globalSuperstep);
|
||||
_masterContext->preCompensation();
|
||||
}
|
||||
|
||||
VPackBuilder b;
|
||||
|
@ -401,7 +401,7 @@ void Conductor::startRecovery() {
|
|||
|
||||
// Let's try recovery
|
||||
if (_masterContext) {
|
||||
bool proceed = _masterContext->preCompensation(_globalSuperstep);
|
||||
bool proceed = _masterContext->preCompensation();
|
||||
if (!proceed) {
|
||||
cancel();
|
||||
}
|
||||
|
|
|
@ -28,7 +28,8 @@
|
|||
#include "Indexes/EdgeIndex.h"
|
||||
#include "Indexes/Index.h"
|
||||
#include "Pregel/WorkerConfig.h"
|
||||
#include "Utils.h"
|
||||
#include "Pregel/Utils.h"
|
||||
#include "Pregel/CommonFormats.h"
|
||||
#include "Utils/CollectionNameResolver.h"
|
||||
#include "Utils/ExplicitTransaction.h"
|
||||
#include "Utils/OperationCursor.h"
|
||||
|
@ -439,3 +440,8 @@ template class arangodb::pregel::GraphStore<float, float>;
|
|||
template class arangodb::pregel::GraphStore<double, float>;
|
||||
template class arangodb::pregel::GraphStore<double, double>;
|
||||
|
||||
// specific algo combos
|
||||
template class arangodb::pregel::GraphStore<SCCValue, int32_t>;
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
|
||||
#include "Pregel/IncomingCache.h"
|
||||
#include "Pregel/Utils.h"
|
||||
//#include "Pregel/AdditionalFormats.h"
|
||||
#include "Pregel/CommonFormats.h"
|
||||
|
||||
#include "Basics/MutexLocker.h"
|
||||
#include "Basics/StaticStrings.h"
|
||||
|
@ -31,9 +31,6 @@
|
|||
#include <velocypack/Iterator.h>
|
||||
#include <velocypack/velocypack-aliases.h>
|
||||
|
||||
//#include <libcuckoo/city_hasher.hh>
|
||||
//#include <libcuckoo/cuckoohash_map.hh>
|
||||
|
||||
using namespace arangodb;
|
||||
using namespace arangodb::pregel;
|
||||
|
||||
|
@ -240,15 +237,17 @@ void CombiningInCache<M>::forEach(
|
|||
}
|
||||
|
||||
// template types to create
|
||||
//template class arangodb::pregel::InCache<SenderValue<int64_t>>;
|
||||
template class arangodb::pregel::InCache<int64_t>;
|
||||
template class arangodb::pregel::InCache<float>;
|
||||
template class arangodb::pregel::InCache<double>;
|
||||
//template class arangodb::pregel::ArrayInCache<SenderValue<int64_t>>;
|
||||
template class arangodb::pregel::ArrayInCache<int64_t>;
|
||||
template class arangodb::pregel::ArrayInCache<float>;
|
||||
template class arangodb::pregel::ArrayInCache<double>;
|
||||
//template class arangodb::pregel::CombiningInCache<SenderValue<int64_t>>;
|
||||
template class arangodb::pregel::CombiningInCache<int64_t>;
|
||||
template class arangodb::pregel::CombiningInCache<float>;
|
||||
template class arangodb::pregel::CombiningInCache<double>;
|
||||
|
||||
// algo specific
|
||||
template class arangodb::pregel::InCache<SenderMessage<uint64_t>>;
|
||||
template class arangodb::pregel::ArrayInCache<SenderMessage<uint64_t>>;
|
||||
template class arangodb::pregel::CombiningInCache<SenderMessage<uint64_t>>;
|
||||
|
|
|
@ -53,7 +53,8 @@ class MessageIterator {
|
|||
it._current = it._size;
|
||||
return it;
|
||||
}
|
||||
const M* operator*() const { return _data + _current; }
|
||||
|
||||
M const* operator*() const { return _data + _current; }
|
||||
|
||||
// prefix ++
|
||||
MessageIterator& operator++() {
|
||||
|
|
|
@ -65,17 +65,17 @@ class MasterContext {
|
|||
|
||||
/// @brief called before supersteps
|
||||
/// @return true to continue the computation
|
||||
virtual void preGlobalSuperstep(uint64_t gss) {};
|
||||
virtual void preGlobalSuperstep() {};
|
||||
/// @brief called after supersteps
|
||||
/// @return true to continue the computation
|
||||
virtual bool postGlobalSuperstep(uint64_t gss) { return true; };
|
||||
virtual bool postGlobalSuperstep() { return true; };
|
||||
virtual void postApplication(){};
|
||||
|
||||
/// should indicate if compensation is supposed to start by returning true
|
||||
virtual bool preCompensation(uint64_t gss) { return true; }
|
||||
virtual bool preCompensation() { return true; }
|
||||
/// should indicate if compensation is finished, by returning false.
|
||||
/// otherwise workers will be called again with the aggregated values
|
||||
virtual bool postCompensation(uint64_t gss) { return false; }
|
||||
virtual bool postCompensation() { return false; }
|
||||
|
||||
};
|
||||
}
|
||||
|
|
|
@ -50,7 +50,8 @@ struct IntegerMessageFormat : public MessageFormat<int64_t> {
|
|||
arrayBuilder.add(VPackValue(val));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
struct DoubleMessageFormat : public MessageFormat<double> {
|
||||
DoubleMessageFormat() {}
|
||||
void unwrapValue(VPackSlice s, double& value) const override {
|
||||
|
@ -69,7 +70,7 @@ struct FloatMessageFormat : public MessageFormat<float> {
|
|||
void addValue(VPackBuilder& arrayBuilder, float const& val) const override {
|
||||
arrayBuilder.add(VPackValue(val));
|
||||
}
|
||||
};
|
||||
};*/
|
||||
|
||||
template <typename M>
|
||||
struct NumberMessageFormat : public MessageFormat<M> {
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "Pregel/IncomingCache.h"
|
||||
#include "Pregel/Utils.h"
|
||||
#include "Pregel/WorkerConfig.h"
|
||||
//#include "Pregel/AdditionalFormats.h"
|
||||
#include "Pregel/CommonFormats.h"
|
||||
|
||||
#include "Basics/MutexLocker.h"
|
||||
#include "Basics/StaticStrings.h"
|
||||
|
@ -258,15 +258,18 @@ void CombiningOutCache<M>::flushMessages() {
|
|||
}
|
||||
|
||||
// template types to create
|
||||
//template class arangodb::pregel::OutCache<SenderValue<int64_t>>;
|
||||
|
||||
template class arangodb::pregel::OutCache<int64_t>;
|
||||
template class arangodb::pregel::OutCache<float>;
|
||||
template class arangodb::pregel::OutCache<double>;
|
||||
//template class arangodb::pregel::ArrayOutCache<SenderValue<int64_t>>;
|
||||
template class arangodb::pregel::ArrayOutCache<int64_t>;
|
||||
template class arangodb::pregel::ArrayOutCache<float>;
|
||||
template class arangodb::pregel::ArrayOutCache<double>;
|
||||
//template class arangodb::pregel::CombiningOutCache<SenderValue<int64_t>>;
|
||||
template class arangodb::pregel::CombiningOutCache<int64_t>;
|
||||
template class arangodb::pregel::CombiningOutCache<float>;
|
||||
template class arangodb::pregel::CombiningOutCache<double>;
|
||||
|
||||
// algo specific
|
||||
template class arangodb::pregel::OutCache<SenderMessage<uint64_t>>;
|
||||
template class arangodb::pregel::ArrayOutCache<SenderMessage<uint64_t>>;
|
||||
template class arangodb::pregel::CombiningOutCache<SenderMessage<uint64_t>>;
|
||||
|
|
|
@ -81,6 +81,7 @@ class VertexContext {
|
|||
|
||||
void voteHalt() { _vertexEntry->setActive(false); }
|
||||
void voteActive() { _vertexEntry->setActive(true); }
|
||||
bool isActive() { return _vertexEntry->active(); }
|
||||
|
||||
inline uint64_t globalSuperstep() const { return _gss; }
|
||||
inline uint64_t localSuperstep() const { return _lss; }
|
||||
|
@ -101,6 +102,10 @@ class VertexComputation : public VertexContext<V, E, M> {
|
|||
_cache->appendMessage(edge->targetShard(), edge->toKey(), data);
|
||||
}
|
||||
|
||||
void sendMessage(PregelID const& pid, M const& data) {
|
||||
_cache->appendMessage(pid.shard, pid.key, data);
|
||||
}
|
||||
|
||||
// TODO optimize outgoing cache somehow
|
||||
void sendMessageToAllEdges(M const& data) {
|
||||
RangeIterator<Edge<E>> edges = this->getEdges();
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include "Pregel/Utils.h"
|
||||
#include "Pregel/VertexComputation.h"
|
||||
#include "Pregel/WorkerConfig.h"
|
||||
//#include "Pregel/AdditionalFormats.h"
|
||||
#include "Pregel/CommonFormats.h"
|
||||
|
||||
#include "Basics/MutexLocker.h"
|
||||
#include "Basics/ReadLocker.h"
|
||||
|
@ -662,5 +662,5 @@ Worker<V, E, M>::_callConductorWithResponse(std::string const& path,
|
|||
template class arangodb::pregel::Worker<int64_t, int64_t, int64_t>;
|
||||
template class arangodb::pregel::Worker<float, float, float>;
|
||||
template class arangodb::pregel::Worker<double, float, double>;
|
||||
// complex types
|
||||
//template class arangodb::pregel::Worker<int64_t, int64_t, SenderValue<int64_t>>;
|
||||
// custom algorihm types
|
||||
template class arangodb::pregel::Worker<SCCValue, int32_t, SenderMessage<uint64_t>>;
|
||||
|
|
Loading…
Reference in New Issue