1
0
Fork 0
arangodb/arangod/RestServer/VocbaseContext.cpp

505 lines
15 KiB
C++

////////////////////////////////////////////////////////////////////////////////
/// DISCLAIMER
///
/// Copyright 2014-2016 ArangoDB GmbH, Cologne, Germany
/// Copyright 2004-2014 triAGENS GmbH, Cologne, Germany
///
/// Licensed under the Apache License, Version 2.0 (the "License");
/// you may not use this file except in compliance with the License.
/// You may obtain a copy of the License at
///
/// http://www.apache.org/licenses/LICENSE-2.0
///
/// Unless required by applicable law or agreed to in writing, software
/// distributed under the License is distributed on an "AS IS" BASIS,
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
/// See the License for the specific language governing permissions and
/// limitations under the License.
///
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
///
/// @author Dr. Frank Celler
////////////////////////////////////////////////////////////////////////////////
#include "VocbaseContext.h"
#include <chrono>
#include <velocypack/Builder.h>
#include <velocypack/Exception.h>
#include <velocypack/Parser.h>
#include <velocypack/velocypack-aliases.h>
#include "Basics/MutexLocker.h"
#include "Basics/tri-strings.h"
#include "Cluster/ServerState.h"
#include "Endpoint/ConnectionInfo.h"
#include "Logger/Logger.h"
#include "Ssl/SslInterface.h"
#include "VocBase/auth.h"
#include "VocBase/server.h"
#include "VocBase/vocbase.h"
using namespace arangodb;
using namespace arangodb::basics;
using namespace arangodb::rest;
////////////////////////////////////////////////////////////////////////////////
/// @brief sid lock
////////////////////////////////////////////////////////////////////////////////
static arangodb::Mutex SidLock;
////////////////////////////////////////////////////////////////////////////////
/// @brief sid cache
////////////////////////////////////////////////////////////////////////////////
#ifdef _WIN32
// turn off warnings about too long type name for debug symbols blabla in MSVC
// only...
#pragma warning(disable : 4503)
#endif
typedef std::unordered_map<std::string, std::pair<std::string, double>>
DatabaseSessionsType;
static std::unordered_map<std::string, DatabaseSessionsType> SidCache;
////////////////////////////////////////////////////////////////////////////////
/// @brief time-to-live for aardvark server sessions
////////////////////////////////////////////////////////////////////////////////
double VocbaseContext::ServerSessionTtl =
60.0 * 60.0 * 2; // 2 hours session timeout
////////////////////////////////////////////////////////////////////////////////
/// @brief defines a sid
////////////////////////////////////////////////////////////////////////////////
void VocbaseContext::createSid(std::string const& database,
std::string const& sid,
std::string const& username) {
MUTEX_LOCKER(mutexLocker, SidLock);
// find entries for database first
auto it = SidCache.find(database);
if (it == SidCache.end()) {
it = SidCache.emplace(database, DatabaseSessionsType()).first;
}
// now insert a database-specific sid
double const now = TRI_microtime() * 1000.0;
(*it).second.emplace(sid, std::make_pair(username, now));
}
////////////////////////////////////////////////////////////////////////////////
/// @brief clears all sid entries for a database
////////////////////////////////////////////////////////////////////////////////
void VocbaseContext::clearSid(std::string const& database) {
MUTEX_LOCKER(mutexLocker, SidLock);
SidCache.erase(database);
}
////////////////////////////////////////////////////////////////////////////////
/// @brief clears a sid
////////////////////////////////////////////////////////////////////////////////
void VocbaseContext::clearSid(std::string const& database,
std::string const& sid) {
MUTEX_LOCKER(mutexLocker, SidLock);
auto it = SidCache.find(database);
if (it == SidCache.end()) {
// database not found. no need to go on
return;
}
(*it).second.erase(sid);
}
////////////////////////////////////////////////////////////////////////////////
/// @brief gets the last access time
////////////////////////////////////////////////////////////////////////////////
double VocbaseContext::accessSid(std::string const& database,
std::string const& sid) {
MUTEX_LOCKER(mutexLocker, SidLock);
auto it = SidCache.find(database);
if (it == SidCache.end()) {
// database not found. no need to go on
return 0.0;
}
auto const& sids = (*it).second;
auto it2 = sids.find(sid);
if (it2 == sids.end()) {
return 0.0;
}
return (*it2).second.second;
}
VocbaseContext::VocbaseContext(HttpRequest* request, TRI_server_t* server,
TRI_vocbase_t* vocbase, std::string const& jwtSecret)
: RequestContext(request), _server(server), _vocbase(vocbase), _jwtSecret(jwtSecret) {
TRI_ASSERT(_server != nullptr);
TRI_ASSERT(_vocbase != nullptr);
}
VocbaseContext::~VocbaseContext() { TRI_ReleaseVocBase(_vocbase); }
////////////////////////////////////////////////////////////////////////////////
/// @brief whether or not to use special cluster authentication
////////////////////////////////////////////////////////////////////////////////
bool VocbaseContext::useClusterAuthentication() const {
auto role = ServerState::instance()->getRole();
if (ServerState::instance()->isDBServer(role)) {
return true;
}
if (ServerState::instance()->isCoordinator(role)) {
std::string const& s = _request->requestPath();
if (s == "/_api/shard-comm" || s == "/_admin/shutdown") {
return true;
}
}
return false;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief return authentication realm
////////////////////////////////////////////////////////////////////////////////
std::string VocbaseContext::realm() const {
if (_vocbase == nullptr) {
return std::string("");
}
return _vocbase->_name;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief checks the authentication
////////////////////////////////////////////////////////////////////////////////
GeneralResponse::ResponseCode VocbaseContext::authenticate() {
TRI_ASSERT(_vocbase != nullptr);
if (!_vocbase->_settings.requireAuthentication) {
// no authentication required at all
return GeneralResponse::ResponseCode::OK;
}
#ifdef ARANGODB_HAVE_DOMAIN_SOCKETS
// check if we need to run authentication for this type of
// endpoint
ConnectionInfo const& ci = _request->connectionInfo();
if (ci.endpointType == Endpoint::DomainType::UNIX &&
!_vocbase->_settings.requireAuthenticationUnixSockets) {
// no authentication required for unix socket domain connections
return GeneralResponse::ResponseCode::OK;
}
#endif
std::string const& path = _request->requestPath();
if (_vocbase->_settings.authenticateSystemOnly) {
// authentication required, but only for /_api, /_admin etc.
if (!path.empty()) {
// check if path starts with /_
if (path[0] != '/') {
return GeneralResponse::ResponseCode::OK;
}
if (path[0] != '\0' && path[1] != '_') {
return GeneralResponse::ResponseCode::OK;
}
}
}
if (StringUtils::isPrefix(path, "/_open/") ||
StringUtils::isPrefix(path, "/_admin/aardvark/") || path == "/") {
return GeneralResponse::ResponseCode::OK;
}
// .............................................................................
// authentication required
// .............................................................................
bool found;
char cn[4096];
cn[0] = '\0';
strncat(cn, "arango_sid_", 11);
strncat(cn + 11, _vocbase->_name, sizeof(cn) - 12);
// extract the sid
std::string const& sid = _request->cookieValue(cn, found);
if (found) {
MUTEX_LOCKER(mutexLocker, SidLock);
auto it = SidCache.find(_vocbase->_name);
if (it != SidCache.end()) {
auto& sids = (*it).second;
auto it2 = sids.find(sid);
if (it2 != sids.end()) {
_request->setUser((*it2).second.first);
double const now = TRI_microtime() * 1000.0;
// fetch last access date of session
double const lastAccess = (*it2).second.second;
// check if session has expired
if (lastAccess + (ServerSessionTtl * 1000.0) < now) {
// session has expired
sids.erase(sid);
return GeneralResponse::ResponseCode::UNAUTHORIZED;
}
(*it2).second.second = now;
return GeneralResponse::ResponseCode::OK;
}
}
// no cookie found. fall-through to regular HTTP authentication
}
std::string const& authStr = _request->header(StaticStrings::Authorization, found);
if (!found) {
return GeneralResponse::ResponseCode::UNAUTHORIZED;
}
size_t methodPos = authStr.find_first_of(' ');
if (methodPos == std::string::npos) {
return GeneralResponse::ResponseCode::UNAUTHORIZED;
}
// skip over "basic "
char const* auth = authStr.c_str() + methodPos;
while (*auth == ' ') {
++auth;
}
if (!TRI_CaseEqualString(authStr.c_str(), "basic ", 6)) {
return basicAuthentication(auth);
} else if (TRI_CaseEqualString(authStr.c_str(), "bearer ", 7)) {
return jwtAuthentication(std::string(auth));
} else {
// mop: hmmm is 403 the correct status code? or 401? or 400? :S
return GeneralResponse::ResponseCode::FORBIDDEN;
}
}
////////////////////////////////////////////////////////////////////////////////
/// @brief checks the authentication via basic
////////////////////////////////////////////////////////////////////////////////
GeneralResponse::ResponseCode VocbaseContext::basicAuthentication(const char* auth) {
if (useClusterAuthentication()) {
std::string const expected = ServerState::instance()->getAuthentication();
if (expected.substr(6) != std::string(auth)) {
return GeneralResponse::ResponseCode::UNAUTHORIZED;
}
std::string const up = StringUtils::decodeBase64(auth);
std::string::size_type n = up.find(':', 0);
if (n == std::string::npos || n == 0 || n + 1 > up.size()) {
LOG(TRACE) << "invalid authentication data found, cannot extract "
"username/password";
return GeneralResponse::ResponseCode::BAD;
}
_request->setUser(up.substr(0, n));
return GeneralResponse::ResponseCode::OK;
}
// look up the info in the cache first
bool mustChange;
std::string username = TRI_CheckCacheAuthInfo(_vocbase, auth, &mustChange);
if (username.empty()) {
// no entry found in cache, decode the basic auth info and look it up
std::string const up = StringUtils::decodeBase64(auth);
std::string::size_type n = up.find(':', 0);
if (n == std::string::npos || n == 0 || n + 1 > up.size()) {
LOG(TRACE) << "invalid authentication data found, cannot extract "
"username/password";
return GeneralResponse::ResponseCode::BAD;
}
username = up.substr(0, n);
LOG(TRACE) << "checking authentication for user '" << username << "'";
bool res =
TRI_CheckAuthenticationAuthInfo(_vocbase, auth, username.c_str(),
up.substr(n + 1).c_str(), &mustChange);
if (!res) {
return GeneralResponse::ResponseCode::UNAUTHORIZED;
}
}
_request->setUser(std::move(username));
if (mustChange) {
if ((_request->requestType() == GeneralRequest::RequestType::PUT ||
_request->requestType() == GeneralRequest::RequestType::PATCH) &&
StringUtils::isPrefix(_request->requestPath(), "/_api/user/")) {
return GeneralResponse::ResponseCode::OK;
}
return GeneralResponse::ResponseCode::FORBIDDEN;
}
return GeneralResponse::ResponseCode::OK;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief checks the authentication via jwt
////////////////////////////////////////////////////////////////////////////////
GeneralResponse::ResponseCode VocbaseContext::jwtAuthentication(std::string const& auth) {
std::vector<std::string> const parts = StringUtils::split(auth, '.');
if (parts.size() != 3) {
return GeneralResponse::ResponseCode::FORBIDDEN;
}
std::string const& header = parts[0];
std::string const& body = parts[1];
std::string const& signature = parts[2];
std::string const message = header + "." + body;
if (!validateJwtHeader(header)) {
LOG(DEBUG) << "Couldn't validate jwt header " << header;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
if (!validateJwtBody(body)) {
LOG(DEBUG) << "Couldn't validate jwt body " << body;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
if (!validateJwtHMAC256Signature(message, signature)) {
LOG(DEBUG) << "Couldn't validate jwt signature " << signature;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
return GeneralResponse::ResponseCode::OK;
}
std::shared_ptr<VPackBuilder> VocbaseContext::parseJson(std::string const& str, std::string const& hint) {
std::shared_ptr<VPackBuilder> result;
VPackParser parser;
try {
parser.parse(str);
result = parser.steal();
} catch (std::bad_alloc const&) {
LOG(ERR) << "Out of memory parsing " << hint << "!";
} catch (VPackException const& ex) {
LOG(DEBUG) << "Couldn't parse " << hint << ": " << ex.what();
} catch (...) {
LOG(ERR) << "Got unknown exception trying to parse " << hint;
}
return result;
}
bool VocbaseContext::validateJwtHeader(std::string const& header) {
std::shared_ptr<VPackBuilder> headerBuilder = parseJson(StringUtils::decodeBase64(header), "jwt header");
if (headerBuilder.get() == nullptr) {
return false;
}
VPackSlice const headerSlice = headerBuilder->slice();
if (!headerSlice.isObject()) {
return false;
}
VPackSlice const algSlice = headerSlice.get("alg");
VPackSlice const typSlice = headerSlice.get("typ");
if (!algSlice.isString()) {
return false;
}
if (!typSlice.isString()) {
return false;
}
if (algSlice.copyString() != "HS256") {
return false;
}
if (typSlice.copyString() != "jwt") {
return false;
}
return true;
}
bool VocbaseContext::validateJwtBody(std::string const& body) {
std::shared_ptr<VPackBuilder> bodyBuilder = parseJson(StringUtils::decodeBase64(body), "jwt body");
if (bodyBuilder.get() == nullptr) {
return false;
}
VPackSlice const bodySlice = bodyBuilder->slice();
if (!bodySlice.isObject()) {
return false;
}
VPackSlice const issSlice = bodySlice.get("iss");
if (!issSlice.isString()) {
return false;
}
if (issSlice.copyString() != "arangodb") {
return false;
}
// mop: optional exp (cluster currently uses non expiring jwts)
if (bodySlice.hasKey("exp")) {
VPackSlice const expSlice = bodySlice.get("exp");
if (!expSlice.isNumber()) {
return false;
}
std::chrono::system_clock::time_point expires(std::chrono::seconds(expSlice.getNumber<uint64_t>()));
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
if (now >= expires) {
return false;
}
}
return true;
}
bool VocbaseContext::validateJwtHMAC256Signature(std::string const& message, std::string const& signature) {
std::string decodedSignature = StringUtils::decodeBase64(signature);
return verifyHMAC(_jwtSecret.c_str(), _jwtSecret.length(), message.c_str(), message.length(), signature.c_str(), signature.length(), SslInterface::Algorithm::ALGORITHM_SHA256);
}