1
0
Fork 0

enable switching from http to vst fix vst authentication

This commit is contained in:
Jan Christoph Uhde 2016-10-26 23:23:35 +02:00
parent 33b6a2b8eb
commit 5b818c3243
10 changed files with 115 additions and 45 deletions

View File

@ -53,9 +53,10 @@ using namespace arangodb::rest;
GeneralCommTask::GeneralCommTask(EventLoop loop, GeneralServer* server,
std::unique_ptr<Socket> socket,
ConnectionInfo&& info, double keepAliveTimeout)
ConnectionInfo&& info, double keepAliveTimeout,
bool skipSocketInit)
: Task(loop, "GeneralCommTask"),
SocketTask(loop, std::move(socket), std::move(info), keepAliveTimeout),
SocketTask(loop, std::move(socket), std::move(info), keepAliveTimeout, skipSocketInit),
_server(server) {}
// -----------------------------------------------------------------------------

View File

@ -83,7 +83,7 @@ class GeneralCommTask : public SocketTask {
public:
GeneralCommTask(EventLoop, GeneralServer*, std::unique_ptr<Socket>,
ConnectionInfo&&, double keepAliveTimeout);
ConnectionInfo&&, double keepAliveTimeout, bool skipSocketInit = false);
virtual void addResponse(GeneralResponse*) = 0;
virtual arangodb::Endpoint::TransportType transportType() = 0;

View File

