1
0
Fork 0

Working on PageRank and SCC

This commit is contained in:
Simon Grätzer 2017-01-24 16:37:21 +01:00
parent 1b65ff07b8
commit afaab2e8d5
22 changed files with 445 additions and 108 deletions

View File

@ -369,6 +369,7 @@ SET(ARANGOD_SOURCES
Pregel/Algos/RecoveringPageRank.cpp
Pregel/Algos/LineRank.cpp
Pregel/Algos/ConnectedComponents.cpp
Pregel/Algos/SCC.cpp
Pregel/Conductor.cpp
Pregel/GraphStore.cpp
Pregel/IncomingCache.cpp

View File

@ -113,6 +113,7 @@ struct ValueAggregator : public NumberAggregator<T> {
void aggregate(void const* valuePtr) override { this->_value = *((T*)valuePtr); };
void parse(VPackSlice slice) override {this-> _value = slice.getNumber<T>(); }
};
}
}
#endif

View File

@ -27,6 +27,7 @@
#include "Pregel/Algos/ShortestPath.h"
#include "Pregel/Algos/LineRank.h"
#include "Pregel/Algos/ConnectedComponents.h"
#include "Pregel/Algos/SCC.h"
#include "Pregel/Utils.h"
using namespace arangodb;
@ -46,6 +47,8 @@ IAlgorithm* AlgoRegistry::createAlgorithm(std::string const& algorithm,
return new algos::LineRank(userParams);
} else if (algorithm == "connectedcomponents") {
return new algos::ConnectedComponents(userParams);
} else if (algorithm == "scc") {
return new algos::SCC(userParams);
} else {
THROW_ARANGO_EXCEPTION_MESSAGE(TRI_ERROR_BAD_PARAMETER,
"Unsupported Algorithm");
@ -85,6 +88,8 @@ IWorker* AlgoRegistry::createWorker(TRI_vocbase_t* vocbase,
return createWorker(vocbase, new algos::LineRank(userParams), body);
} else if (algorithm == "connectedcomponents") {
return createWorker(vocbase, new algos::ConnectedComponents(userParams), body);
} else if (algorithm == "scc") {
return createWorker(vocbase, new algos::SCC(userParams), body);
} else {
THROW_ARANGO_EXCEPTION_MESSAGE(TRI_ERROR_BAD_PARAMETER,
"Unsupported Algorithm");

View File

@ -84,7 +84,9 @@ struct Algorithm : IAlgorithm {
}
virtual GraphFormat* inputFormat() const = 0;
virtual MessageFormat<M>* messageFormat() const = 0;
virtual MessageCombiner<M>* messageCombiner() const = 0;
virtual MessageCombiner<M>* messageCombiner() const {
return nullptr;
};
virtual VertexComputation<V, E, M>* createComputation(
WorkerConfig const*) const = 0;
virtual VertexCompensation<V, E, M>* createCompensation(

View File

@ -38,7 +38,7 @@ struct ConnectedComponents : public SimpleAlgorithm<int64_t, int64_t, int64_t> {
public:
ConnectedComponents(VPackSlice userParams) : SimpleAlgorithm("ConnectedComponents", userParams) {}
bool supportsAsyncMode() const override { return false; }
bool supportsAsyncMode() const override { return true; }
bool supportsCompensation() const override { return true; }
GraphFormat* inputFormat() const override;

View File

@ -44,7 +44,7 @@ struct LineRank : public SimpleAlgorithm<float, float, float> {
return new VertexGraphFormat<float, float>(_resultField, -1.0);
}
MessageFormat<float>* messageFormat() const override {
return new FloatMessageFormat();
return new NumberMessageFormat<float>();
}
MessageCombiner<float>* messageCombiner() const override {

View File

@ -39,63 +39,56 @@ using namespace arangodb;
using namespace arangodb::pregel;
using namespace arangodb::pregel::algos;
static float EPS = 0.00001;
static std::string const kConvergence = "convergence";
static double EPS = 0.00001;
PageRank::PageRank(arangodb::velocypack::Slice params)
: SimpleAlgorithm("PageRank", params) {
VPackSlice t = params.get("convergenceThreshold");
_threshold = t.isNumber() ? t.getNumber<double>() : EPS;
}
struct PRComputation : public VertexComputation<double, float, double> {
struct PRComputation : public VertexComputation<float, float, float> {
PRComputation() {}
void compute(MessageIterator<double> const& messages) override {
double* ptr = mutableVertexData();
double copy = *ptr;
void compute(MessageIterator<float> const& messages) override {
float* ptr = mutableVertexData();
float copy = *ptr;
if (globalSuperstep() == 0) {
*ptr = 1.0 / context()->vertexCount();
} else {
double sum = 0.0;
for (const double* msg : messages) {
float sum = 0.0;
for (const float* msg : messages) {
sum += *msg;
}
*ptr = 0.85 * sum + 0.15 / context()->vertexCount();
}
double diff = fabs(copy - *ptr);
float diff = fabs(copy - *ptr);
aggregate(kConvergence, diff);
if (globalSuperstep() < 50) {
RangeIterator<Edge<float>> edges = getEdges();
double val = *ptr / edges.size();
for (Edge<float>* edge : edges) {
sendMessage(edge, val);
}
} else {
voteHalt();
RangeIterator<Edge<float>> edges = getEdges();
float val = *ptr / edges.size();
for (Edge<float>* edge : edges) {
sendMessage(edge, val);
}
}
};
VertexComputation<double, float, double>* PageRank::createComputation(
VertexComputation<float, float, float>* PageRank::createComputation(
WorkerConfig const* config) const {
return new PRComputation();
}
IAggregator* PageRank::aggregator(std::string const& name) const {
if (name == kConvergence) {
return new MaxAggregator<double>(-1.0f, false);
return new MaxAggregator<float>(MAXFLOAT, false);
}
return nullptr;
}
struct MyMasterContext : MasterContext {
MyMasterContext() {}// TODO use _threashold
bool postGlobalSuperstep(uint64_t gss) {
double const* diff = getAggregatedValue<double>(kConvergence);
return globalSuperstep() < 2 || *diff > EPS;
struct MyMasterContext : public MasterContext {
MyMasterContext() {
//VPackSlice t = params.get("convergenceThreshold");
//_threshold = t.isNumber() ? t.getNumber<float>() : EPS;
}// TODO use _threashold
bool postGlobalSuperstep() override {
float const* diff = getAggregatedValue<float>(kConvergence);
return globalSuperstep() < 50 && *diff > EPS;
};
};

View File

@ -31,25 +31,24 @@ namespace pregel {
namespace algos {
/// PageRank
struct PageRank : public SimpleAlgorithm<double, float, double> {
float _threshold;
struct PageRank : public SimpleAlgorithm<float, float, float> {
public:
PageRank(arangodb::velocypack::Slice params);
PageRank(arangodb::velocypack::Slice params)
: SimpleAlgorithm("PageRank", params) {}
GraphFormat* inputFormat() const override {
return new VertexGraphFormat<double, float>(_resultField, 0);
return new VertexGraphFormat<float, float>(_resultField, 0);
}
MessageFormat<double>* messageFormat() const override {
return new DoubleMessageFormat();
MessageFormat<float>* messageFormat() const override {
return new NumberMessageFormat<float>();
}
MessageCombiner<double>* messageCombiner() const override {
return new SumCombiner<double>();
MessageCombiner<float>* messageCombiner() const override {
return new SumCombiner<float>();
}
VertexComputation<double, float, double>* createComputation(
VertexComputation<float, float, float>* createComputation(
WorkerConfig const*) const override;
IAggregator* aggregator(std::string const& name) const override;

View File

@ -39,6 +39,7 @@ using namespace arangodb;
using namespace arangodb::pregel;
using namespace arangodb::pregel::algos;
static float EPS = 0.00001;
static std::string const kConvergence = "convergence";
static std::string const kStep = "step";
static std::string const kRank = "rank";
@ -47,15 +48,9 @@ static std::string const kNonFailedCount = "nonfailedCount";
static std::string const kScale = "scale";
RecoveringPageRank::RecoveringPageRank(arangodb::velocypack::Slice params)
: SimpleAlgorithm("PageRank", params) {
VPackSlice t = params.get("convergenceThreshold");
_threshold = t.isNumber() ? t.getNumber<float>() : 0.000002f;
}
struct RPRComputation : public VertexComputation<float, float, float> {
float _limit;
RPRComputation(float t) : _limit(t) {}
struct MyComputation : public VertexComputation<float, float, float> {
MyComputation() {}
void compute(MessageIterator<float> const& messages) override {
float* ptr = mutableVertexData();
float copy = *ptr;
@ -72,26 +67,18 @@ struct RPRComputation : public VertexComputation<float, float, float> {
float diff = fabsf(copy - *ptr);
aggregate(kConvergence, diff);
aggregate(kRank, ptr);
// const float* val = getAggregatedValue<float>("convergence");
// if (val) { // if global convergence is available use it
// diff = *val;
//}
if (globalSuperstep() < 50 && diff > _limit) {
RangeIterator<Edge<float>> edges = getEdges();
float val = *ptr / edges.size();
for (Edge<float>* edge : edges) {
sendMessage(edge, val);
}
} else {
voteHalt();
RangeIterator<Edge<float>> edges = getEdges();
float val = *ptr / edges.size();
for (Edge<float>* edge : edges) {
sendMessage(edge, val);
}
}
};
VertexComputation<float, float, float>* RecoveringPageRank::createComputation(
WorkerConfig const* config) const {
return new RPRComputation(_threshold);
return new MyComputation();
}
IAggregator* RecoveringPageRank::aggregator(std::string const& name) const {
@ -140,24 +127,31 @@ VertexCompensation<float, float, float>* RecoveringPageRank::createCompensation(
}
struct MyMasterContext : public MasterContext {
MyMasterContext(VPackSlice params) {};
float _threshold;
MyMasterContext(VPackSlice params) {
VPackSlice t = params.get("convergenceThreshold");
_threshold = t.isNumber() ? t.getNumber<float>() : EPS;
};
int32_t recoveryStep = 0;
float totalRank = 0;
bool postGlobalSuperstep(uint64_t gss) override {
bool postGlobalSuperstep() override {
const float* convergence = getAggregatedValue<float>(kConvergence);
LOG(INFO) << "Current convergence level" << *convergence;
totalRank = *getAggregatedValue<float>(kRank);
return true;
float const* diff = getAggregatedValue<float>(kConvergence);
return globalSuperstep() < 50 && *diff > _threshold;
}
bool preCompensation(uint64_t gss) override {
bool preCompensation() override {
aggregate(kStep, recoveryStep);
return totalRank != 0;
}
bool postCompensation(uint64_t gss) override {
bool postCompensation() override {
if (recoveryStep == 0) {
recoveryStep = 1;

View File

@ -32,10 +32,9 @@ namespace algos {
/// PageRank
struct RecoveringPageRank : public SimpleAlgorithm<float, float, float> {
float _threshold;
public:
RecoveringPageRank(arangodb::velocypack::Slice params);
RecoveringPageRank(arangodb::velocypack::Slice params)
: SimpleAlgorithm("PageRank", params) {}
bool supportsCompensation() const override { return true; }
MasterContext* masterContext(VPackSlice userParams) const override;
@ -45,7 +44,7 @@ struct RecoveringPageRank : public SimpleAlgorithm<float, float, float> {
}
MessageFormat<float>* messageFormat() const override {
return new FloatMessageFormat();
return new NumberMessageFormat<float>();
}
MessageCombiner<float>* messageCombiner() const override {

View File

@ -0,0 +1,251 @@
////////////////////////////////////////////////////////////////////////////////
/// DISCLAIMER
///
/// Copyright 2016 ArangoDB GmbH, Cologne, Germany
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
///
/// @author Simon Grätzer
////////////////////////////////////////////////////////////////////////////////
#include "SCC.h"
#include "Pregel/Aggregator.h"
#include "Pregel/Algorithm.h"
#include "Pregel/GraphStore.h"
#include "Pregel/IncomingCache.h"
#include "Pregel/VertexComputation.h"
#include "Cluster/ClusterInfo.h"
#include "Cluster/ServerState.h"
#include "Pregel/MasterContext.h"
using namespace arangodb::pregel;
using namespace arangodb::pregel::algos;
static std::string const kPhase = "phase";
static std::string const kFoundNewMax = "max";
static std::string const kConverged = "converged";
enum SCCPhase {
TRANSPOSE = 0,
TRIMMING = 1,
FORWARD_TRAVERSAL = 2,
BACKWARD_TRAVERSAL_START = 3,
BACKWARD_TRAVERSAL_REST = 4
};
struct MyComputation : public VertexComputation<SCCValue, int32_t, SenderMessage<uint64_t>> {
MyComputation() {}
void compute(MessageIterator<SenderMessage<uint64_t>> const& messages) override {
if (isActive() == false) {
// color was already determinded or vertex was trimmed
return;
}
SCCValue *vertexState = mutableVertexData();
int64_t const* phase = getAggregatedValue<int64_t>(kPhase);
switch (*phase) {
// let all our connected nodes know we are there
case SCCPhase::TRANSPOSE: {
vertexState->parents.clear();
SenderMessage<uint64_t> message(pregelId(), 0);
sendMessageToAllEdges(message);
break;
}
// Creates list of parents based on the received ids and halts the vertices
// that don't have any parent or outgoing edge, hence, they can't be
// part of an SCC.
case SCCPhase::TRIMMING:{
for (SenderMessage<uint64_t> const* msg : messages) {
vertexState->parents.push_back(msg->pregelId);
}
// reset the color for vertices which are not active
vertexState->color = vertexState->vertexID;
// If this node doesn't have any parents or outgoing edges,
// it can't be part of an SCC
RangeIterator<Edge<int32_t>> edges = getEdges();
if (vertexState->parents.size() == 0 || edges.size() == 0) {
voteHalt();
} else {
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
sendMessageToAllEdges(message);
}
break;
}
case SCCPhase::FORWARD_TRAVERSAL:{
uint64_t old = vertexState->color;
for (SenderMessage<uint64_t> const* msg : messages) {
if (vertexState->color < msg->value) {
vertexState->color = msg->value;
}
}
if (old != vertexState->color) {
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
sendMessageToAllEdges(message);
aggregate(kFoundNewMax, true);
}
break;
}
case SCCPhase::BACKWARD_TRAVERSAL_START:{
// if I am the 'root' of a SCC start traversak
if (vertexState->vertexID == vertexState->color) {
SenderMessage<uint64_t> message(pregelId(), vertexState->color);
// sendMessageToAllParents
for (PregelID const& pid : vertexState->parents) {
sendMessage(pid, message);
}
}
break;
}
case SCCPhase::BACKWARD_TRAVERSAL_REST:{
for (SenderMessage<uint64_t> const* msg : messages) {
if (vertexState->color == msg->value) {
for (PregelID const& pid : vertexState->parents) {
sendMessage(pid, *msg);
}
aggregate(kConverged, true);
voteHalt();
}
}
break;
}
}
}
};
VertexComputation<SCCValue, int32_t,
SenderMessage<uint64_t>>* SCC::createComputation(
WorkerConfig const* config) const {
return new MyComputation();
}
struct MyGraphFormat : public GraphFormat {
const std::string _resultField;
uint64_t vertexIdRange = 0;
MyGraphFormat(std::string const& result) : _resultField(result) {}
void willLoadVertices(uint64_t count) override {
// if we aren't running in a cluster it doesn't matter
if (arangodb::ServerState::instance()->isRunningInCluster()) {
arangodb::ClusterInfo *ci = arangodb::ClusterInfo::instance();
if (ci) {
vertexIdRange = ci->uniqid(count);
}
}
}
size_t estimatedVertexSize() const override { return sizeof(SCCValue); };
size_t estimatedEdgeSize() const override { return 0; };
size_t copyVertexData(VertexEntry const& vertex,
std::string const& documentId,
arangodb::velocypack::Slice document,
void* targetPtr, size_t maxSize) override {
SCCValue *senders = (SCCValue*) targetPtr;
senders->vertexID = vertexIdRange++;
return sizeof(SCCValue);
}
size_t copyEdgeData(arangodb::velocypack::Slice document, void* targetPtr,
size_t maxSize) override {
return 0;
}
bool buildVertexDocument(arangodb::velocypack::Builder& b,
const void* ptr,
size_t size) override {
SCCValue *senders = (SCCValue*) ptr;
b.add(_resultField, VPackValue(senders->color));
return true;
}
bool buildEdgeDocument(arangodb::velocypack::Builder& b, const void* ptr,
size_t size) override {
return false;
}
};
GraphFormat* SCC::inputFormat() const {
return new MyGraphFormat(_resultField);
}
struct MyMasterContext : public MasterContext {
MyMasterContext() {}// TODO use _threashold
void preGlobalSuperstep() override {
if (globalSuperstep() == 0) {
return;
}
int64_t const* phase = getAggregatedValue<int64_t>(kPhase);
switch (*phase) {
case SCCPhase::TRANSPOSE:
aggregate(kPhase, SCCPhase::TRIMMING);
break;
case SCCPhase::TRIMMING:
aggregate(kPhase, SCCPhase::FORWARD_TRAVERSAL);
break;
case SCCPhase::FORWARD_TRAVERSAL: {
bool const* newMaxFound = getAggregatedValue<bool>(kFoundNewMax);
if (*newMaxFound == false) {
aggregate(kPhase, SCCPhase::BACKWARD_TRAVERSAL_START);
}
}
break;
case SCCPhase::BACKWARD_TRAVERSAL_START:
aggregate(kPhase, SCCPhase::BACKWARD_TRAVERSAL_REST);
break;
case SCCPhase::BACKWARD_TRAVERSAL_REST:
bool const* converged = getAggregatedValue<bool>(kConverged);
// continue until no more vertices are updated
if (*converged == false) {
aggregate(kPhase, SCCPhase::TRANSPOSE);
}
break;
}
};
};
MasterContext* SCC::masterContext(VPackSlice userParams) const {
return new MyMasterContext();
}
IAggregator* SCC::aggregator(std::string const& name) const {
if (name == kPhase) {
return new ValueAggregator<int64_t>(SCCPhase::TRANSPOSE, true);
} else if (name == kFoundNewMax) {
return new ValueAggregator<bool>(false, false);
} else if (name == kConverged) {
return new ValueAggregator<bool>(false, true);
}
return nullptr;
}

View File

@ -0,0 +1,63 @@
////////////////////////////////////////////////////////////////////////////////
/// DISCLAIMER
///
/// Copyright 2016 ArangoDB GmbH, Cologne, Germany
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
///
/// @author Simon Grätzer
////////////////////////////////////////////////////////////////////////////////
#ifndef ARANGODB_PREGEL_ALGOS_SCC_H
#define ARANGODB_PREGEL_ALGOS_SCC_H 1
#include "Pregel/Algorithm.h"
#include "Pregel/CommonFormats.h"
namespace arangodb {
namespace pregel {
namespace algos {
/// Finds strongly connected components of the graph.
///
/// 1. Each vertex starts with its vertex id as its "color".
/// 2. Remove vertices which cannot be in a SCC (no incoming or no outgoing edges)
/// 3. Propagate the color forward from each vertex, accept a neighbor's color if it's smaller than yours.
/// At convergence, vertices with the same color represents all nodes that are visitable from the root of that color.
/// 4. Reverse the graph.
/// 5. Start at all roots, walk the graph. Visit a neighbor if it has the same color as you.
/// All nodes visited belongs to the SCC identified by the root color.
struct SCC : public SimpleAlgorithm<SCCValue, int32_t, SenderMessage<uint64_t>> {
public:
SCC(VPackSlice userParams)
: SimpleAlgorithm<SCCValue, int32_t, SenderMessage<uint64_t>>("SCC", userParams) {}
GraphFormat* inputFormat() const override;
MessageFormat<SenderMessage<uint64_t>>* messageFormat() const override {
return new SenderMessageFormat<uint64_t>();
}
VertexComputation<SCCValue, int32_t, SenderMessage<uint64_t>>*
createComputation( WorkerConfig const*) const override;
MasterContext* masterContext(VPackSlice userParams) const override;
IAggregator* aggregator(std::string const& name) const override;
};
}
}
}
#endif

View File

@ -20,34 +20,48 @@
/// @author Simon Grätzer
////////////////////////////////////////////////////////////////////////////////
#ifndef ARANGODB_PREGEL_ADDITIONAL_MFORMATS_H
#define ARANGODB_PREGEL_ADDITIONAL_MFORMATS_H 1
// NOTE: this files exists primarily to include these algorithm specfic structs in the
// cpp files to do template specialization
#ifndef ARANGODB_PREGEL_COMMON_MFORMATS_H
#define ARANGODB_PREGEL_COMMON_MFORMATS_H 1
#include "Pregel/GraphFormat.h"
#include "Pregel/MessageFormat.h"
#include "Pregel/Graph.h"
namespace arangodb {
namespace pregel {
struct SCCValue {
std::vector<PregelID> parents;
uint64_t vertexID;
uint64_t color;
};
template<typename T>
struct SenderValue {
struct SenderMessage {
SenderMessage() {}
SenderMessage(PregelID const& pid, T const& val)
: pregelId(pid), value(val) {}
PregelID pregelId;
T value;
};
template <typename T>
struct NumberSenderFormat : public MessageFormat<SenderValue<T>> {
struct SenderMessageFormat : public MessageFormat<SenderMessage<T>> {
static_assert(std::is_arithmetic<T>::value, "Message type must be numeric");
NumberSenderFormat() {}
void unwrapValue(VPackSlice s, SenderValue<T>& senderVal) const override {
SenderMessageFormat() {}
void unwrapValue(VPackSlice s, SenderMessage<T>& senderVal) const override {
VPackArrayIterator array(s);
senderVal.pregelId.shard = (*array).getUInt();
senderVal.pregelId.key = (*(++array)).copyString();
senderVal.value = (*(++array)).getNumber<T>();
}
void addValue(VPackBuilder& arrayBuilder, SenderValue<T> const& senderVal) const override {
void addValue(VPackBuilder& arrayBuilder, SenderMessage<T> const& senderVal) const override {
arrayBuilder.openArray();
arrayBuilder.add(VPackValue(senderVal.pregelId.shard));
arrayBuilder.add(VPackValue(senderVal.pregelId.key));

View File

@ -141,7 +141,7 @@ bool Conductor::_startGlobalStep() {
bool proceed = true;
if (_masterContext && _globalSuperstep > 0) { // ask algorithm to evaluate aggregated values
_masterContext->_globalSuperstep = _globalSuperstep - 1;
proceed = _masterContext->postGlobalSuperstep(_globalSuperstep);
proceed = _masterContext->postGlobalSuperstep();
if (!proceed) {
LOG(INFO) << "Master context ended execution";
}
@ -159,7 +159,7 @@ bool Conductor::_startGlobalStep() {
_masterContext->_globalSuperstep = _globalSuperstep;
_masterContext->_vertexCount = _totalVerticesCount;
_masterContext->_edgeCount = _totalEdgesCount;
_masterContext->preGlobalSuperstep(_globalSuperstep);
_masterContext->preGlobalSuperstep();
}
b.clear();
@ -291,7 +291,7 @@ void Conductor::finishedRecoveryStep(VPackSlice data) {
// only compensations supported
bool proceed = false;
if (_masterContext) {
proceed = proceed || _masterContext->postCompensation(_globalSuperstep);
proceed = proceed || _masterContext->postCompensation();
}
int res = TRI_ERROR_NO_ERROR;
@ -299,7 +299,7 @@ void Conductor::finishedRecoveryStep(VPackSlice data) {
// reset values which are calculated during the superstep
_aggregators->resetValues();
if (_masterContext) {
_masterContext->preCompensation(_globalSuperstep);
_masterContext->preCompensation();
}
VPackBuilder b;
@ -401,7 +401,7 @@ void Conductor::startRecovery() {
// Let's try recovery
if (_masterContext) {
bool proceed = _masterContext->preCompensation(_globalSuperstep);
bool proceed = _masterContext->preCompensation();
if (!proceed) {
cancel();
}

View File

@ -28,7 +28,8 @@
#include "Indexes/EdgeIndex.h"
#include "Indexes/Index.h"
#include "Pregel/WorkerConfig.h"
#include "Utils.h"
#include "Pregel/Utils.h"
#include "Pregel/CommonFormats.h"
#include "Utils/CollectionNameResolver.h"
#include "Utils/ExplicitTransaction.h"
#include "Utils/OperationCursor.h"
@ -439,3 +440,8 @@ template class arangodb::pregel::GraphStore<float, float>;
template class arangodb::pregel::GraphStore<double, float>;
template class arangodb::pregel::GraphStore<double, double>;
// specific algo combos
template class arangodb::pregel::GraphStore<SCCValue, int32_t>;

View File

@ -22,7 +22,7 @@
#include "Pregel/IncomingCache.h"
#include "Pregel/Utils.h"
//#include "Pregel/AdditionalFormats.h"
#include "Pregel/CommonFormats.h"
#include "Basics/MutexLocker.h"
#include "Basics/StaticStrings.h"
@ -31,9 +31,6 @@
#include <velocypack/Iterator.h>
#include <velocypack/velocypack-aliases.h>
//#include <libcuckoo/city_hasher.hh>
//#include <libcuckoo/cuckoohash_map.hh>
using namespace arangodb;
using namespace arangodb::pregel;
@ -240,15 +237,17 @@ void CombiningInCache<M>::forEach(
}
// template types to create
//template class arangodb::pregel::InCache<SenderValue<int64_t>>;
template class arangodb::pregel::InCache<int64_t>;
template class arangodb::pregel::InCache<float>;
template class arangodb::pregel::InCache<double>;
//template class arangodb::pregel::ArrayInCache<SenderValue<int64_t>>;
template class arangodb::pregel::ArrayInCache<int64_t>;
template class arangodb::pregel::ArrayInCache<float>;
template class arangodb::pregel::ArrayInCache<double>;
//template class arangodb::pregel::CombiningInCache<SenderValue<int64_t>>;
template class arangodb::pregel::CombiningInCache<int64_t>;
template class arangodb::pregel::CombiningInCache<float>;
template class arangodb::pregel::CombiningInCache<double>;
// algo specific
template class arangodb::pregel::InCache<SenderMessage<uint64_t>>;
template class arangodb::pregel::ArrayInCache<SenderMessage<uint64_t>>;
template class arangodb::pregel::CombiningInCache<SenderMessage<uint64_t>>;

View File

@ -53,7 +53,8 @@ class MessageIterator {
it._current = it._size;
return it;
}
const M* operator*() const { return _data + _current; }
M const* operator*() const { return _data + _current; }
// prefix ++
MessageIterator& operator++() {

View File

@ -65,17 +65,17 @@ class MasterContext {
/// @brief called before supersteps
/// @return true to continue the computation
virtual void preGlobalSuperstep(uint64_t gss) {};
virtual void preGlobalSuperstep() {};
/// @brief called after supersteps
/// @return true to continue the computation
virtual bool postGlobalSuperstep(uint64_t gss) { return true; };
virtual bool postGlobalSuperstep() { return true; };
virtual void postApplication(){};
/// should indicate if compensation is supposed to start by returning true
virtual bool preCompensation(uint64_t gss) { return true; }
virtual bool preCompensation() { return true; }
/// should indicate if compensation is finished, by returning false.
/// otherwise workers will be called again with the aggregated values
virtual bool postCompensation(uint64_t gss) { return false; }
virtual bool postCompensation() { return false; }
};
}

View File

@ -50,7 +50,8 @@ struct IntegerMessageFormat : public MessageFormat<int64_t> {
arrayBuilder.add(VPackValue(val));
}
};
/*
struct DoubleMessageFormat : public MessageFormat<double> {
DoubleMessageFormat() {}
void unwrapValue(VPackSlice s, double& value) const override {
@ -69,7 +70,7 @@ struct FloatMessageFormat : public MessageFormat<float> {
void addValue(VPackBuilder& arrayBuilder, float const& val) const override {
arrayBuilder.add(VPackValue(val));
}
};
};*/
template <typename M>
struct NumberMessageFormat : public MessageFormat<M> {

View File

@ -24,7 +24,7 @@
#include "Pregel/IncomingCache.h"
#include "Pregel/Utils.h"
#include "Pregel/WorkerConfig.h"
//#include "Pregel/AdditionalFormats.h"
#include "Pregel/CommonFormats.h"
#include "Basics/MutexLocker.h"
#include "Basics/StaticStrings.h"
@ -258,15 +258,18 @@ void CombiningOutCache<M>::flushMessages() {
}
// template types to create
//template class arangodb::pregel::OutCache<SenderValue<int64_t>>;
template class arangodb::pregel::OutCache<int64_t>;
template class arangodb::pregel::OutCache<float>;
template class arangodb::pregel::OutCache<double>;
//template class arangodb::pregel::ArrayOutCache<SenderValue<int64_t>>;
template class arangodb::pregel::ArrayOutCache<int64_t>;
template class arangodb::pregel::ArrayOutCache<float>;
template class arangodb::pregel::ArrayOutCache<double>;
//template class arangodb::pregel::CombiningOutCache<SenderValue<int64_t>>;
template class arangodb::pregel::CombiningOutCache<int64_t>;
template class arangodb::pregel::CombiningOutCache<float>;
template class arangodb::pregel::CombiningOutCache<double>;
// algo specific
template class arangodb::pregel::OutCache<SenderMessage<uint64_t>>;
template class arangodb::pregel::ArrayOutCache<SenderMessage<uint64_t>>;
template class arangodb::pregel::CombiningOutCache<SenderMessage<uint64_t>>;

View File

@ -81,6 +81,7 @@ class VertexContext {
void voteHalt() { _vertexEntry->setActive(false); }
void voteActive() { _vertexEntry->setActive(true); }
bool isActive() { return _vertexEntry->active(); }
inline uint64_t globalSuperstep() const { return _gss; }
inline uint64_t localSuperstep() const { return _lss; }
@ -101,6 +102,10 @@ class VertexComputation : public VertexContext<V, E, M> {
_cache->appendMessage(edge->targetShard(), edge->toKey(), data);
}
void sendMessage(PregelID const& pid, M const& data) {
_cache->appendMessage(pid.shard, pid.key, data);
}
// TODO optimize outgoing cache somehow
void sendMessageToAllEdges(M const& data) {
RangeIterator<Edge<E>> edges = this->getEdges();

View File

@ -29,7 +29,7 @@
#include "Pregel/Utils.h"
#include "Pregel/VertexComputation.h"
#include "Pregel/WorkerConfig.h"
//#include "Pregel/AdditionalFormats.h"
#include "Pregel/CommonFormats.h"
#include "Basics/MutexLocker.h"
#include "Basics/ReadLocker.h"
@ -662,5 +662,5 @@ Worker<V, E, M>::_callConductorWithResponse(std::string const& path,
template class arangodb::pregel::Worker<int64_t, int64_t, int64_t>;
template class arangodb::pregel::Worker<float, float, float>;
template class arangodb::pregel::Worker<double, float, double>;
// complex types
//template class arangodb::pregel::Worker<int64_t, int64_t, SenderValue<int64_t>>;
// custom algorihm types
template class arangodb::pregel::Worker<SCCValue, int32_t, SenderMessage<uint64_t>>;