From ee98b59e2f8ca3fe93da5be59209e88e1ac89b35 Mon Sep 17 00:00:00 2001 From: Frank Celler Date: Fri, 3 Jun 2016 10:38:47 +0200 Subject: [PATCH] changed reload to outdated --- arangod/Cluster/HeartbeatThread.cpp | 297 ++++++++--------------- arangod/Cluster/HeartbeatThread.h | 12 - arangod/RestServer/RestServerFeature.cpp | 2 +- arangod/V8Server/v8-vocbase.cpp | 8 +- arangod/VocBase/AuthInfo.cpp | 153 +++++++----- arangod/VocBase/AuthInfo.h | 15 +- 6 files changed, 208 insertions(+), 279 deletions(-) diff --git a/arangod/Cluster/HeartbeatThread.cpp b/arangod/Cluster/HeartbeatThread.cpp index 4be0f2f164..9c10a97095 100644 --- a/arangod/Cluster/HeartbeatThread.cpp +++ b/arangod/Cluster/HeartbeatThread.cpp @@ -64,7 +64,6 @@ HeartbeatThread::HeartbeatThread(TRI_server_t* server, _statusLock(), _agency(), _condition(), - _refetchUsers(true), _myId(ServerState::instance()->getId()), _interval(interval), _maxFailsBeforeWarning(maxFailsBeforeWarning), @@ -109,13 +108,14 @@ void HeartbeatThread::run() { //////////////////////////////////////////////////////////////////////////////// void HeartbeatThread::runDBServer() { - LOG_TOPIC(TRACE, Logger::HEARTBEAT) + LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "starting heartbeat thread (DBServer version)"; - + // mop: the heartbeat thread itself is now ready setReady(); // mop: however we need to wait for the rest server here to come up - // otherwise we would already create collections and the coordinator would think + // otherwise we would already create collections and the coordinator would + // think // ohhh the dbserver is online...pump some documents into it // which fails when it is still in maintenance mode while (arangodb::rest::HttpHandlerFactory::isMaintenance()) { @@ -125,48 +125,48 @@ void HeartbeatThread::runDBServer() { // convert timeout to seconds double const interval = (double)_interval / 1000.0 / 1000.0; - std::function updatePlan = [&]( - VPackSlice const& result) { - if (!result.isNumber()) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) - << "Plan Version is not a number! " << result.toJson(); - return false; - } - uint64_t version = result.getNumber(); - - bool doSync = false; - { - MUTEX_LOCKER(mutexLocker, _statusLock); - if (version > _desiredVersions.plan) { - _desiredVersions.plan = version; - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) - << "Desired Current Version is now " << _desiredVersions.plan; - doSync = true; - } - } + std::function updatePlan = + [&](VPackSlice const& result) { + if (!result.isNumber()) { + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Plan Version is not a number! " + << result.toJson(); + return false; + } + uint64_t version = result.getNumber(); - if (doSync) { - syncDBServerStatusQuo(); - } + bool doSync = false; + { + MUTEX_LOCKER(mutexLocker, _statusLock); + if (version > _desiredVersions.plan) { + _desiredVersions.plan = version; + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + << "Desired Current Version is now " << _desiredVersions.plan; + doSync = true; + } + } + + if (doSync) { + syncDBServerStatusQuo(); + } + + return true; + }; - return true; - }; - auto planAgencyCallback = std::make_shared( _agency, "Plan/Version", updatePlan, true); - + bool registered = false; while (!registered) { registered = _agencyCallbackRegistry->registerCallback(planAgencyCallback); if (!registered) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Couldn't register plan change in agency!"; sleep(1); } } - + // we check Current/Version every few heartbeats: - int const currentCountStart = 1; // set to 1 by Max to speed up discovery + int const currentCountStart = 1; // set to 1 by Max to speed up discovery int currentCount = currentCountStart; while (!isStopping()) { @@ -186,16 +186,15 @@ void HeartbeatThread::runDBServer() { currentCount = currentCountStart; // send an initial GET request to Sync/Commands/my-id - LOG_TOPIC(TRACE, Logger::HEARTBEAT) + LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Looking at Sync/Commands/" + _myId; - AgencyCommResult result = - _agency.getValues("Sync/Commands/" + _myId); - + AgencyCommResult result = _agency.getValues("Sync/Commands/" + _myId); + if (result.successful()) { handleStateChange(result); } - + if (isStopping()) { break; } @@ -203,15 +202,14 @@ void HeartbeatThread::runDBServer() { LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Refetching Current/Version..."; AgencyCommResult res = _agency.getValues("Current/Version"); if (!res.successful()) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Could not read Current/Version from agency."; } else { - VPackSlice s - = res.slice()[0].get(std::vector( - {_agency.prefix(), std::string("Current"), - std::string("Version")})); + VPackSlice s = res.slice()[0].get( + std::vector({_agency.prefix(), std::string("Current"), + std::string("Version")})); if (!s.isInteger()) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Current/Version in agency is not an integer."; } else { uint64_t currentVersion = 0; @@ -220,14 +218,14 @@ void HeartbeatThread::runDBServer() { } catch (...) { } if (currentVersion == 0) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Current/Version in agency is 0."; } else { { MUTEX_LOCKER(mutexLocker, _statusLock); if (currentVersion > _desiredVersions.current) { _desiredVersions.current = currentVersion; - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) << "Found greater Current/Version in agency."; } } @@ -245,7 +243,7 @@ void HeartbeatThread::runDBServer() { // mop: execute at least once do { LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Entering update loop"; - + bool wasNotified; { CONDITION_LOCKER(locker, _condition); @@ -284,7 +282,7 @@ void HeartbeatThread::runDBServer() { } usleep(1000); } - LOG_TOPIC(TRACE, Logger::HEARTBEAT) + LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "stopped heartbeat thread (DBServer version)"; } @@ -293,7 +291,7 @@ void HeartbeatThread::runDBServer() { //////////////////////////////////////////////////////////////////////////////// void HeartbeatThread::runCoordinator() { - LOG_TOPIC(TRACE, Logger::HEARTBEAT) + LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "starting heartbeat thread (coordinator version)"; uint64_t oldUserVersion = 0; @@ -320,25 +318,24 @@ void HeartbeatThread::runCoordinator() { break; } - AgencyReadTransaction trx(std::vector({ - _agency.prefixPath() + "Plan/Version", - _agency.prefixPath() + "Current/Version", - _agency.prefixPath() + "Sync/Commands/" + _myId, - _agency.prefixPath() + "Sync/UserVersion"})); + AgencyReadTransaction trx(std::vector( + {_agency.prefixPath() + "Plan/Version", + _agency.prefixPath() + "Current/Version", + _agency.prefixPath() + "Sync/Commands/" + _myId, + _agency.prefixPath() + "Sync/UserVersion"})); AgencyCommResult result = _agency.sendTransactionWithFailover(trx); if (!result.successful()) { - LOG_TOPIC(WARN, Logger::HEARTBEAT) + LOG_TOPIC(WARN, Logger::HEARTBEAT) << "Heartbeat: Could not read from agency!"; } else { - LOG_TOPIC(TRACE, Logger::HEARTBEAT) + LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Looking at Sync/Commands/" + _myId; handleStateChange(result); - VPackSlice versionSlice - = result.slice()[0].get(std::vector( - {_agency.prefix(), "Plan", "Version"})); + VPackSlice versionSlice = result.slice()[0].get( + std::vector({_agency.prefix(), "Plan", "Version"})); if (versionSlice.isInteger()) { // there is a plan version @@ -346,8 +343,7 @@ void HeartbeatThread::runCoordinator() { uint64_t planVersion = 0; try { planVersion = versionSlice.getUInt(); - } - catch (...) { + } catch (...) { } if (planVersion > lastPlanVersionNoticed) { @@ -363,62 +359,34 @@ void HeartbeatThread::runCoordinator() { } } - VPackSlice slice = - result.slice()[0].get(std::vector( - {_agency.prefix(), "Sync", "UserVersion"})); + VPackSlice slice = result.slice()[0].get( + std::vector({_agency.prefix(), "Sync", "UserVersion"})); if (slice.isInteger()) { // there is a UserVersion uint64_t userVersion = 0; try { userVersion = slice.getUInt(); + } catch (...) { } - catch (...) { - } + if (userVersion > 0 && userVersion != oldUserVersion) { - // reload user cache for all databases - std::vector dbs = - ClusterInfo::instance()->databases(true); - std::vector::iterator i; - bool allOK = true; - for (i = dbs.begin(); i != dbs.end(); ++i) { - TRI_vocbase_t* vocbase = - TRI_UseCoordinatorDatabaseServer(_server, i->c_str()); - - if (vocbase != nullptr && TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) { - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) - << "Reloading users for database " << vocbase->_name - << "."; - - if (!fetchUsers()) { - // something is wrong... probably the database server - // with the _users collection is not yet available - allOK = false; - // we will not set oldUserVersion such that we will try this - // very same exercise again in the next heartbeat - } - TRI_ReleaseVocBase(vocbase); - } - } - if (allOK) { - oldUserVersion = userVersion; - } + oldUserVersion = userVersion; + RestServerFeature::AUTH_INFO.outdate(); } } - versionSlice = result.slice()[0].get(std::vector( - {_agency.prefix(), "Current", "Version"})); + versionSlice = result.slice()[0].get( + std::vector({_agency.prefix(), "Current", "Version"})); if (versionSlice.isInteger()) { - uint64_t currentVersion = 0; try { currentVersion = versionSlice.getUInt(); - } - catch (...) { + } catch (...) { } if (currentVersion > lastCurrentVersionNoticed) { LOG_TOPIC(TRACE, Logger::HEARTBEAT) - << "Found currentVersion " << currentVersion + << "Found currentVersion " << currentVersion << " which is newer than " << lastCurrentVersionNoticed; lastCurrentVersionNoticed = currentVersion; @@ -440,7 +408,6 @@ void HeartbeatThread::runCoordinator() { remain = 0.0; } } - } LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "stopped heartbeat thread"; @@ -470,8 +437,8 @@ void HeartbeatThread::removeDispatchedJob(DBServerAgencySyncResult result) { { MUTEX_LOCKER(mutexLocker, _statusLock); if (result.success) { - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) - << "Sync request successful. Now have Plan " << result.planVersion + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + << "Sync request successful. Now have Plan " << result.planVersion << ", Current " << result.currentVersion; _currentVersions = AgencyVersions(result); } else { @@ -496,44 +463,40 @@ void HeartbeatThread::removeDispatchedJob(DBServerAgencySyncResult result) { static std::string const prefixPlanChangeCoordinator = "Plan/Databases"; bool HeartbeatThread::handlePlanChangeCoordinator(uint64_t currentPlanVersion) { - bool fetchingUsersFailed = false; LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "found a plan update"; AgencyCommResult result; - + { AgencyCommLocker locker("Plan", "READ"); if (locker.successful()) { result = _agency.getValues(prefixPlanChangeCoordinator); } } - + if (result.successful()) { - std::vector ids; - velocypack::Slice databases = - result.slice()[0].get(std::vector( - {AgencyComm::prefix(), "Plan", "Databases"})); - + velocypack::Slice databases = result.slice()[0].get( + std::vector({AgencyComm::prefix(), "Plan", "Databases"})); + if (!databases.isObject()) { return false; } // loop over all database names we got and create a local database // instance if not yet present: - - for (auto const& options : VPackObjectIterator (databases)) { + for (auto const& options : VPackObjectIterator(databases)) { if (!options.value.isObject()) { continue; } auto nameSlice = options.value.get("name"); if (nameSlice.isNone()) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Missing name in agency database plan"; - continue; + continue; } - + std::string const name = options.value.get("name").copyString(); TRI_voc_tick_t id = 0; @@ -543,84 +506,66 @@ bool HeartbeatThread::handlePlanChangeCoordinator(uint64_t currentPlanVersion) { try { id = std::stoul(v.copyString()); } catch (std::exception const& e) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Failed to convert id string to number"; LOG_TOPIC(ERR, Logger::HEARTBEAT) << e.what(); } } } - + if (id > 0) { ids.push_back(id); } - + TRI_vocbase_t* vocbase = - TRI_UseCoordinatorDatabaseServer(_server, name.c_str()); - + TRI_UseCoordinatorDatabaseServer(_server, name.c_str()); + if (vocbase == nullptr) { // database does not yet exist, create it now - + if (id == 0) { // verify that we have an id id = ClusterInfo::instance()->uniqid(); } - + TRI_vocbase_defaults_t defaults; TRI_GetDatabaseDefaultsServer(_server, &defaults); - + // create a local database object... TRI_CreateCoordinatorDatabaseServer(_server, id, name.c_str(), &defaults, &vocbase); - - if (vocbase != nullptr && TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) { - HasRunOnce = 1; - - // insert initial user(s) for database - if (!fetchUsers()) { - TRI_ReleaseVocBase(vocbase); - return false; // We give up, we will try again in the - // next heartbeat - } - } - } else if (TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) { - if (_refetchUsers) { - // must re-fetch users for an existing database - if (!fetchUsers()) { - fetchingUsersFailed = true; - } - } - + } else { TRI_ReleaseVocBase(vocbase); } } - + // get the list of databases that we know about locally std::vector localIds = - TRI_GetIdsCoordinatorDatabaseServer(_server); - + TRI_GetIdsCoordinatorDatabaseServer(_server); + for (auto id : localIds) { auto r = std::find(ids.begin(), ids.end(), id); - + if (r == ids.end()) { // local database not found in the plan... TRI_DropByIdCoordinatorDatabaseServer(_server, id, false); } } - + } else { return false; } - + if (fetchingUsersFailed) { return false; } - + // invalidate our local cache ClusterInfo::instance()->flush(); - + // turn on error logging now if (!ClusterComm::instance()->enableConnectionErrorLogging(true)) { - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) << "created coordinator databases for the first time"; } @@ -645,15 +590,15 @@ bool HeartbeatThread::syncDBServerStatusQuo() { } if (_desiredVersions.plan > _currentVersions.plan) { - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) - << "Plan version " << _currentVersions.plan + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + << "Plan version " << _currentVersions.plan << " is lower than desired version " << _desiredVersions.plan; _isDispatchingChange = true; becauseOfPlan = true; } if (_desiredVersions.current > _currentVersions.current) { - LOG_TOPIC(DEBUG, Logger::HEARTBEAT) - << "Current version " << _currentVersions.current + LOG_TOPIC(DEBUG, Logger::HEARTBEAT) + << "Current version " << _currentVersions.current << " is lower than desired version " << _desiredVersions.current; _isDispatchingChange = true; becauseOfCurrent = true; @@ -677,7 +622,7 @@ bool HeartbeatThread::syncDBServerStatusQuo() { auto dispatcher = DispatcherFeature::DISPATCHER; if (dispatcher == nullptr) { - LOG_TOPIC(ERR, Logger::HEARTBEAT) + LOG_TOPIC(ERR, Logger::HEARTBEAT) << "could not schedule dbserver sync - dispatcher gone."; return false; } @@ -701,9 +646,8 @@ bool HeartbeatThread::syncDBServerStatusQuo() { //////////////////////////////////////////////////////////////////////////////// bool HeartbeatThread::handleStateChange(AgencyCommResult& result) { - VPackSlice const slice = result.slice()[0].get( - std::vector({ AgencyComm::prefix(), "Sync", - "Commands", _myId })); + VPackSlice const slice = result.slice()[0].get(std::vector( + {AgencyComm::prefix(), "Sync", "Commands", _myId})); if (slice.isString()) { std::string command = slice.copyString(); ServerState::StateEnum newState = ServerState::stringToState(command); @@ -724,7 +668,7 @@ bool HeartbeatThread::handleStateChange(AgencyCommResult& result) { bool HeartbeatThread::sendState() { const AgencyCommResult result = _agency.sendServerState(0.0); -// 8.0 * static_cast(_interval) / 1000.0 / 1000.0); + // 8.0 * static_cast(_interval) / 1000.0 / 1000.0); if (result.successful()) { _numFails = 0; @@ -734,38 +678,11 @@ bool HeartbeatThread::sendState() { if (++_numFails % _maxFailsBeforeWarning == 0) { std::string const endpoints = AgencyComm::getEndpointsString(); - LOG_TOPIC(WARN, Logger::HEARTBEAT) - << "heartbeat could not be sent to agency endpoints (" - << endpoints << "): http code: " << result.httpCode() - << ", body: " << result.body(); + LOG_TOPIC(WARN, Logger::HEARTBEAT) + << "heartbeat could not be sent to agency endpoints (" << endpoints + << "): http code: " << result.httpCode() << ", body: " << result.body(); _numFails = 0; } return false; } - -//////////////////////////////////////////////////////////////////////////////// -/// @brief fetch users for a database (run on coordinator only) -//////////////////////////////////////////////////////////////////////////////// - -bool HeartbeatThread::fetchUsers() { - VPackBuilder builder; - builder.openArray(); - - LOG_TOPIC(TRACE, Logger::HEARTBEAT) - << "fetching users for database"; - - bool result = RestServerFeature::AUTH_INFO.reload(); - - if (result) { - LOG_TOPIC(TRACE, Logger::HEARTBEAT) - << "fetching users successful"; - _refetchUsers = false; - } else { - LOG_TOPIC(TRACE, Logger::HEARTBEAT) - << "fetching users failed"; - _refetchUsers = true; - } - - return result; -} diff --git a/arangod/Cluster/HeartbeatThread.h b/arangod/Cluster/HeartbeatThread.h index e82c01b876..5417001f58 100644 --- a/arangod/Cluster/HeartbeatThread.h +++ b/arangod/Cluster/HeartbeatThread.h @@ -136,12 +136,6 @@ class HeartbeatThread : public Thread { bool sendState(); - ////////////////////////////////////////////////////////////////////////////// - /// @brief fetch users for a database (run on coordinator only) - ////////////////////////////////////////////////////////////////////////////// - - bool fetchUsers(); - ////////////////////////////////////////////////////////////////////////////// /// @brief bring the db server in sync with the desired state ////////////////////////////////////////////////////////////////////////////// @@ -179,12 +173,6 @@ class HeartbeatThread : public Thread { arangodb::basics::ConditionVariable _condition; - ////////////////////////////////////////////////////////////////////////////// - /// @brief users will be re-fetched the next time the heartbeat thread runs - ////////////////////////////////////////////////////////////////////////////// - - bool _refetchUsers; - ////////////////////////////////////////////////////////////////////////////// /// @brief this server's id ////////////////////////////////////////////////////////////////////////////// diff --git a/arangod/RestServer/RestServerFeature.cpp b/arangod/RestServer/RestServerFeature.cpp index 1e0aa9a0e0..9333082943 100644 --- a/arangod/RestServer/RestServerFeature.cpp +++ b/arangod/RestServer/RestServerFeature.cpp @@ -298,7 +298,7 @@ void RestServerFeature::start() { // populate the authentication cache. otherwise no one can access the new // database - RestServerFeature::AUTH_INFO.reload(); + RestServerFeature::AUTH_INFO.outdate(); } void RestServerFeature::stop() { diff --git a/arangod/V8Server/v8-vocbase.cpp b/arangod/V8Server/v8-vocbase.cpp index 31d50b39a1..cf4808b4c7 100644 --- a/arangod/V8Server/v8-vocbase.cpp +++ b/arangod/V8Server/v8-vocbase.cpp @@ -945,13 +945,9 @@ static void JS_ReloadAuth(v8::FunctionCallbackInfo const& args) { TRI_V8_THROW_EXCEPTION_USAGE("RELOAD_AUTH()"); } - bool result = RestServerFeature::AUTH_INFO.reload(); + RestServerFeature::AUTH_INFO.outdate(); - if (result) { - TRI_V8_RETURN_TRUE(); - } - - TRI_V8_RETURN_FALSE(); + TRI_V8_RETURN_TRUE(); TRI_V8_TRY_CATCH_END } diff --git a/arangod/VocBase/AuthInfo.cpp b/arangod/VocBase/AuthInfo.cpp index ca66d00b07..77167cbae5 100644 --- a/arangod/VocBase/AuthInfo.cpp +++ b/arangod/VocBase/AuthInfo.cpp @@ -103,7 +103,7 @@ static AuthEntry CreateAuthEntry(VPackSlice const& slice) { VPackSlice const databasesSlice = slice.get("databases"); std::unordered_map databases; AuthLevel allDatabases = AuthLevel::NONE; - + if (databasesSlice.isObject()) { for (auto const& obj : VPackObjectIterator(databasesSlice)) { std::string const key = obj.key.copyString(); @@ -112,26 +112,25 @@ static AuthEntry CreateAuthEntry(VPackSlice const& slice) { char const* value = obj.value.getString(length); if (TRI_CaseEqualString(value, "rw", 2)) { - if (key == "*") { - allDatabases = AuthLevel::RW; - } else { - databases.emplace(key, AuthLevel::RW); - } - } - else if (TRI_CaseEqualString(value, "ro", 2)) { - if (key == "*") { - allDatabases = AuthLevel::RO; - } else { - databases.emplace(key, AuthLevel::RO); - } + if (key == "*") { + allDatabases = AuthLevel::RW; + } else { + databases.emplace(key, AuthLevel::RW); + } + } else if (TRI_CaseEqualString(value, "ro", 2)) { + if (key == "*") { + allDatabases = AuthLevel::RO; + } else { + databases.emplace(key, AuthLevel::RO); + } } } } // build authentication entry return AuthEntry(userSlice.copyString(), methodSlice.copyString(), - saltSlice.copyString(), hashSlice.copyString(), - databases, allDatabases, active, mustChange); + saltSlice.copyString(), hashSlice.copyString(), databases, + allDatabases, active, mustChange); } AuthLevel AuthEntry::canUseDatabase(std::string const& dbname) const { @@ -217,15 +216,15 @@ bool AuthInfo::populate(VPackSlice const& slice) { return true; } -bool AuthInfo::reload() { +void AuthInfo::reload() { insertInitial(); - + TRI_vocbase_t* vocbase = DatabaseFeature::DATABASE->vocbase(); if (vocbase == nullptr) { LOG(DEBUG) << "system database is unknown, cannot load authentication " - << "and authorization information"; - return false; + << "and authorization information"; + return; } LOG(DEBUG) << "starting to load authentication and authorization information"; @@ -236,24 +235,25 @@ bool AuthInfo::reload() { int res = trx.begin(); if (res != TRI_ERROR_NO_ERROR) { - return false; + LOG(ERR) << "cannot start transaction to load authentication"; + return; } OperationResult users = - trx.all(TRI_COL_NAME_USERS, 0, UINT64_MAX, OperationOptions()); + trx.all(TRI_COL_NAME_USERS, 0, UINT64_MAX, OperationOptions()); trx.finish(users.code); if (users.failed()) { LOG(ERR) << "cannot read users from _users collection"; - return false; + return; } auto usersSlice = users.slice(); if (!usersSlice.isArray()) { LOG(ERR) << "cannot read users from _users collection"; - return false; + return; } if (usersSlice.length() == 0) { @@ -262,11 +262,15 @@ bool AuthInfo::reload() { populate(usersSlice); } - return true; + _outdated = false; } AuthResult AuthInfo::checkPassword(std::string const& username, - std::string const& password) { + std::string const& password) { + if (_outdated) { + reload(); + } + AuthResult result; // look up username @@ -299,22 +303,22 @@ AuthResult AuthInfo::checkPassword(std::string const& username, try { if (passwordMethod == "sha1") { arangodb::rest::SslInterface::sslSHA1(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else if (passwordMethod == "sha512") { arangodb::rest::SslInterface::sslSHA512(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else if (passwordMethod == "sha384") { arangodb::rest::SslInterface::sslSHA384(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else if (passwordMethod == "sha256") { arangodb::rest::SslInterface::sslSHA256(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else if (passwordMethod == "sha224") { arangodb::rest::SslInterface::sslSHA224(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else if (passwordMethod == "md5") { arangodb::rest::SslInterface::sslMD5(salted.c_str(), len, crypted, - cryptedLength); + cryptedLength); } else { // invalid algorithm... } @@ -329,8 +333,8 @@ AuthResult AuthInfo::checkPassword(std::string const& username, char* hex = TRI_EncodeHexString(crypted, cryptedLength, &hexLen); if (hex != nullptr) { - result._authorized = auth.checkPasswordHash(hex); - TRI_FreeString(TRI_CORE_MEM_ZONE, hex); + result._authorized = auth.checkPasswordHash(hex); + TRI_FreeString(TRI_CORE_MEM_ZONE, hex); } } @@ -340,7 +344,12 @@ AuthResult AuthInfo::checkPassword(std::string const& username, return result; } -AuthLevel AuthInfo::canUseDatabase(std::string const& username, std::string const& dbname) { +AuthLevel AuthInfo::canUseDatabase(std::string const& username, + std::string const& dbname) { + if (_outdated) { + reload(); + } + auto const& it = _authInfo.find(username); if (it == _authInfo.end()) { @@ -352,13 +361,18 @@ AuthLevel AuthInfo::canUseDatabase(std::string const& username, std::string cons return entry.canUseDatabase(dbname); } -AuthResult AuthInfo::checkAuthentication(AuthType authType, std::string const& secret) { - switch (authType) { - case AuthType::BASIC: - return checkAuthenticationBasic(secret); +AuthResult AuthInfo::checkAuthentication(AuthType authType, + std::string const& secret) { + if (_outdated) { + reload(); + } - case AuthType::JWT: - return checkAuthenticationJWT(secret); + switch (authType) { + case AuthType::BASIC: + return checkAuthenticationBasic(secret); + + case AuthType::JWT: + return checkAuthenticationJWT(secret); } return AuthResult(); @@ -376,7 +390,7 @@ AuthResult AuthInfo::checkAuthenticationBasic(std::string const& secret) { if (n == std::string::npos || n == 0 || n + 1 > up.size()) { LOG(TRACE) << "invalid authentication data found, cannot extract " - "username/password"; + "username/password"; return AuthResult(); } @@ -394,7 +408,7 @@ AuthResult AuthInfo::checkAuthenticationBasic(std::string const& secret) { AuthResult AuthInfo::checkAuthenticationJWT(std::string const& secret) { std::vector const parts = StringUtils::split(secret, '.'); - + if (parts.size() != 3) { LOG(DEBUG) << "Secret contains " << parts.size() << " parts"; return AuthResult(); @@ -403,33 +417,34 @@ AuthResult AuthInfo::checkAuthenticationJWT(std::string const& secret) { std::string const& header = parts[0]; std::string const& body = parts[1]; std::string const& signature = parts[2]; - + std::string const message = header + "." + body; if (!validateJwtHeader(header)) { LOG(DEBUG) << "Couldn't validate jwt header " << header; return AuthResult(); } - + std::string username; if (!validateJwtBody(body, &username)) { LOG(DEBUG) << "Couldn't validate jwt body " << body; return AuthResult(); } - + if (!validateJwtHMAC256Signature(message, signature)) { LOG(DEBUG) << "Couldn't validate jwt signature " << signature; return AuthResult(); } - + AuthResult result; result._username = username; result._authorized = true; - + return result; } -std::shared_ptr AuthInfo::parseJson(std::string const& str, std::string const& hint) { +std::shared_ptr AuthInfo::parseJson(std::string const& str, + std::string const& hint) { std::shared_ptr result; VPackParser parser; try { @@ -442,12 +457,13 @@ std::shared_ptr AuthInfo::parseJson(std::string const& str, std::s } catch (...) { LOG(ERR) << "Got unknown exception trying to parse " << hint; } - + return result; } bool AuthInfo::validateJwtHeader(std::string const& header) { - std::shared_ptr headerBuilder = parseJson(StringUtils::decodeBase64(header), "jwt header"); + std::shared_ptr headerBuilder = + parseJson(StringUtils::decodeBase64(header), "jwt header"); if (headerBuilder.get() == nullptr) { return false; } @@ -463,15 +479,15 @@ bool AuthInfo::validateJwtHeader(std::string const& header) { if (!algSlice.isString()) { return false; } - + if (!typSlice.isString()) { return false; } - + if (algSlice.copyString() != "HS256") { return false; } - + std::string typ = typSlice.copyString(); if (typ != "JWT") { return false; @@ -481,7 +497,8 @@ bool AuthInfo::validateJwtHeader(std::string const& header) { } bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) { - std::shared_ptr bodyBuilder = parseJson(StringUtils::decodeBase64(body), "jwt body"); + std::shared_ptr bodyBuilder = + parseJson(StringUtils::decodeBase64(body), "jwt body"); if (bodyBuilder.get() == nullptr) { return false; } @@ -495,28 +512,30 @@ bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) { if (!issSlice.isString()) { return false; } - + if (issSlice.copyString() != "arangodb") { return false; } - + VPackSlice const usernameSlice = bodySlice.get("preferred_username"); if (!usernameSlice.isString()) { return false; } - + *username = usernameSlice.copyString(); - + // mop: optional exp (cluster currently uses non expiring jwts) if (bodySlice.hasKey("exp")) { VPackSlice const expSlice = bodySlice.get("exp"); - + if (!expSlice.isNumber()) { return false; } - - std::chrono::system_clock::time_point expires(std::chrono::seconds(expSlice.getNumber())); - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + + std::chrono::system_clock::time_point expires( + std::chrono::seconds(expSlice.getNumber())); + std::chrono::system_clock::time_point now = + std::chrono::system_clock::now(); if (now >= expires) { return false; @@ -525,9 +544,13 @@ bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) { return true; } -bool AuthInfo::validateJwtHMAC256Signature(std::string const& message, std::string const& signature) { +bool AuthInfo::validateJwtHMAC256Signature(std::string const& message, + std::string const& signature) { std::string decodedSignature = StringUtils::decodeBase64U(signature); - + std::string const& jwtSecret = RestServerFeature::getJwtSecret(); - return verifyHMAC(jwtSecret.c_str(), jwtSecret.length(), message.c_str(), message.length(), decodedSignature.c_str(), decodedSignature.length(), SslInterface::Algorithm::ALGORITHM_SHA256); + return verifyHMAC(jwtSecret.c_str(), jwtSecret.length(), message.c_str(), + message.length(), decodedSignature.c_str(), + decodedSignature.length(), + SslInterface::Algorithm::ALGORITHM_SHA256); } diff --git a/arangod/VocBase/AuthInfo.h b/arangod/VocBase/AuthInfo.h index f1242bb45b..dd057f85f3 100644 --- a/arangod/VocBase/AuthInfo.h +++ b/arangod/VocBase/AuthInfo.h @@ -47,7 +47,7 @@ class AuthEntry { AuthEntry(std::string const& username, std::string const& passwordMethod, std::string const& passwordSalt, std::string const& passwordHash, std::unordered_map databases, AuthLevel allDatabases, - bool active, bool mustChange) + bool active, bool mustChange) : _username(username), _passwordMethod(passwordMethod), _passwordSalt(passwordSalt), @@ -96,18 +96,22 @@ class AuthInfo { }; public: - bool reload(); + AuthInfo() : _outdated(true) {} + + public: + void outdate() { _outdated = true; } AuthResult checkPassword(std::string const& username, - std::string const& password); + std::string const& password); AuthResult checkAuthentication(AuthType authType, - std::string const& secret); + std::string const& secret); AuthLevel canUseDatabase(std::string const& username, - std::string const& dbname); + std::string const& dbname); private: + void reload(); void clear(); void insertInitial(); bool populate(velocypack::Slice const& slice); @@ -121,6 +125,7 @@ class AuthInfo { private: basics::ReadWriteLock _authInfoLock; + std::atomic _outdated; std::unordered_map _authInfo; std::unordered_map _authBasicCache;