//////////////////////////////////////////////////////////////////////////////// /// DISCLAIMER /// /// Copyright 2016 ArangoDB GmbH, Cologne, Germany /// /// Licensed under the Apache License, Version 2.0 (the "License"); /// you may not use this file except in compliance with the License. /// You may obtain a copy of the License at /// /// http://www.apache.org/licenses/LICENSE-2.0 /// /// Unless required by applicable law or agreed to in writing, software /// distributed under the License is distributed on an "AS IS" BASIS, /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. /// See the License for the specific language governing permissions and /// limitations under the License. /// /// Copyright holder is ArangoDB GmbH, Cologne, Germany /// /// @author Simon Grätzer //////////////////////////////////////////////////////////////////////////////// #include "PageRank.h" #include "Pregel/Aggregator.h" #include "Pregel/GraphFormat.h" #include "Pregel/Iterators.h" #include "Pregel/MasterContext.h" #include "Pregel/Utils.h" #include "Pregel/VertexComputation.h" using namespace arangodb; using namespace arangodb::pregel; using namespace arangodb::pregel::algos; static float EPS = 0.00001; static std::string const kConvergence = "convergence"; PageRank::PageRank(VPackSlice const& params) : SimpleAlgorithm("PageRank", params) { _maxGSS = basics::VelocyPackHelper::getNumericValue(params, "maxIterations", 250); } struct PRComputation : public VertexComputation { PRComputation() {} void compute(MessageIterator const& messages) override { float* ptr = mutableVertexData(); float copy = *ptr; if (globalSuperstep() == 0) { *ptr = 1.0 / context()->vertexCount(); } else { float sum = 0.0; for (const float* msg : messages) { sum += *msg; } *ptr = 0.85 * sum + 0.15 / context()->vertexCount(); } float diff = fabs(copy - *ptr); aggregate(kConvergence, diff); RangeIterator> edges = getEdges(); float val = *ptr / edges.size(); for (Edge* edge : edges) { sendMessage(edge, val); } } }; VertexComputation* PageRank::createComputation( WorkerConfig const* config) const { return new PRComputation(); } IAggregator* PageRank::aggregator(std::string const& name) const { if (name == kConvergence) { return new MaxAggregator(-1, false); } return nullptr; } struct PRMasterContext : public MasterContext { float _threshold = EPS; PRMasterContext(VPackSlice params) { VPackSlice t = params.get("threshold"); _threshold = t.isNumber() ? t.getNumber() : EPS; } void preApplication() override { LOG_TOPIC(INFO, Logger::PREGEL) << "Using threshold " << _threshold; }; bool postGlobalSuperstep() override { float const* diff = getAggregatedValue(kConvergence); return globalSuperstep() < 1 || *diff > _threshold; }; }; MasterContext* PageRank::masterContext(VPackSlice userParams) const { return new PRMasterContext(userParams); }