mirror of https://gitee.com/bigwinds/arangodb
Compiles
This commit is contained in:
parent
6cb00f905d
commit
87f09b986a
|
@ -55,7 +55,7 @@ std::string RestAuthHandler::generateJwt(std::string const& username, std::strin
|
|||
VPackBuilder bodyBuilder;
|
||||
{
|
||||
VPackObjectBuilder p(&bodyBuilder);
|
||||
bodyBuilder.add("username", VPackValue(username));
|
||||
bodyBuilder.add("preferred_username", VPackValue(username));
|
||||
bodyBuilder.add("iss", VPackValue("arangodb"));
|
||||
bodyBuilder.add("exp", VPackValue(exp.count()));
|
||||
}
|
||||
|
|
|
@ -228,6 +228,7 @@ static TRI_vocbase_t* LookupDatabaseFromRequest(HttpRequest* request,
|
|||
}
|
||||
|
||||
static bool SetRequestContext(HttpRequest* request, void* data) {
|
||||
TRI_ASSERT(RestServerFeature::RESTSERVER != nullptr);
|
||||
TRI_server_t* server = static_cast<TRI_server_t*>(data);
|
||||
TRI_vocbase_t* vocbase = LookupDatabaseFromRequest(request, server);
|
||||
|
||||
|
@ -242,7 +243,7 @@ static bool SetRequestContext(HttpRequest* request, void* data) {
|
|||
return false;
|
||||
}
|
||||
|
||||
VocbaseContext* ctx = new arangodb::VocbaseContext(request, server, vocbase);
|
||||
VocbaseContext* ctx = new arangodb::VocbaseContext(request, server, vocbase, RestServerFeature::getJwtSecret());
|
||||
request->setRequestContext(ctx, true);
|
||||
|
||||
// the "true" means the request is the owner of the context
|
||||
|
|
|
@ -54,6 +54,11 @@ class RestServerFeature final
|
|||
return RESTSERVER->trustedProxies();
|
||||
}
|
||||
|
||||
static std::string getJwtSecret() {
|
||||
TRI_ASSERT(RESTSERVER != nullptr);
|
||||
return RESTSERVER->jwtSecret();
|
||||
}
|
||||
|
||||
private:
|
||||
static RestServerFeature* RESTSERVER;
|
||||
static const size_t _maxSecretLength = 64;
|
||||
|
|
|
@ -23,11 +23,19 @@
|
|||
|
||||
#include "VocbaseContext.h"
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include <velocypack/Builder.h>
|
||||
#include <velocypack/Exception.h>
|
||||
#include <velocypack/Parser.h>
|
||||
#include <velocypack/velocypack-aliases.h>
|
||||
|
||||
#include "Basics/MutexLocker.h"
|
||||
#include "Basics/tri-strings.h"
|
||||
#include "Cluster/ServerState.h"
|
||||
#include "Endpoint/ConnectionInfo.h"
|
||||
#include "Logger/Logger.h"
|
||||
#include "Ssl/SslInterface.h"
|
||||
#include "VocBase/auth.h"
|
||||
#include "VocBase/server.h"
|
||||
#include "VocBase/vocbase.h"
|
||||
|
@ -139,8 +147,8 @@ double VocbaseContext::accessSid(std::string const& database,
|
|||
}
|
||||
|
||||
VocbaseContext::VocbaseContext(HttpRequest* request, TRI_server_t* server,
|
||||
TRI_vocbase_t* vocbase)
|
||||
: RequestContext(request), _server(server), _vocbase(vocbase) {
|
||||
TRI_vocbase_t* vocbase, std::string const& jwtSecret)
|
||||
: RequestContext(request), _server(server), _vocbase(vocbase), _jwtSecret(jwtSecret) {
|
||||
TRI_ASSERT(_server != nullptr);
|
||||
TRI_ASSERT(_vocbase != nullptr);
|
||||
}
|
||||
|
@ -291,7 +299,7 @@ GeneralResponse::ResponseCode VocbaseContext::authenticate() {
|
|||
if (!TRI_CaseEqualString(authStr.c_str(), "basic ", 6)) {
|
||||
return basicAuthentication(auth);
|
||||
} else if (TRI_CaseEqualString(authStr.c_str(), "bearer ", 7)) {
|
||||
return jwtAuthentication(auth);
|
||||
return jwtAuthentication(std::string(auth));
|
||||
} else {
|
||||
// mop: hmmm is 403 the correct status code? or 401? or 400? :S
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
|
@ -371,6 +379,126 @@ GeneralResponse::ResponseCode VocbaseContext::basicAuthentication(const char* au
|
|||
/// @brief checks the authentication via jwt
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
GeneralResponse::ResponseCode VocbaseContext::jwtAuthentication(const char* auth) {
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
GeneralResponse::ResponseCode VocbaseContext::jwtAuthentication(std::string const& auth) {
|
||||
std::vector<std::string> const parts = StringUtils::split(auth, '.');
|
||||
|
||||
if (parts.size() != 3) {
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
std::string const& header = parts[0];
|
||||
std::string const& body = parts[1];
|
||||
std::string const& signature = parts[2];
|
||||
|
||||
std::string const message = header + "." + body;
|
||||
|
||||
if (!validateJwtHeader(header)) {
|
||||
LOG(DEBUG) << "Couldn't validate jwt header " << header;
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
if (!validateJwtBody(body)) {
|
||||
LOG(DEBUG) << "Couldn't validate jwt body " << body;
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
if (!validateJwtHMAC256Signature(message, signature)) {
|
||||
LOG(DEBUG) << "Couldn't validate jwt signature " << signature;
|
||||
return GeneralResponse::ResponseCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
return GeneralResponse::ResponseCode::OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<VPackBuilder> VocbaseContext::parseJson(std::string const& str, std::string const& hint) {
|
||||
std::shared_ptr<VPackBuilder> result;
|
||||
VPackParser parser;
|
||||
try {
|
||||
parser.parse(str);
|
||||
result = parser.steal();
|
||||
} catch (std::bad_alloc const&) {
|
||||
LOG(ERR) << "Out of memory parsing " << hint << "!";
|
||||
} catch (VPackException const& ex) {
|
||||
LOG(DEBUG) << "Couldn't parse " << hint << ": " << ex.what();
|
||||
} catch (...) {
|
||||
LOG(ERR) << "Got unknown exception trying to parse " << hint;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool VocbaseContext::validateJwtHeader(std::string const& header) {
|
||||
std::shared_ptr<VPackBuilder> headerBuilder = parseJson(StringUtils::decodeBase64(header), "jwt header");
|
||||
if (headerBuilder.get() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VPackSlice const headerSlice = headerBuilder->slice();
|
||||
if (!headerSlice.isObject()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VPackSlice const algSlice = headerSlice.get("alg");
|
||||
VPackSlice const typSlice = headerSlice.get("typ");
|
||||
|
||||
if (!algSlice.isString()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!typSlice.isString()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (algSlice.copyString() != "HS256") {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (typSlice.copyString() != "jwt") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool VocbaseContext::validateJwtBody(std::string const& body) {
|
||||
std::shared_ptr<VPackBuilder> bodyBuilder = parseJson(StringUtils::decodeBase64(body), "jwt body");
|
||||
if (bodyBuilder.get() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VPackSlice const bodySlice = bodyBuilder->slice();
|
||||
if (!bodySlice.isObject()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VPackSlice const issSlice = bodySlice.get("iss");
|
||||
if (!issSlice.isString()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (issSlice.copyString() != "arangodb") {
|
||||
return false;
|
||||
}
|
||||
|
||||
// mop: optional exp (cluster currently uses non expiring jwts)
|
||||
if (bodySlice.hasKey("exp")) {
|
||||
VPackSlice const expSlice = bodySlice.get("exp");
|
||||
|
||||
if (!expSlice.isNumber()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::chrono::system_clock::time_point expires(std::chrono::seconds(expSlice.getNumber<uint64_t>()));
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
|
||||
if (now >= expires) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool VocbaseContext::validateJwtHMAC256Signature(std::string const& message, std::string const& signature) {
|
||||
std::string decodedSignature = StringUtils::decodeBase64(signature);
|
||||
return verifyHMAC(_jwtSecret.c_str(), _jwtSecret.length(), message.c_str(), message.length(), signature.c_str(), signature.length(), SslInterface::Algorithm::ALGORITHM_SHA256);
|
||||
}
|
||||
|
|
|
@ -24,6 +24,9 @@
|
|||
#ifndef ARANGOD_REST_SERVER_VOCBASE_CONTEXT_H
|
||||
#define ARANGOD_REST_SERVER_VOCBASE_CONTEXT_H 1
|
||||
|
||||
#include <velocypack/Builder.h>
|
||||
#include <velocypack/velocypack-aliases.h>
|
||||
|
||||
#include "Basics/Common.h"
|
||||
#include "Rest/HttpRequest.h"
|
||||
#include "Rest/HttpResponse.h"
|
||||
|
@ -66,7 +69,7 @@ class VocbaseContext : public arangodb::RequestContext {
|
|||
static double accessSid(std::string const& database, std::string const& sid);
|
||||
|
||||
public:
|
||||
VocbaseContext(HttpRequest*, TRI_server_t*, TRI_vocbase_t*);
|
||||
VocbaseContext(HttpRequest*, TRI_server_t*, TRI_vocbase_t*, std::string const&);
|
||||
|
||||
~VocbaseContext();
|
||||
|
||||
|
@ -95,6 +98,7 @@ class VocbaseContext : public arangodb::RequestContext {
|
|||
|
||||
GeneralResponse::ResponseCode authenticate() override final;
|
||||
|
||||
private:
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// @brief checks the authentication (basic)
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -105,7 +109,13 @@ class VocbaseContext : public arangodb::RequestContext {
|
|||
/// @brief checks the authentication (jwt)
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
GeneralResponse::ResponseCode jwtAuthentication(const char*);
|
||||
GeneralResponse::ResponseCode jwtAuthentication(std::string const&);
|
||||
|
||||
std::shared_ptr<VPackBuilder> parseJson(std::string const&, std::string const&);
|
||||
|
||||
bool validateJwtHeader(std::string const&);
|
||||
bool validateJwtBody(std::string const&);
|
||||
bool validateJwtHMAC256Signature(std::string const&, std::string const&);
|
||||
|
||||
public:
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -126,6 +136,8 @@ class VocbaseContext : public arangodb::RequestContext {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TRI_vocbase_t* _vocbase;
|
||||
|
||||
std::string const _jwtSecret;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -270,7 +270,7 @@ bool verifyHMAC(char const* challenge, size_t challengeLength,
|
|||
// result must == BASE64(response, responseLen)
|
||||
|
||||
std::string s =
|
||||
StringUtils::encodeHex(sslHMAC(challenge, challengeLength, secret, secretLen, algorithm));
|
||||
sslHMAC(challenge, challengeLength, secret, secretLen, algorithm);
|
||||
|
||||
if (s.length() == responseLen &&
|
||||
s.compare(std::string(response, responseLen)) == 0) {
|
||||
|
|
Loading…
Reference in New Issue