@ -33,6 +33,8 @@
#include "Rest/HttpRequest.h"
#include "VocBase/ticks.h"
#include "VppCommTask.h"
using namespace arangodb;
using namespace arangodb::basics;
using namespace arangodb::rest;
@ -254,6 +256,23 @@ bool HttpCommTask::processRead() {
return false;
}
LOG_TOPIC(WARN, Logger::COMMUNICATION)
<< std::string(_readBuffer.c_str(), _readBuffer.length());
if (std::strncmp(_readBuffer.c_str(), "VST/1.0\r\n\r\n", 11) == 0) {
LOG_TOPIC(INFO, Logger::COMMUNICATION) << "Switching from Http to Vst";
std::shared_ptr<GeneralCommTask> commTask;
_abandoned = true;
cancelKeepAlive();
commTask = std::make_shared<VppCommTask>(
_loop, _server, std::move(_peer), std::move(_connectionInfo),
GeneralServerFeature::keepAliveTimeout(), /*skipSocketInit*/ true);
commTask->addToReadBuffer(_readBuffer.c_str() + 11,
_readBuffer.length() - 11);
commTask->processRead();
commTask->start();
// statistics?!
return false;
}
// header is complete
if (ptr < end) {
_readPosition = ptr - _readBuffer.c_str() + 4;
@ -531,9 +550,7 @@ bool HttpCommTask::processRead() {
else if (authResult == rest::ResponseCode::FORBIDDEN) {
handleSimpleError(authResult, TRI_ERROR_USER_CHANGE_PASSWORD,
"change password", 1);
}
// not authenticated
else {
} else { // not authenticated
HttpResponse response(rest::ResponseCode::UNAUTHORIZED);
std::string realm = "Bearer token_type=\"JWT\", realm=\"ArangoDB\"";

View File

@ -58,14 +58,14 @@ using namespace arangodb::rest;
VppCommTask::VppCommTask(EventLoop loop, GeneralServer* server,
std::unique_ptr<Socket> socket, ConnectionInfo&& info,
double timeout)
double timeout, bool skipInit)
: Task(loop, "VppCommTask"),
GeneralCommTask(loop, server, std::move(socket), std::move(info),
timeout),
GeneralCommTask(loop, server, std::move(socket), std::move(info), timeout,
skipInit),
_authenticatedUser(),
_authentication(nullptr) {
_authentication = application_features::ApplicationServer::getFeature<AuthenticationFeature>(
"Authentication");
_authentication = application_features::ApplicationServer::getFeature<
AuthenticationFeature>("Authentication");
TRI_ASSERT(_authentication != nullptr);
_protocol = "vpp";
@ -187,7 +187,8 @@ bool VppCommTask::isChunkComplete(char* start) {
return true;
}
void VppCommTask::handleAuthentication(VPackSlice const& header, uint64_t messageId) {
void VppCommTask::handleAuthentication(VPackSlice const& header,
uint64_t messageId) {
// std::string encryption = header.at(2).copyString();
std::string user = header.at(3).copyString();
std::string pass = header.at(4).copyString();
@ -199,20 +200,23 @@ void VppCommTask::handleAuthentication(VPackSlice const& header, uint64_t messag
auto auth = basics::StringUtils::encodeBase64(user + ":" + pass);
AuthResult result = _authentication->authInfo()->checkAuthentication(
AuthInfo::AuthType::BASIC, auth);
authOk = result._authorized;
if (authOk) {
_authenticatedUser = std::move(user);
}
}
if (authOk) {
// mop: hmmm...user should be completely ignored if there is no auth IMHO
_authenticatedUser = std::move(user);
// obi: user who sends authentication expects a reply
handleSimpleError(rest::ResponseCode::OK, TRI_ERROR_NO_ERROR,
"authentication successful", messageId);
"authentication successful", messageId);
} else {
_authenticatedUser.clear();
handleSimpleError(rest::ResponseCode::UNAUTHORIZED,
TRI_ERROR_HTTP_UNAUTHORIZED, "authentication failed",
messageId);
TRI_ERROR_HTTP_UNAUTHORIZED, "authentication failed",
messageId);
}
}
@ -293,11 +297,13 @@ bool VppCommTask::processRead() {
request->setUser(_authenticatedUser);
// check authentication
std::string const& dbname = request->databaseName();
AuthLevel level = AuthLevel::RW;
if (!_authenticatedUser.empty() || !dbname.empty()) {
level = _authentication->canUseDatabase(
_authenticatedUser, dbname);
if (_authentication->isEnabled()) { // only check authorization if
// authentication is enabled
std::string const& dbname = request->databaseName();
if (!(_authenticatedUser.empty() && dbname.empty())) {
level = _authentication->canUseDatabase(_authenticatedUser, dbname);
}
}
if (level != AuthLevel::RW) {

View File

@ -42,7 +42,7 @@ namespace rest {
class VppCommTask : public GeneralCommTask {
public:
VppCommTask(EventLoop, GeneralServer*, std::unique_ptr<Socket> socket,
ConnectionInfo&&, double timeout);
ConnectionInfo&&, double timeout, bool skipSocketInit = false);
// convert from GeneralResponse to vppResponse ad dispatch request to class
// internal addResponse
@ -65,7 +65,7 @@ class VppCommTask : public GeneralCommTask {
std::unique_ptr<GeneralResponse> createResponse(
rest::ResponseCode, uint64_t messageId) override final;
void handleAuthentication(VPackSlice const& header, uint64_t messageId);
void handleSimpleError(rest::ResponseCode code, uint64_t id) override {
VppResponse response(code, id);

View File

@ -41,22 +41,24 @@ using namespace arangodb::rest;
SocketTask::SocketTask(arangodb::EventLoop loop,
std::unique_ptr<arangodb::Socket> socket,
arangodb::ConnectionInfo&& connectionInfo,
double keepAliveTimeout)
double keepAliveTimeout, bool skipInit = false)
: Task(loop, "SocketTask"),
_connectionInfo(connectionInfo),
_readBuffer(TRI_UNKNOWN_MEM_ZONE, READ_BLOCK_SIZE + 1, false),
_peer(std::move(socket)),
_keepAliveTimeout(static_cast<long>(keepAliveTimeout * 1000)),
_useKeepAliveTimeout(static_cast<long>(keepAliveTimeout * 1000) > 0),
_keepAliveTimer(_peer->_ioService, _keepAliveTimeout) {
_keepAliveTimer(_peer->_ioService, _keepAliveTimeout),
_abandoned(false) {
ConnectionStatisticsAgent::acquire();
connectionStatisticsAgentSetStart();
_peer->setNonBlocking(true);
if (!_peer->handshake()) {
_closedSend = true;
_closedReceive = true;
if (!skipInit) {
_peer->setNonBlocking(true);
if (!_peer->handshake()) {
_closedSend = true;
_closedReceive = true;
}
}
}
@ -189,16 +191,20 @@ void SocketTask::completedWriteBuffer() {
_writeBuffer = nullptr;
if (_writeBufferStatistics != nullptr) {
#ifdef DEBUG_STATISTICS
LOG_TOPIC(TRACE, Logger::REQUESTS)
<< "SocketTask::addWriteBuffer - Statistics release: "
<< _writeBufferStatistics->to_string();
#endif
_writeBufferStatistics->_writeEnd = TRI_StatisticsTime();
TRI_ReleaseRequestStatistics(_writeBufferStatistics);
_writeBufferStatistics = nullptr;
} else {
#ifdef DEBUG_STATISTICS
LOG_TOPIC(TRACE, Logger::REQUESTS) << "SocketTask::addWriteBuffer - "
"Statistics release: nullptr - "
"nothing to realease";
#endif
}
if (_writeBuffers.empty()) {
@ -263,6 +269,10 @@ void SocketTask::closeStream() {
// -----------------------------------------------------------------------------
// --SECTION-- private methods
// -----------------------------------------------------------------------------
void SocketTask::addToReadBuffer(char const* data, std::size_t len) {
LOG_TOPIC(DEBUG, Logger::COMMUNICATION) << std::string(data, len);
_readBuffer.appendText(data, len);
}
void SocketTask::resetKeepAlive() {
if (_useKeepAliveTimeout) {
@ -308,7 +318,15 @@ bool SocketTask::reserveMemory() {
bool SocketTask::trySyncRead() {
boost::system::error_code err;
if (_abandoned) {
return false;
}
if (!_peer) {
LOG_TOPIC(DEBUG, Logger::COMMUNICATION) << "SocketTask::trySyncRead "
<< "- peer disappeared ";
}
if (0 == _peer->available(err)) {
return false;
}
@ -321,9 +339,9 @@ bool SocketTask::trySyncRead() {
}
size_t bytesRead = 0;
bytesRead = _peer->read(
boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), err);
bytesRead =
_peer->read(boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), err);
if (0 == bytesRead) {
return false; // should not happen
@ -347,6 +365,10 @@ bool SocketTask::trySyncRead() {
void SocketTask::asyncReadSome() {
try {
if (_abandoned) {
return;
}
JobGuard guard(_loop);
guard.busy();
@ -355,8 +377,7 @@ void SocketTask::asyncReadSome() {
while (++n <= MAX_DIRECT_TRIES) {
if (!reserveMemory()) {
LOG_TOPIC(TRACE, Logger::COMMUNICATION)
<< "failed to reserve memory";
LOG_TOPIC(TRACE, Logger::COMMUNICATION) << "failed to reserve memory";
return;
}
@ -371,6 +392,9 @@ void SocketTask::asyncReadSome() {
}
while (processRead()) {
if (_abandoned) {
return;
}
if (_closeRequested) {
break;
}
@ -406,7 +430,7 @@ void SocketTask::asyncReadSome() {
auto self = shared_from_this();
auto handler = [self, this](const boost::system::error_code& ec,
std::size_t transferred) {
std::size_t transferred) {
if (ec) {
LOG_TOPIC(DEBUG, Logger::COMMUNICATION)
<< "SocketTask::asyncReadSome (async_read_some) - read on stream "
@ -429,14 +453,18 @@ void SocketTask::asyncReadSome() {
<< "close requested, closing receive stream";
closeReceiveStream();
} else if (_abandoned) {
return;
} else {
asyncReadSome();
}
}
};
_peer->asyncRead(
boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), handler);
if (!_abandoned && _peer) {
_peer->asyncRead(boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE),
handler);
}
}
void SocketTask::closeReceiveStream() {

View File

@ -29,14 +29,15 @@
#include <boost/asio/ssl.hpp>
#include "Basics/asio-helper.h"
#include "Basics/StringBuffer.h"
#include "Basics/asio-helper.h"
#include "Scheduler/Socket.h"
#include "Statistics/StatisticsAgent.h"
namespace arangodb {
namespace rest {
class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
friend class HttpCommTask;
explicit SocketTask(SocketTask const&) = delete;
SocketTask& operator=(SocketTask const&) = delete;
@ -45,16 +46,31 @@ class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
public:
SocketTask(EventLoop, std::unique_ptr<Socket>, ConnectionInfo&&,
double keepAliveTimeout);
double keepAliveTimeout, bool skipInit);
virtual ~SocketTask();
std::unique_ptr<Socket> releasePeer() {
_abandoned = true;
return std::move(_peer);
}
ConnectionInfo&& releaseConnectionInfo() {
_abandoned = true;
return std::move(_connectionInfo);
}
public:
void start();
protected:
virtual bool processRead() = 0;
// This function is used during the protocol switch from http
// to VelocyStream. This way we no not require additional
// constructor arguments. It should not be used otherwise.
void addToReadBuffer(char const* data, std::size_t len);
protected:
void addWriteBuffer(std::unique_ptr<basics::StringBuffer> buffer) {
addWriteBuffer(std::move(buffer), (RequestStatisticsAgent*)nullptr);
@ -89,6 +105,7 @@ class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
boost::asio::deadline_timer _keepAliveTimer;
bool _closeRequested = false;
std::atomic_bool _abandoned;
private:
bool reserveMemory();

View File

@ -751,7 +751,7 @@ VPackSlice HttpRequest::payload(VPackOptions const* options) {
} else {
return VPackSlice(_vpackBuilder->slice());
}
}
}
return VPackSlice::noneSlice(); // no body
} else /*VPACK*/ {
VPackValidator validator;

View File

@ -113,7 +113,7 @@ class HttpRequest final : public GeneralRequest {
// key that do not get special treatment end um in the _headers map.
void setHeader(char const* key, size_t keyLength, char const* value,
size_t valueLength);
void setHeader(std::string const& key, std::string const& value) {
setHeader(key.c_str(), key.length(), value.c_str(), value.length());
}

View File

@ -220,7 +220,7 @@ int Communicator::work_once() {
void Communicator::wait() {
static int const MAX_WAIT_MSECS = 1000; // wait max. 1 seconds
int numFds; // not used here
int numFds; // not used here
int res = curl_multi_wait(_curl, &_wakeup, 1, MAX_WAIT_MSECS, &numFds);
if (res != CURLM_OK) {
throw std::runtime_error(
@ -378,7 +378,8 @@ void Communicator::handleResult(CURL* handle, CURLcode rc) {
std::string prefix("Communicator(" + std::to_string(rip->_ticketId) +
") // ");
LOG_TOPIC(TRACE, Logger::REQUESTS)
<< prefix << "Curl rc is : " << rc << " after " << Logger::FIXED(TRI_microtime() - rip->_startTime) << " s";
<< prefix << "Curl rc is : " << rc << " after "
<< Logger::FIXED(TRI_microtime() - rip->_startTime) << " s";
if (strlen(rip->_errorBuffer) != 0) {
LOG_TOPIC(TRACE, Logger::REQUESTS)
<< prefix << "Curl error details: " << rip->_errorBuffer;