mirror of https://gitee.com/bigwinds/arangodb
enable switching from http to vst fix vst authentication
This commit is contained in:
parent
33b6a2b8eb
commit
5b818c3243
|
@ -53,9 +53,10 @@ using namespace arangodb::rest;
|
|||
|
||||
GeneralCommTask::GeneralCommTask(EventLoop loop, GeneralServer* server,
|
||||
std::unique_ptr<Socket> socket,
|
||||
ConnectionInfo&& info, double keepAliveTimeout)
|
||||
ConnectionInfo&& info, double keepAliveTimeout,
|
||||
bool skipSocketInit)
|
||||
: Task(loop, "GeneralCommTask"),
|
||||
SocketTask(loop, std::move(socket), std::move(info), keepAliveTimeout),
|
||||
SocketTask(loop, std::move(socket), std::move(info), keepAliveTimeout, skipSocketInit),
|
||||
_server(server) {}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -83,7 +83,7 @@ class GeneralCommTask : public SocketTask {
|
|||
|
||||
public:
|
||||
GeneralCommTask(EventLoop, GeneralServer*, std::unique_ptr<Socket>,
|
||||
ConnectionInfo&&, double keepAliveTimeout);
|
||||
ConnectionInfo&&, double keepAliveTimeout, bool skipSocketInit = false);
|
||||
|
||||
virtual void addResponse(GeneralResponse*) = 0;
|
||||
virtual arangodb::Endpoint::TransportType transportType() = 0;
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
#include "Rest/HttpRequest.h"
|
||||
#include "VocBase/ticks.h"
|
||||
|
||||
#include "VppCommTask.h"
|
||||
|
||||
using namespace arangodb;
|
||||
using namespace arangodb::basics;
|
||||
using namespace arangodb::rest;
|
||||
|
@ -254,6 +256,23 @@ bool HttpCommTask::processRead() {
|
|||
return false;
|
||||
}
|
||||
|
||||
LOG_TOPIC(WARN, Logger::COMMUNICATION)
|
||||
<< std::string(_readBuffer.c_str(), _readBuffer.length());
|
||||
if (std::strncmp(_readBuffer.c_str(), "VST/1.0\r\n\r\n", 11) == 0) {
|
||||
LOG_TOPIC(INFO, Logger::COMMUNICATION) << "Switching from Http to Vst";
|
||||
std::shared_ptr<GeneralCommTask> commTask;
|
||||
_abandoned = true;
|
||||
cancelKeepAlive();
|
||||
commTask = std::make_shared<VppCommTask>(
|
||||
_loop, _server, std::move(_peer), std::move(_connectionInfo),
|
||||
GeneralServerFeature::keepAliveTimeout(), /*skipSocketInit*/ true);
|
||||
commTask->addToReadBuffer(_readBuffer.c_str() + 11,
|
||||
_readBuffer.length() - 11);
|
||||
commTask->processRead();
|
||||
commTask->start();
|
||||
// statistics?!
|
||||
return false;
|
||||
}
|
||||
// header is complete
|
||||
if (ptr < end) {
|
||||
_readPosition = ptr - _readBuffer.c_str() + 4;
|
||||
|
@ -531,9 +550,7 @@ bool HttpCommTask::processRead() {
|
|||
else if (authResult == rest::ResponseCode::FORBIDDEN) {
|
||||
handleSimpleError(authResult, TRI_ERROR_USER_CHANGE_PASSWORD,
|
||||
"change password", 1);
|
||||
}
|
||||
// not authenticated
|
||||
else {
|
||||
} else { // not authenticated
|
||||
HttpResponse response(rest::ResponseCode::UNAUTHORIZED);
|
||||
std::string realm = "Bearer token_type=\"JWT\", realm=\"ArangoDB\"";
|
||||
|
||||
|
|
|
@ -58,14 +58,14 @@ using namespace arangodb::rest;
|
|||
|
||||
VppCommTask::VppCommTask(EventLoop loop, GeneralServer* server,
|
||||
std::unique_ptr<Socket> socket, ConnectionInfo&& info,
|
||||
double timeout)
|
||||
double timeout, bool skipInit)
|
||||
: Task(loop, "VppCommTask"),
|
||||
GeneralCommTask(loop, server, std::move(socket), std::move(info),
|
||||
timeout),
|
||||
GeneralCommTask(loop, server, std::move(socket), std::move(info), timeout,
|
||||
skipInit),
|
||||
_authenticatedUser(),
|
||||
_authentication(nullptr) {
|
||||
_authentication = application_features::ApplicationServer::getFeature<AuthenticationFeature>(
|
||||
"Authentication");
|
||||
_authentication = application_features::ApplicationServer::getFeature<
|
||||
AuthenticationFeature>("Authentication");
|
||||
TRI_ASSERT(_authentication != nullptr);
|
||||
|
||||
_protocol = "vpp";
|
||||
|
@ -187,7 +187,8 @@ bool VppCommTask::isChunkComplete(char* start) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void VppCommTask::handleAuthentication(VPackSlice const& header, uint64_t messageId) {
|
||||
void VppCommTask::handleAuthentication(VPackSlice const& header,
|
||||
uint64_t messageId) {
|
||||
// std::string encryption = header.at(2).copyString();
|
||||
std::string user = header.at(3).copyString();
|
||||
std::string pass = header.at(4).copyString();
|
||||
|
@ -199,20 +200,23 @@ void VppCommTask::handleAuthentication(VPackSlice const& header, uint64_t messag
|
|||
auto auth = basics::StringUtils::encodeBase64(user + ":" + pass);
|
||||
AuthResult result = _authentication->authInfo()->checkAuthentication(
|
||||
AuthInfo::AuthType::BASIC, auth);
|
||||
|
||||
|
||||
authOk = result._authorized;
|
||||
if (authOk) {
|
||||
_authenticatedUser = std::move(user);
|
||||
}
|
||||
}
|
||||
|
||||
if (authOk) {
|
||||
// mop: hmmm...user should be completely ignored if there is no auth IMHO
|
||||
_authenticatedUser = std::move(user);
|
||||
// obi: user who sends authentication expects a reply
|
||||
handleSimpleError(rest::ResponseCode::OK, TRI_ERROR_NO_ERROR,
|
||||
"authentication successful", messageId);
|
||||
"authentication successful", messageId);
|
||||
} else {
|
||||
_authenticatedUser.clear();
|
||||
handleSimpleError(rest::ResponseCode::UNAUTHORIZED,
|
||||
TRI_ERROR_HTTP_UNAUTHORIZED, "authentication failed",
|
||||
messageId);
|
||||
TRI_ERROR_HTTP_UNAUTHORIZED, "authentication failed",
|
||||
messageId);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -293,11 +297,13 @@ bool VppCommTask::processRead() {
|
|||
request->setUser(_authenticatedUser);
|
||||
|
||||
// check authentication
|
||||
std::string const& dbname = request->databaseName();
|
||||
AuthLevel level = AuthLevel::RW;
|
||||
if (!_authenticatedUser.empty() || !dbname.empty()) {
|
||||
level = _authentication->canUseDatabase(
|
||||
_authenticatedUser, dbname);
|
||||
if (_authentication->isEnabled()) { // only check authorization if
|
||||
// authentication is enabled
|
||||
std::string const& dbname = request->databaseName();
|
||||
if (!(_authenticatedUser.empty() && dbname.empty())) {
|
||||
level = _authentication->canUseDatabase(_authenticatedUser, dbname);
|
||||
}
|
||||
}
|
||||
|
||||
if (level != AuthLevel::RW) {
|
||||
|
|
|
@ -42,7 +42,7 @@ namespace rest {
|
|||
class VppCommTask : public GeneralCommTask {
|
||||
public:
|
||||
VppCommTask(EventLoop, GeneralServer*, std::unique_ptr<Socket> socket,
|
||||
ConnectionInfo&&, double timeout);
|
||||
ConnectionInfo&&, double timeout, bool skipSocketInit = false);
|
||||
|
||||
// convert from GeneralResponse to vppResponse ad dispatch request to class
|
||||
// internal addResponse
|
||||
|
@ -65,7 +65,7 @@ class VppCommTask : public GeneralCommTask {
|
|||
|
||||
std::unique_ptr<GeneralResponse> createResponse(
|
||||
rest::ResponseCode, uint64_t messageId) override final;
|
||||
|
||||
|
||||
void handleAuthentication(VPackSlice const& header, uint64_t messageId);
|
||||
void handleSimpleError(rest::ResponseCode code, uint64_t id) override {
|
||||
VppResponse response(code, id);
|
||||
|
|
|
@ -41,22 +41,24 @@ using namespace arangodb::rest;
|
|||
SocketTask::SocketTask(arangodb::EventLoop loop,
|
||||
std::unique_ptr<arangodb::Socket> socket,
|
||||
arangodb::ConnectionInfo&& connectionInfo,
|
||||
double keepAliveTimeout)
|
||||
double keepAliveTimeout, bool skipInit = false)
|
||||
: Task(loop, "SocketTask"),
|
||||
_connectionInfo(connectionInfo),
|
||||
_readBuffer(TRI_UNKNOWN_MEM_ZONE, READ_BLOCK_SIZE + 1, false),
|
||||
_peer(std::move(socket)),
|
||||
_keepAliveTimeout(static_cast<long>(keepAliveTimeout * 1000)),
|
||||
_useKeepAliveTimeout(static_cast<long>(keepAliveTimeout * 1000) > 0),
|
||||
_keepAliveTimer(_peer->_ioService, _keepAliveTimeout) {
|
||||
_keepAliveTimer(_peer->_ioService, _keepAliveTimeout),
|
||||
_abandoned(false) {
|
||||
ConnectionStatisticsAgent::acquire();
|
||||
connectionStatisticsAgentSetStart();
|
||||
|
||||
_peer->setNonBlocking(true);
|
||||
|
||||
if (!_peer->handshake()) {
|
||||
_closedSend = true;
|
||||
_closedReceive = true;
|
||||
if (!skipInit) {
|
||||
_peer->setNonBlocking(true);
|
||||
if (!_peer->handshake()) {
|
||||
_closedSend = true;
|
||||
_closedReceive = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,16 +191,20 @@ void SocketTask::completedWriteBuffer() {
|
|||
_writeBuffer = nullptr;
|
||||
|
||||
if (_writeBufferStatistics != nullptr) {
|
||||
#ifdef DEBUG_STATISTICS
|
||||
LOG_TOPIC(TRACE, Logger::REQUESTS)
|
||||
<< "SocketTask::addWriteBuffer - Statistics release: "
|
||||
<< _writeBufferStatistics->to_string();
|
||||
#endif
|
||||
_writeBufferStatistics->_writeEnd = TRI_StatisticsTime();
|
||||
TRI_ReleaseRequestStatistics(_writeBufferStatistics);
|
||||
_writeBufferStatistics = nullptr;
|
||||
} else {
|
||||
#ifdef DEBUG_STATISTICS
|
||||
LOG_TOPIC(TRACE, Logger::REQUESTS) << "SocketTask::addWriteBuffer - "
|
||||
"Statistics release: nullptr - "
|
||||
"nothing to realease";
|
||||
#endif
|
||||
}
|
||||
|
||||
if (_writeBuffers.empty()) {
|
||||
|
@ -263,6 +269,10 @@ void SocketTask::closeStream() {
|
|||
// -----------------------------------------------------------------------------
|
||||
// --SECTION-- private methods
|
||||
// -----------------------------------------------------------------------------
|
||||
void SocketTask::addToReadBuffer(char const* data, std::size_t len) {
|
||||
LOG_TOPIC(DEBUG, Logger::COMMUNICATION) << std::string(data, len);
|
||||
_readBuffer.appendText(data, len);
|
||||
}
|
||||
|
||||
void SocketTask::resetKeepAlive() {
|
||||
if (_useKeepAliveTimeout) {
|
||||
|
@ -308,7 +318,15 @@ bool SocketTask::reserveMemory() {
|
|||
|
||||
bool SocketTask::trySyncRead() {
|
||||
boost::system::error_code err;
|
||||
|
||||
if (_abandoned) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_peer) {
|
||||
LOG_TOPIC(DEBUG, Logger::COMMUNICATION) << "SocketTask::trySyncRead "
|
||||
<< "- peer disappeared ";
|
||||
}
|
||||
|
||||
if (0 == _peer->available(err)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -321,9 +339,9 @@ bool SocketTask::trySyncRead() {
|
|||
}
|
||||
|
||||
size_t bytesRead = 0;
|
||||
|
||||
bytesRead = _peer->read(
|
||||
boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), err);
|
||||
|
||||
bytesRead =
|
||||
_peer->read(boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), err);
|
||||
|
||||
if (0 == bytesRead) {
|
||||
return false; // should not happen
|
||||
|
@ -347,6 +365,10 @@ bool SocketTask::trySyncRead() {
|
|||
|
||||
void SocketTask::asyncReadSome() {
|
||||
try {
|
||||
if (_abandoned) {
|
||||
return;
|
||||
}
|
||||
|
||||
JobGuard guard(_loop);
|
||||
guard.busy();
|
||||
|
||||
|
@ -355,8 +377,7 @@ void SocketTask::asyncReadSome() {
|
|||
|
||||
while (++n <= MAX_DIRECT_TRIES) {
|
||||
if (!reserveMemory()) {
|
||||
LOG_TOPIC(TRACE, Logger::COMMUNICATION)
|
||||
<< "failed to reserve memory";
|
||||
LOG_TOPIC(TRACE, Logger::COMMUNICATION) << "failed to reserve memory";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -371,6 +392,9 @@ void SocketTask::asyncReadSome() {
|
|||
}
|
||||
|
||||
while (processRead()) {
|
||||
if (_abandoned) {
|
||||
return;
|
||||
}
|
||||
if (_closeRequested) {
|
||||
break;
|
||||
}
|
||||
|
@ -406,7 +430,7 @@ void SocketTask::asyncReadSome() {
|
|||
|
||||
auto self = shared_from_this();
|
||||
auto handler = [self, this](const boost::system::error_code& ec,
|
||||
std::size_t transferred) {
|
||||
std::size_t transferred) {
|
||||
if (ec) {
|
||||
LOG_TOPIC(DEBUG, Logger::COMMUNICATION)
|
||||
<< "SocketTask::asyncReadSome (async_read_some) - read on stream "
|
||||
|
@ -429,14 +453,18 @@ void SocketTask::asyncReadSome() {
|
|||
<< "close requested, closing receive stream";
|
||||
|
||||
closeReceiveStream();
|
||||
} else if (_abandoned) {
|
||||
return;
|
||||
} else {
|
||||
asyncReadSome();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
_peer->asyncRead(
|
||||
boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE), handler);
|
||||
if (!_abandoned && _peer) {
|
||||
_peer->asyncRead(boost::asio::buffer(_readBuffer.end(), READ_BLOCK_SIZE),
|
||||
handler);
|
||||
}
|
||||
}
|
||||
|
||||
void SocketTask::closeReceiveStream() {
|
||||
|
|
|
@ -29,14 +29,15 @@
|
|||
|
||||
#include <boost/asio/ssl.hpp>
|
||||
|
||||
#include "Basics/asio-helper.h"
|
||||
#include "Basics/StringBuffer.h"
|
||||
#include "Basics/asio-helper.h"
|
||||
#include "Scheduler/Socket.h"
|
||||
#include "Statistics/StatisticsAgent.h"
|
||||
|
||||
namespace arangodb {
|
||||
namespace rest {
|
||||
class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
|
||||
friend class HttpCommTask;
|
||||
explicit SocketTask(SocketTask const&) = delete;
|
||||
SocketTask& operator=(SocketTask const&) = delete;
|
||||
|
||||
|
@ -45,16 +46,31 @@ class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
|
|||
|
||||
public:
|
||||
SocketTask(EventLoop, std::unique_ptr<Socket>, ConnectionInfo&&,
|
||||
double keepAliveTimeout);
|
||||
double keepAliveTimeout, bool skipInit);
|
||||
|
||||
virtual ~SocketTask();
|
||||
|
||||
std::unique_ptr<Socket> releasePeer() {
|
||||
_abandoned = true;
|
||||
return std::move(_peer);
|
||||
}
|
||||
|
||||
ConnectionInfo&& releaseConnectionInfo() {
|
||||
_abandoned = true;
|
||||
return std::move(_connectionInfo);
|
||||
}
|
||||
|
||||
public:
|
||||
void start();
|
||||
|
||||
protected:
|
||||
virtual bool processRead() = 0;
|
||||
|
||||
// This function is used during the protocol switch from http
|
||||
// to VelocyStream. This way we no not require additional
|
||||
// constructor arguments. It should not be used otherwise.
|
||||
void addToReadBuffer(char const* data, std::size_t len);
|
||||
|
||||
protected:
|
||||
void addWriteBuffer(std::unique_ptr<basics::StringBuffer> buffer) {
|
||||
addWriteBuffer(std::move(buffer), (RequestStatisticsAgent*)nullptr);
|
||||
|
@ -89,6 +105,7 @@ class SocketTask : virtual public Task, public ConnectionStatisticsAgent {
|
|||
boost::asio::deadline_timer _keepAliveTimer;
|
||||
|
||||
bool _closeRequested = false;
|
||||
std::atomic_bool _abandoned;
|
||||
|
||||
private:
|
||||
bool reserveMemory();
|
||||
|
|
|
@ -751,7 +751,7 @@ VPackSlice HttpRequest::payload(VPackOptions const* options) {
|
|||
} else {
|
||||
return VPackSlice(_vpackBuilder->slice());
|
||||
}
|
||||
}
|
||||
}
|
||||
return VPackSlice::noneSlice(); // no body
|
||||
} else /*VPACK*/ {
|
||||
VPackValidator validator;
|
||||
|
|
|
@ -113,7 +113,7 @@ class HttpRequest final : public GeneralRequest {
|
|||
// key that do not get special treatment end um in the _headers map.
|
||||
void setHeader(char const* key, size_t keyLength, char const* value,
|
||||
size_t valueLength);
|
||||
|
||||
|
||||
void setHeader(std::string const& key, std::string const& value) {
|
||||
setHeader(key.c_str(), key.length(), value.c_str(), value.length());
|
||||
}
|
||||
|
|
|
@ -220,7 +220,7 @@ int Communicator::work_once() {
|
|||
void Communicator::wait() {
|
||||
static int const MAX_WAIT_MSECS = 1000; // wait max. 1 seconds
|
||||
|
||||
int numFds; // not used here
|
||||
int numFds; // not used here
|
||||
int res = curl_multi_wait(_curl, &_wakeup, 1, MAX_WAIT_MSECS, &numFds);
|
||||
if (res != CURLM_OK) {
|
||||
throw std::runtime_error(
|
||||
|
@ -378,7 +378,8 @@ void Communicator::handleResult(CURL* handle, CURLcode rc) {
|
|||
std::string prefix("Communicator(" + std::to_string(rip->_ticketId) +
|
||||
") // ");
|
||||
LOG_TOPIC(TRACE, Logger::REQUESTS)
|
||||
<< prefix << "Curl rc is : " << rc << " after " << Logger::FIXED(TRI_microtime() - rip->_startTime) << " s";
|
||||
<< prefix << "Curl rc is : " << rc << " after "
|
||||
<< Logger::FIXED(TRI_microtime() - rip->_startTime) << " s";
|
||||
if (strlen(rip->_errorBuffer) != 0) {
|
||||
LOG_TOPIC(TRACE, Logger::REQUESTS)
|
||||
<< prefix << "Curl error details: " << rip->_errorBuffer;
|
||||
|
|
Loading…
Reference in New Issue