1
0
Fork 0
This commit is contained in:
Andreas Streichardt 2016-05-31 14:28:15 +02:00
parent 6cb00f905d
commit 87f09b986a
6 changed files with 156 additions and 10 deletions

View File

@ -55,7 +55,7 @@ std::string RestAuthHandler::generateJwt(std::string const& username, std::strin
VPackBuilder bodyBuilder;
{
VPackObjectBuilder p(&bodyBuilder);
bodyBuilder.add("username", VPackValue(username));
bodyBuilder.add("preferred_username", VPackValue(username));
bodyBuilder.add("iss", VPackValue("arangodb"));
bodyBuilder.add("exp", VPackValue(exp.count()));
}

View File

@ -228,6 +228,7 @@ static TRI_vocbase_t* LookupDatabaseFromRequest(HttpRequest* request,
}
static bool SetRequestContext(HttpRequest* request, void* data) {
TRI_ASSERT(RestServerFeature::RESTSERVER != nullptr);
TRI_server_t* server = static_cast<TRI_server_t*>(data);
TRI_vocbase_t* vocbase = LookupDatabaseFromRequest(request, server);
@ -242,7 +243,7 @@ static bool SetRequestContext(HttpRequest* request, void* data) {
return false;
}
VocbaseContext* ctx = new arangodb::VocbaseContext(request, server, vocbase);
VocbaseContext* ctx = new arangodb::VocbaseContext(request, server, vocbase, RestServerFeature::getJwtSecret());
request->setRequestContext(ctx, true);
// the "true" means the request is the owner of the context

View File

@ -54,6 +54,11 @@ class RestServerFeature final
return RESTSERVER->trustedProxies();
}
static std::string getJwtSecret() {
TRI_ASSERT(RESTSERVER != nullptr);
return RESTSERVER->jwtSecret();
}
private:
static RestServerFeature* RESTSERVER;
static const size_t _maxSecretLength = 64;

View File

@ -23,11 +23,19 @@
#include "VocbaseContext.h"
#include <chrono>
#include <velocypack/Builder.h>
#include <velocypack/Exception.h>
#include <velocypack/Parser.h>
#include <velocypack/velocypack-aliases.h>
#include "Basics/MutexLocker.h"
#include "Basics/tri-strings.h"
#include "Cluster/ServerState.h"
#include "Endpoint/ConnectionInfo.h"
#include "Logger/Logger.h"
#include "Ssl/SslInterface.h"
#include "VocBase/auth.h"
#include "VocBase/server.h"
#include "VocBase/vocbase.h"
@ -139,8 +147,8 @@ double VocbaseContext::accessSid(std::string const& database,
}
VocbaseContext::VocbaseContext(HttpRequest* request, TRI_server_t* server,
TRI_vocbase_t* vocbase)
: RequestContext(request), _server(server), _vocbase(vocbase) {
TRI_vocbase_t* vocbase, std::string const& jwtSecret)
: RequestContext(request), _server(server), _vocbase(vocbase), _jwtSecret(jwtSecret) {
TRI_ASSERT(_server != nullptr);
TRI_ASSERT(_vocbase != nullptr);
}
@ -291,7 +299,7 @@ GeneralResponse::ResponseCode VocbaseContext::authenticate() {
if (!TRI_CaseEqualString(authStr.c_str(), "basic ", 6)) {
return basicAuthentication(auth);
} else if (TRI_CaseEqualString(authStr.c_str(), "bearer ", 7)) {
return jwtAuthentication(auth);
return jwtAuthentication(std::string(auth));
} else {
// mop: hmmm is 403 the correct status code? or 401? or 400? :S
return GeneralResponse::ResponseCode::FORBIDDEN;
@ -371,6 +379,126 @@ GeneralResponse::ResponseCode VocbaseContext::basicAuthentication(const char* au
/// @brief checks the authentication via jwt
////////////////////////////////////////////////////////////////////////////////
GeneralResponse::ResponseCode VocbaseContext::jwtAuthentication(const char* auth) {
return GeneralResponse::ResponseCode::FORBIDDEN;
GeneralResponse::ResponseCode VocbaseContext::jwtAuthentication(std::string const& auth) {
std::vector<std::string> const parts = StringUtils::split(auth, '.');
if (parts.size() != 3) {
return GeneralResponse::ResponseCode::FORBIDDEN;
}
std::string const& header = parts[0];
std::string const& body = parts[1];
std::string const& signature = parts[2];
std::string const message = header + "." + body;
if (!validateJwtHeader(header)) {
LOG(DEBUG) << "Couldn't validate jwt header " << header;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
if (!validateJwtBody(body)) {
LOG(DEBUG) << "Couldn't validate jwt body " << body;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
if (!validateJwtHMAC256Signature(message, signature)) {
LOG(DEBUG) << "Couldn't validate jwt signature " << signature;
return GeneralResponse::ResponseCode::FORBIDDEN;
}
return GeneralResponse::ResponseCode::OK;
}
std::shared_ptr<VPackBuilder> VocbaseContext::parseJson(std::string const& str, std::string const& hint) {
std::shared_ptr<VPackBuilder> result;
VPackParser parser;
try {
parser.parse(str);
result = parser.steal();
} catch (std::bad_alloc const&) {
LOG(ERR) << "Out of memory parsing " << hint << "!";
} catch (VPackException const& ex) {
LOG(DEBUG) << "Couldn't parse " << hint << ": " << ex.what();
} catch (...) {
LOG(ERR) << "Got unknown exception trying to parse " << hint;
}
return result;
}
bool VocbaseContext::validateJwtHeader(std::string const& header) {
std::shared_ptr<VPackBuilder> headerBuilder = parseJson(StringUtils::decodeBase64(header), "jwt header");
if (headerBuilder.get() == nullptr) {
return false;
}
VPackSlice const headerSlice = headerBuilder->slice();
if (!headerSlice.isObject()) {
return false;
}
VPackSlice const algSlice = headerSlice.get("alg");
VPackSlice const typSlice = headerSlice.get("typ");
if (!algSlice.isString()) {
return false;
}
if (!typSlice.isString()) {
return false;
}
if (algSlice.copyString() != "HS256") {
return false;
}
if (typSlice.copyString() != "jwt") {
return false;
}
return true;
}
bool VocbaseContext::validateJwtBody(std::string const& body) {
std::shared_ptr<VPackBuilder> bodyBuilder = parseJson(StringUtils::decodeBase64(body), "jwt body");
if (bodyBuilder.get() == nullptr) {
return false;
}
VPackSlice const bodySlice = bodyBuilder->slice();
if (!bodySlice.isObject()) {
return false;
}
VPackSlice const issSlice = bodySlice.get("iss");
if (!issSlice.isString()) {
return false;
}
if (issSlice.copyString() != "arangodb") {
return false;
}
// mop: optional exp (cluster currently uses non expiring jwts)
if (bodySlice.hasKey("exp")) {
VPackSlice const expSlice = bodySlice.get("exp");
if (!expSlice.isNumber()) {
return false;
}
std::chrono::system_clock::time_point expires(std::chrono::seconds(expSlice.getNumber<uint64_t>()));
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
if (now >= expires) {
return false;
}
}
return true;
}
bool VocbaseContext::validateJwtHMAC256Signature(std::string const& message, std::string const& signature) {
std::string decodedSignature = StringUtils::decodeBase64(signature);
return verifyHMAC(_jwtSecret.c_str(), _jwtSecret.length(), message.c_str(), message.length(), signature.c_str(), signature.length(), SslInterface::Algorithm::ALGORITHM_SHA256);
}

View File

@ -24,6 +24,9 @@
#ifndef ARANGOD_REST_SERVER_VOCBASE_CONTEXT_H
#define ARANGOD_REST_SERVER_VOCBASE_CONTEXT_H 1
#include <velocypack/Builder.h>
#include <velocypack/velocypack-aliases.h>
#include "Basics/Common.h"
#include "Rest/HttpRequest.h"
#include "Rest/HttpResponse.h"
@ -66,7 +69,7 @@ class VocbaseContext : public arangodb::RequestContext {
static double accessSid(std::string const& database, std::string const& sid);
public:
VocbaseContext(HttpRequest*, TRI_server_t*, TRI_vocbase_t*);
VocbaseContext(HttpRequest*, TRI_server_t*, TRI_vocbase_t*, std::string const&);
~VocbaseContext();
@ -95,6 +98,7 @@ class VocbaseContext : public arangodb::RequestContext {
GeneralResponse::ResponseCode authenticate() override final;
private:
//////////////////////////////////////////////////////////////////////////////
/// @brief checks the authentication (basic)
//////////////////////////////////////////////////////////////////////////////
@ -105,7 +109,13 @@ class VocbaseContext : public arangodb::RequestContext {
/// @brief checks the authentication (jwt)
//////////////////////////////////////////////////////////////////////////////
GeneralResponse::ResponseCode jwtAuthentication(const char*);
GeneralResponse::ResponseCode jwtAuthentication(std::string const&);
std::shared_ptr<VPackBuilder> parseJson(std::string const&, std::string const&);
bool validateJwtHeader(std::string const&);
bool validateJwtBody(std::string const&);
bool validateJwtHMAC256Signature(std::string const&, std::string const&);
public:
////////////////////////////////////////////////////////////////////////////////
@ -126,6 +136,8 @@ class VocbaseContext : public arangodb::RequestContext {
//////////////////////////////////////////////////////////////////////////////
TRI_vocbase_t* _vocbase;
std::string const _jwtSecret;
};
}

View File

@ -270,7 +270,7 @@ bool verifyHMAC(char const* challenge, size_t challengeLength,
// result must == BASE64(response, responseLen)
std::string s =
StringUtils::encodeHex(sslHMAC(challenge, challengeLength, secret, secretLen, algorithm));
sslHMAC(challenge, challengeLength, secret, secretLen, algorithm);
if (s.length() == responseLen &&
s.compare(std::string(response, responseLen)) == 0) {