1
0
Fork 0

changed reload to outdated

This commit is contained in:
Frank Celler 2016-06-03 10:38:47 +02:00
parent 5c0dd05308
commit ee98b59e2f
6 changed files with 208 additions and 279 deletions

View File

@ -64,7 +64,6 @@ HeartbeatThread::HeartbeatThread(TRI_server_t* server,
_statusLock(),
_agency(),
_condition(),
_refetchUsers(true),
_myId(ServerState::instance()->getId()),
_interval(interval),
_maxFailsBeforeWarning(maxFailsBeforeWarning),
@ -109,13 +108,14 @@ void HeartbeatThread::run() {
////////////////////////////////////////////////////////////////////////////////
void HeartbeatThread::runDBServer() {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "starting heartbeat thread (DBServer version)";
// mop: the heartbeat thread itself is now ready
setReady();
// mop: however we need to wait for the rest server here to come up
// otherwise we would already create collections and the coordinator would think
// otherwise we would already create collections and the coordinator would
// think
// ohhh the dbserver is online...pump some documents into it
// which fails when it is still in maintenance mode
while (arangodb::rest::HttpHandlerFactory::isMaintenance()) {
@ -125,48 +125,48 @@ void HeartbeatThread::runDBServer() {
// convert timeout to seconds
double const interval = (double)_interval / 1000.0 / 1000.0;
std::function<bool(VPackSlice const& result)> updatePlan = [&](
VPackSlice const& result) {
if (!result.isNumber()) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Plan Version is not a number! " << result.toJson();
return false;
}
uint64_t version = result.getNumber<uint64_t>();
bool doSync = false;
{
MUTEX_LOCKER(mutexLocker, _statusLock);
if (version > _desiredVersions.plan) {
_desiredVersions.plan = version;
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Desired Current Version is now " << _desiredVersions.plan;
doSync = true;
}
}
std::function<bool(VPackSlice const& result)> updatePlan =
[&](VPackSlice const& result) {
if (!result.isNumber()) {
LOG_TOPIC(ERR, Logger::HEARTBEAT) << "Plan Version is not a number! "
<< result.toJson();
return false;
}
uint64_t version = result.getNumber<uint64_t>();
if (doSync) {
syncDBServerStatusQuo();
}
bool doSync = false;
{
MUTEX_LOCKER(mutexLocker, _statusLock);
if (version > _desiredVersions.plan) {
_desiredVersions.plan = version;
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Desired Current Version is now " << _desiredVersions.plan;
doSync = true;
}
}
if (doSync) {
syncDBServerStatusQuo();
}
return true;
};
return true;
};
auto planAgencyCallback = std::make_shared<AgencyCallback>(
_agency, "Plan/Version", updatePlan, true);
bool registered = false;
while (!registered) {
registered = _agencyCallbackRegistry->registerCallback(planAgencyCallback);
if (!registered) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Couldn't register plan change in agency!";
sleep(1);
}
}
// we check Current/Version every few heartbeats:
int const currentCountStart = 1; // set to 1 by Max to speed up discovery
int const currentCountStart = 1; // set to 1 by Max to speed up discovery
int currentCount = currentCountStart;
while (!isStopping()) {
@ -186,16 +186,15 @@ void HeartbeatThread::runDBServer() {
currentCount = currentCountStart;
// send an initial GET request to Sync/Commands/my-id
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "Looking at Sync/Commands/" + _myId;
AgencyCommResult result =
_agency.getValues("Sync/Commands/" + _myId);
AgencyCommResult result = _agency.getValues("Sync/Commands/" + _myId);
if (result.successful()) {
handleStateChange(result);
}
if (isStopping()) {
break;
}
@ -203,15 +202,14 @@ void HeartbeatThread::runDBServer() {
LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Refetching Current/Version...";
AgencyCommResult res = _agency.getValues("Current/Version");
if (!res.successful()) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Could not read Current/Version from agency.";
} else {
VPackSlice s
= res.slice()[0].get(std::vector<std::string>(
{_agency.prefix(), std::string("Current"),
std::string("Version")}));
VPackSlice s = res.slice()[0].get(
std::vector<std::string>({_agency.prefix(), std::string("Current"),
std::string("Version")}));
if (!s.isInteger()) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Current/Version in agency is not an integer.";
} else {
uint64_t currentVersion = 0;
@ -220,14 +218,14 @@ void HeartbeatThread::runDBServer() {
} catch (...) {
}
if (currentVersion == 0) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Current/Version in agency is 0.";
} else {
{
MUTEX_LOCKER(mutexLocker, _statusLock);
if (currentVersion > _desiredVersions.current) {
_desiredVersions.current = currentVersion;
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Found greater Current/Version in agency.";
}
}
@ -245,7 +243,7 @@ void HeartbeatThread::runDBServer() {
// mop: execute at least once
do {
LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "Entering update loop";
bool wasNotified;
{
CONDITION_LOCKER(locker, _condition);
@ -284,7 +282,7 @@ void HeartbeatThread::runDBServer() {
}
usleep(1000);
}
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "stopped heartbeat thread (DBServer version)";
}
@ -293,7 +291,7 @@ void HeartbeatThread::runDBServer() {
////////////////////////////////////////////////////////////////////////////////
void HeartbeatThread::runCoordinator() {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "starting heartbeat thread (coordinator version)";
uint64_t oldUserVersion = 0;
@ -320,25 +318,24 @@ void HeartbeatThread::runCoordinator() {
break;
}
AgencyReadTransaction trx(std::vector<std::string>({
_agency.prefixPath() + "Plan/Version",
_agency.prefixPath() + "Current/Version",
_agency.prefixPath() + "Sync/Commands/" + _myId,
_agency.prefixPath() + "Sync/UserVersion"}));
AgencyReadTransaction trx(std::vector<std::string>(
{_agency.prefixPath() + "Plan/Version",
_agency.prefixPath() + "Current/Version",
_agency.prefixPath() + "Sync/Commands/" + _myId,
_agency.prefixPath() + "Sync/UserVersion"}));
AgencyCommResult result = _agency.sendTransactionWithFailover(trx);
if (!result.successful()) {
LOG_TOPIC(WARN, Logger::HEARTBEAT)
LOG_TOPIC(WARN, Logger::HEARTBEAT)
<< "Heartbeat: Could not read from agency!";
} else {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "Looking at Sync/Commands/" + _myId;
handleStateChange(result);
VPackSlice versionSlice
= result.slice()[0].get(std::vector<std::string>(
{_agency.prefix(), "Plan", "Version"}));
VPackSlice versionSlice = result.slice()[0].get(
std::vector<std::string>({_agency.prefix(), "Plan", "Version"}));
if (versionSlice.isInteger()) {
// there is a plan version
@ -346,8 +343,7 @@ void HeartbeatThread::runCoordinator() {
uint64_t planVersion = 0;
try {
planVersion = versionSlice.getUInt();
}
catch (...) {
} catch (...) {
}
if (planVersion > lastPlanVersionNoticed) {
@ -363,62 +359,34 @@ void HeartbeatThread::runCoordinator() {
}
}
VPackSlice slice =
result.slice()[0].get(std::vector<std::string>(
{_agency.prefix(), "Sync", "UserVersion"}));
VPackSlice slice = result.slice()[0].get(
std::vector<std::string>({_agency.prefix(), "Sync", "UserVersion"}));
if (slice.isInteger()) {
// there is a UserVersion
uint64_t userVersion = 0;
try {
userVersion = slice.getUInt();
} catch (...) {
}
catch (...) {
}
if (userVersion > 0 && userVersion != oldUserVersion) {
// reload user cache for all databases
std::vector<DatabaseID> dbs =
ClusterInfo::instance()->databases(true);
std::vector<DatabaseID>::iterator i;
bool allOK = true;
for (i = dbs.begin(); i != dbs.end(); ++i) {
TRI_vocbase_t* vocbase =
TRI_UseCoordinatorDatabaseServer(_server, i->c_str());
if (vocbase != nullptr && TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) {
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Reloading users for database " << vocbase->_name
<< ".";
if (!fetchUsers()) {
// something is wrong... probably the database server
// with the _users collection is not yet available
allOK = false;
// we will not set oldUserVersion such that we will try this
// very same exercise again in the next heartbeat
}
TRI_ReleaseVocBase(vocbase);
}
}
if (allOK) {
oldUserVersion = userVersion;
}
oldUserVersion = userVersion;
RestServerFeature::AUTH_INFO.outdate();
}
}
versionSlice = result.slice()[0].get(std::vector<std::string>(
{_agency.prefix(), "Current", "Version"}));
versionSlice = result.slice()[0].get(
std::vector<std::string>({_agency.prefix(), "Current", "Version"}));
if (versionSlice.isInteger()) {
uint64_t currentVersion = 0;
try {
currentVersion = versionSlice.getUInt();
}
catch (...) {
} catch (...) {
}
if (currentVersion > lastCurrentVersionNoticed) {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "Found currentVersion " << currentVersion
<< "Found currentVersion " << currentVersion
<< " which is newer than " << lastCurrentVersionNoticed;
lastCurrentVersionNoticed = currentVersion;
@ -440,7 +408,6 @@ void HeartbeatThread::runCoordinator() {
remain = 0.0;
}
}
}
LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "stopped heartbeat thread";
@ -470,8 +437,8 @@ void HeartbeatThread::removeDispatchedJob(DBServerAgencySyncResult result) {
{
MUTEX_LOCKER(mutexLocker, _statusLock);
if (result.success) {
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Sync request successful. Now have Plan " << result.planVersion
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Sync request successful. Now have Plan " << result.planVersion
<< ", Current " << result.currentVersion;
_currentVersions = AgencyVersions(result);
} else {
@ -496,44 +463,40 @@ void HeartbeatThread::removeDispatchedJob(DBServerAgencySyncResult result) {
static std::string const prefixPlanChangeCoordinator = "Plan/Databases";
bool HeartbeatThread::handlePlanChangeCoordinator(uint64_t currentPlanVersion) {
bool fetchingUsersFailed = false;
LOG_TOPIC(TRACE, Logger::HEARTBEAT) << "found a plan update";
AgencyCommResult result;
{
AgencyCommLocker locker("Plan", "READ");
if (locker.successful()) {
result = _agency.getValues(prefixPlanChangeCoordinator);
}
}
if (result.successful()) {
std::vector<TRI_voc_tick_t> ids;
velocypack::Slice databases =
result.slice()[0].get(std::vector<std::string>(
{AgencyComm::prefix(), "Plan", "Databases"}));
velocypack::Slice databases = result.slice()[0].get(
std::vector<std::string>({AgencyComm::prefix(), "Plan", "Databases"}));
if (!databases.isObject()) {
return false;
}
// loop over all database names we got and create a local database
// instance if not yet present:
for (auto const& options : VPackObjectIterator (databases)) {
for (auto const& options : VPackObjectIterator(databases)) {
if (!options.value.isObject()) {
continue;
}
auto nameSlice = options.value.get("name");
if (nameSlice.isNone()) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Missing name in agency database plan";
continue;
continue;
}
std::string const name = options.value.get("name").copyString();
TRI_voc_tick_t id = 0;
@ -543,84 +506,66 @@ bool HeartbeatThread::handlePlanChangeCoordinator(uint64_t currentPlanVersion) {
try {
id = std::stoul(v.copyString());
} catch (std::exception const& e) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "Failed to convert id string to number";
LOG_TOPIC(ERR, Logger::HEARTBEAT) << e.what();
}
}
}
if (id > 0) {
ids.push_back(id);
}
TRI_vocbase_t* vocbase =
TRI_UseCoordinatorDatabaseServer(_server, name.c_str());
TRI_UseCoordinatorDatabaseServer(_server, name.c_str());
if (vocbase == nullptr) {
// database does not yet exist, create it now
if (id == 0) {
// verify that we have an id
id = ClusterInfo::instance()->uniqid();
}
TRI_vocbase_defaults_t defaults;
TRI_GetDatabaseDefaultsServer(_server, &defaults);
// create a local database object...
TRI_CreateCoordinatorDatabaseServer(_server, id, name.c_str(),
&defaults, &vocbase);
if (vocbase != nullptr && TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) {
HasRunOnce = 1;
// insert initial user(s) for database
if (!fetchUsers()) {
TRI_ReleaseVocBase(vocbase);
return false; // We give up, we will try again in the
// next heartbeat
}
}
} else if (TRI_EqualString(vocbase->_name, TRI_VOC_SYSTEM_DATABASE)) {
if (_refetchUsers) {
// must re-fetch users for an existing database
if (!fetchUsers()) {
fetchingUsersFailed = true;
}
}
} else {
TRI_ReleaseVocBase(vocbase);
}
}
// get the list of databases that we know about locally
std::vector<TRI_voc_tick_t> localIds =
TRI_GetIdsCoordinatorDatabaseServer(_server);
TRI_GetIdsCoordinatorDatabaseServer(_server);
for (auto id : localIds) {
auto r = std::find(ids.begin(), ids.end(), id);
if (r == ids.end()) {
// local database not found in the plan...
TRI_DropByIdCoordinatorDatabaseServer(_server, id, false);
}
}
} else {
return false;
}
if (fetchingUsersFailed) {
return false;
}
// invalidate our local cache
ClusterInfo::instance()->flush();
// turn on error logging now
if (!ClusterComm::instance()->enableConnectionErrorLogging(true)) {
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "created coordinator databases for the first time";
}
@ -645,15 +590,15 @@ bool HeartbeatThread::syncDBServerStatusQuo() {
}
if (_desiredVersions.plan > _currentVersions.plan) {
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Plan version " << _currentVersions.plan
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Plan version " << _currentVersions.plan
<< " is lower than desired version " << _desiredVersions.plan;
_isDispatchingChange = true;
becauseOfPlan = true;
}
if (_desiredVersions.current > _currentVersions.current) {
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Current version " << _currentVersions.current
LOG_TOPIC(DEBUG, Logger::HEARTBEAT)
<< "Current version " << _currentVersions.current
<< " is lower than desired version " << _desiredVersions.current;
_isDispatchingChange = true;
becauseOfCurrent = true;
@ -677,7 +622,7 @@ bool HeartbeatThread::syncDBServerStatusQuo() {
auto dispatcher = DispatcherFeature::DISPATCHER;
if (dispatcher == nullptr) {
LOG_TOPIC(ERR, Logger::HEARTBEAT)
LOG_TOPIC(ERR, Logger::HEARTBEAT)
<< "could not schedule dbserver sync - dispatcher gone.";
return false;
}
@ -701,9 +646,8 @@ bool HeartbeatThread::syncDBServerStatusQuo() {
////////////////////////////////////////////////////////////////////////////////
bool HeartbeatThread::handleStateChange(AgencyCommResult& result) {
VPackSlice const slice = result.slice()[0].get(
std::vector<std::string>({ AgencyComm::prefix(), "Sync",
"Commands", _myId }));
VPackSlice const slice = result.slice()[0].get(std::vector<std::string>(
{AgencyComm::prefix(), "Sync", "Commands", _myId}));
if (slice.isString()) {
std::string command = slice.copyString();
ServerState::StateEnum newState = ServerState::stringToState(command);
@ -724,7 +668,7 @@ bool HeartbeatThread::handleStateChange(AgencyCommResult& result) {
bool HeartbeatThread::sendState() {
const AgencyCommResult result = _agency.sendServerState(0.0);
// 8.0 * static_cast<double>(_interval) / 1000.0 / 1000.0);
// 8.0 * static_cast<double>(_interval) / 1000.0 / 1000.0);
if (result.successful()) {
_numFails = 0;
@ -734,38 +678,11 @@ bool HeartbeatThread::sendState() {
if (++_numFails % _maxFailsBeforeWarning == 0) {
std::string const endpoints = AgencyComm::getEndpointsString();
LOG_TOPIC(WARN, Logger::HEARTBEAT)
<< "heartbeat could not be sent to agency endpoints ("
<< endpoints << "): http code: " << result.httpCode()
<< ", body: " << result.body();
LOG_TOPIC(WARN, Logger::HEARTBEAT)
<< "heartbeat could not be sent to agency endpoints (" << endpoints
<< "): http code: " << result.httpCode() << ", body: " << result.body();
_numFails = 0;
}
return false;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief fetch users for a database (run on coordinator only)
////////////////////////////////////////////////////////////////////////////////
bool HeartbeatThread::fetchUsers() {
VPackBuilder builder;
builder.openArray();
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "fetching users for database";
bool result = RestServerFeature::AUTH_INFO.reload();
if (result) {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "fetching users successful";
_refetchUsers = false;
} else {
LOG_TOPIC(TRACE, Logger::HEARTBEAT)
<< "fetching users failed";
_refetchUsers = true;
}
return result;
}

View File

@ -136,12 +136,6 @@ class HeartbeatThread : public Thread {
bool sendState();
//////////////////////////////////////////////////////////////////////////////
/// @brief fetch users for a database (run on coordinator only)
//////////////////////////////////////////////////////////////////////////////
bool fetchUsers();
//////////////////////////////////////////////////////////////////////////////
/// @brief bring the db server in sync with the desired state
//////////////////////////////////////////////////////////////////////////////
@ -179,12 +173,6 @@ class HeartbeatThread : public Thread {
arangodb::basics::ConditionVariable _condition;
//////////////////////////////////////////////////////////////////////////////
/// @brief users will be re-fetched the next time the heartbeat thread runs
//////////////////////////////////////////////////////////////////////////////
bool _refetchUsers;
//////////////////////////////////////////////////////////////////////////////
/// @brief this server's id
//////////////////////////////////////////////////////////////////////////////

View File

@ -298,7 +298,7 @@ void RestServerFeature::start() {
// populate the authentication cache. otherwise no one can access the new
// database
RestServerFeature::AUTH_INFO.reload();
RestServerFeature::AUTH_INFO.outdate();
}
void RestServerFeature::stop() {

View File

@ -945,13 +945,9 @@ static void JS_ReloadAuth(v8::FunctionCallbackInfo<v8::Value> const& args) {
TRI_V8_THROW_EXCEPTION_USAGE("RELOAD_AUTH()");
}
bool result = RestServerFeature::AUTH_INFO.reload();
RestServerFeature::AUTH_INFO.outdate();
if (result) {
TRI_V8_RETURN_TRUE();
}
TRI_V8_RETURN_FALSE();
TRI_V8_RETURN_TRUE();
TRI_V8_TRY_CATCH_END
}

View File

@ -103,7 +103,7 @@ static AuthEntry CreateAuthEntry(VPackSlice const& slice) {
VPackSlice const databasesSlice = slice.get("databases");
std::unordered_map<std::string, AuthLevel> databases;
AuthLevel allDatabases = AuthLevel::NONE;
if (databasesSlice.isObject()) {
for (auto const& obj : VPackObjectIterator(databasesSlice)) {
std::string const key = obj.key.copyString();
@ -112,26 +112,25 @@ static AuthEntry CreateAuthEntry(VPackSlice const& slice) {
char const* value = obj.value.getString(length);
if (TRI_CaseEqualString(value, "rw", 2)) {
if (key == "*") {
allDatabases = AuthLevel::RW;
} else {
databases.emplace(key, AuthLevel::RW);
}
}
else if (TRI_CaseEqualString(value, "ro", 2)) {
if (key == "*") {
allDatabases = AuthLevel::RO;
} else {
databases.emplace(key, AuthLevel::RO);
}
if (key == "*") {
allDatabases = AuthLevel::RW;
} else {
databases.emplace(key, AuthLevel::RW);
}
} else if (TRI_CaseEqualString(value, "ro", 2)) {
if (key == "*") {
allDatabases = AuthLevel::RO;
} else {
databases.emplace(key, AuthLevel::RO);
}
}
}
}
// build authentication entry
return AuthEntry(userSlice.copyString(), methodSlice.copyString(),
saltSlice.copyString(), hashSlice.copyString(),
databases, allDatabases, active, mustChange);
saltSlice.copyString(), hashSlice.copyString(), databases,
allDatabases, active, mustChange);
}
AuthLevel AuthEntry::canUseDatabase(std::string const& dbname) const {
@ -217,15 +216,15 @@ bool AuthInfo::populate(VPackSlice const& slice) {
return true;
}
bool AuthInfo::reload() {
void AuthInfo::reload() {
insertInitial();
TRI_vocbase_t* vocbase = DatabaseFeature::DATABASE->vocbase();
if (vocbase == nullptr) {
LOG(DEBUG) << "system database is unknown, cannot load authentication "
<< "and authorization information";
return false;
<< "and authorization information";
return;
}
LOG(DEBUG) << "starting to load authentication and authorization information";
@ -236,24 +235,25 @@ bool AuthInfo::reload() {
int res = trx.begin();
if (res != TRI_ERROR_NO_ERROR) {
return false;
LOG(ERR) << "cannot start transaction to load authentication";
return;
}
OperationResult users =
trx.all(TRI_COL_NAME_USERS, 0, UINT64_MAX, OperationOptions());
trx.all(TRI_COL_NAME_USERS, 0, UINT64_MAX, OperationOptions());
trx.finish(users.code);
if (users.failed()) {
LOG(ERR) << "cannot read users from _users collection";
return false;
return;
}
auto usersSlice = users.slice();
if (!usersSlice.isArray()) {
LOG(ERR) << "cannot read users from _users collection";
return false;
return;
}
if (usersSlice.length() == 0) {
@ -262,11 +262,15 @@ bool AuthInfo::reload() {
populate(usersSlice);
}
return true;
_outdated = false;
}
AuthResult AuthInfo::checkPassword(std::string const& username,
std::string const& password) {
std::string const& password) {
if (_outdated) {
reload();
}
AuthResult result;
// look up username
@ -299,22 +303,22 @@ AuthResult AuthInfo::checkPassword(std::string const& username,
try {
if (passwordMethod == "sha1") {
arangodb::rest::SslInterface::sslSHA1(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else if (passwordMethod == "sha512") {
arangodb::rest::SslInterface::sslSHA512(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else if (passwordMethod == "sha384") {
arangodb::rest::SslInterface::sslSHA384(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else if (passwordMethod == "sha256") {
arangodb::rest::SslInterface::sslSHA256(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else if (passwordMethod == "sha224") {
arangodb::rest::SslInterface::sslSHA224(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else if (passwordMethod == "md5") {
arangodb::rest::SslInterface::sslMD5(salted.c_str(), len, crypted,
cryptedLength);
cryptedLength);
} else {
// invalid algorithm...
}
@ -329,8 +333,8 @@ AuthResult AuthInfo::checkPassword(std::string const& username,
char* hex = TRI_EncodeHexString(crypted, cryptedLength, &hexLen);
if (hex != nullptr) {
result._authorized = auth.checkPasswordHash(hex);
TRI_FreeString(TRI_CORE_MEM_ZONE, hex);
result._authorized = auth.checkPasswordHash(hex);
TRI_FreeString(TRI_CORE_MEM_ZONE, hex);
}
}
@ -340,7 +344,12 @@ AuthResult AuthInfo::checkPassword(std::string const& username,
return result;
}
AuthLevel AuthInfo::canUseDatabase(std::string const& username, std::string const& dbname) {
AuthLevel AuthInfo::canUseDatabase(std::string const& username,
std::string const& dbname) {
if (_outdated) {
reload();
}
auto const& it = _authInfo.find(username);
if (it == _authInfo.end()) {
@ -352,13 +361,18 @@ AuthLevel AuthInfo::canUseDatabase(std::string const& username, std::string cons
return entry.canUseDatabase(dbname);
}
AuthResult AuthInfo::checkAuthentication(AuthType authType, std::string const& secret) {
switch (authType) {
case AuthType::BASIC:
return checkAuthenticationBasic(secret);
AuthResult AuthInfo::checkAuthentication(AuthType authType,
std::string const& secret) {
if (_outdated) {
reload();
}
case AuthType::JWT:
return checkAuthenticationJWT(secret);
switch (authType) {
case AuthType::BASIC:
return checkAuthenticationBasic(secret);
case AuthType::JWT:
return checkAuthenticationJWT(secret);
}
return AuthResult();
@ -376,7 +390,7 @@ AuthResult AuthInfo::checkAuthenticationBasic(std::string const& secret) {
if (n == std::string::npos || n == 0 || n + 1 > up.size()) {
LOG(TRACE) << "invalid authentication data found, cannot extract "
"username/password";
"username/password";
return AuthResult();
}
@ -394,7 +408,7 @@ AuthResult AuthInfo::checkAuthenticationBasic(std::string const& secret) {
AuthResult AuthInfo::checkAuthenticationJWT(std::string const& secret) {
std::vector<std::string> const parts = StringUtils::split(secret, '.');
if (parts.size() != 3) {
LOG(DEBUG) << "Secret contains " << parts.size() << " parts";
return AuthResult();
@ -403,33 +417,34 @@ AuthResult AuthInfo::checkAuthenticationJWT(std::string const& secret) {
std::string const& header = parts[0];
std::string const& body = parts[1];
std::string const& signature = parts[2];
std::string const message = header + "." + body;
if (!validateJwtHeader(header)) {
LOG(DEBUG) << "Couldn't validate jwt header " << header;
return AuthResult();
}
std::string username;
if (!validateJwtBody(body, &username)) {
LOG(DEBUG) << "Couldn't validate jwt body " << body;
return AuthResult();
}
if (!validateJwtHMAC256Signature(message, signature)) {
LOG(DEBUG) << "Couldn't validate jwt signature " << signature;
return AuthResult();
}
AuthResult result;
result._username = username;
result._authorized = true;
return result;
}
std::shared_ptr<VPackBuilder> AuthInfo::parseJson(std::string const& str, std::string const& hint) {
std::shared_ptr<VPackBuilder> AuthInfo::parseJson(std::string const& str,
std::string const& hint) {
std::shared_ptr<VPackBuilder> result;
VPackParser parser;
try {
@ -442,12 +457,13 @@ std::shared_ptr<VPackBuilder> AuthInfo::parseJson(std::string const& str, std::s
} catch (...) {
LOG(ERR) << "Got unknown exception trying to parse " << hint;
}
return result;
}
bool AuthInfo::validateJwtHeader(std::string const& header) {
std::shared_ptr<VPackBuilder> headerBuilder = parseJson(StringUtils::decodeBase64(header), "jwt header");
std::shared_ptr<VPackBuilder> headerBuilder =
parseJson(StringUtils::decodeBase64(header), "jwt header");
if (headerBuilder.get() == nullptr) {
return false;
}
@ -463,15 +479,15 @@ bool AuthInfo::validateJwtHeader(std::string const& header) {
if (!algSlice.isString()) {
return false;
}
if (!typSlice.isString()) {
return false;
}
if (algSlice.copyString() != "HS256") {
return false;
}
std::string typ = typSlice.copyString();
if (typ != "JWT") {
return false;
@ -481,7 +497,8 @@ bool AuthInfo::validateJwtHeader(std::string const& header) {
}
bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) {
std::shared_ptr<VPackBuilder> bodyBuilder = parseJson(StringUtils::decodeBase64(body), "jwt body");
std::shared_ptr<VPackBuilder> bodyBuilder =
parseJson(StringUtils::decodeBase64(body), "jwt body");
if (bodyBuilder.get() == nullptr) {
return false;
}
@ -495,28 +512,30 @@ bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) {
if (!issSlice.isString()) {
return false;
}
if (issSlice.copyString() != "arangodb") {
return false;
}
VPackSlice const usernameSlice = bodySlice.get("preferred_username");
if (!usernameSlice.isString()) {
return false;
}
*username = usernameSlice.copyString();
// mop: optional exp (cluster currently uses non expiring jwts)
if (bodySlice.hasKey("exp")) {
VPackSlice const expSlice = bodySlice.get("exp");
if (!expSlice.isNumber()) {
return false;
}
std::chrono::system_clock::time_point expires(std::chrono::seconds(expSlice.getNumber<uint64_t>()));
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::chrono::system_clock::time_point expires(
std::chrono::seconds(expSlice.getNumber<uint64_t>()));
std::chrono::system_clock::time_point now =
std::chrono::system_clock::now();
if (now >= expires) {
return false;
@ -525,9 +544,13 @@ bool AuthInfo::validateJwtBody(std::string const& body, std::string* username) {
return true;
}
bool AuthInfo::validateJwtHMAC256Signature(std::string const& message, std::string const& signature) {
bool AuthInfo::validateJwtHMAC256Signature(std::string const& message,
std::string const& signature) {
std::string decodedSignature = StringUtils::decodeBase64U(signature);
std::string const& jwtSecret = RestServerFeature::getJwtSecret();
return verifyHMAC(jwtSecret.c_str(), jwtSecret.length(), message.c_str(), message.length(), decodedSignature.c_str(), decodedSignature.length(), SslInterface::Algorithm::ALGORITHM_SHA256);
return verifyHMAC(jwtSecret.c_str(), jwtSecret.length(), message.c_str(),
message.length(), decodedSignature.c_str(),
decodedSignature.length(),
SslInterface::Algorithm::ALGORITHM_SHA256);
}

View File

@ -47,7 +47,7 @@ class AuthEntry {
AuthEntry(std::string const& username, std::string const& passwordMethod,
std::string const& passwordSalt, std::string const& passwordHash,
std::unordered_map<std::string, AuthLevel> databases, AuthLevel allDatabases,
bool active, bool mustChange)
bool active, bool mustChange)
: _username(username),
_passwordMethod(passwordMethod),
_passwordSalt(passwordSalt),
@ -96,18 +96,22 @@ class AuthInfo {
};
public:
bool reload();
AuthInfo() : _outdated(true) {}
public:
void outdate() { _outdated = true; }
AuthResult checkPassword(std::string const& username,
std::string const& password);
std::string const& password);
AuthResult checkAuthentication(AuthType authType,
std::string const& secret);
std::string const& secret);
AuthLevel canUseDatabase(std::string const& username,
std::string const& dbname);
std::string const& dbname);
private:
void reload();
void clear();
void insertInitial();
bool populate(velocypack::Slice const& slice);
@ -121,6 +125,7 @@ class AuthInfo {
private:
basics::ReadWriteLock _authInfoLock;
std::atomic<bool> _outdated;
std::unordered_map<std::string, arangodb::AuthEntry> _authInfo;
std::unordered_map<std::string, arangodb::AuthResult> _authBasicCache;