1
0
Fork 0

moved create of CommTask into ListenTask

This commit is contained in:
Frank Celler 2016-07-23 20:04:38 +02:00
parent a92add0a9f
commit 96e69b5b92
11 changed files with 132 additions and 213 deletions

View File

@ -58,8 +58,7 @@ GeneralCommTask::GeneralCommTask(GeneralServer* server, TRI_socket_t socket,
_writeBuffers(),
_writeBuffersStats(),
_isChunked(false),
_requestPending(false),
_setupDone(false) {
_requestPending(false) {
LOG(TRACE) << "connection established, client "
<< TRI_get_fd_or_handle_of_socket(socket) << ", server ip "
<< _connectionInfo.serverAddress << ", server port "
@ -122,22 +121,10 @@ void GeneralCommTask::handleSimpleError(
}
}
bool GeneralCommTask::setup(Scheduler* scheduler, EventLoop loop) {
bool ok = SocketTask::setup(scheduler, loop);
if (!ok) return false;
_scheduler = scheduler;
_loop = loop;
setupDone();
return true;
}
bool GeneralCommTask::handleEvent(EventToken token, EventType events) {
bool result = SocketTask::handleEvent(token, events);
if (_clientClosed) _scheduler->destroyTask(this);
return result;
}
void GeneralCommTask::handleTimeout() {
_clientClosed = true;
_server->handleCommunicationClosed(this);
}
void GeneralCommTask::handleTimeout() { _clientClosed = true; }

View File

@ -27,6 +27,8 @@
#include "Scheduler/SocketTask.h"
#include <openssl/ssl.h>
#include "Basics/Mutex.h"
#include "Basics/StringBuffer.h"
#include "Basics/WorkItem.h"
@ -55,9 +57,6 @@ class GeneralCommTask : public SocketTask, public RequestStatisticsAgent {
void handleSimpleError(GeneralResponse::ResponseCode, int code,
std::string const& errorMessage);
// task set up complete
void setupDone() { _setupDone.store(true, std::memory_order_relaxed); }
protected:
virtual ~GeneralCommTask();
@ -68,8 +67,6 @@ class GeneralCommTask : public SocketTask, public RequestStatisticsAgent {
virtual bool handleEvent(EventToken token,
EventType events) override; // called by TODO
virtual bool setup(Scheduler* scheduler,
EventLoop loop) override; // called by
void cleanup() override final { SocketTask::cleanup(); }
@ -91,12 +88,11 @@ class GeneralCommTask : public SocketTask, public RequestStatisticsAgent {
bool _startThread;
std::deque<basics::StringBuffer*> _writeBuffers;
std::deque<TRI_request_statistics_t*>
_writeBuffersStats; // statistics buffers
bool _isChunked; // true if within a chunked response
bool _requestPending; // true if request is complete but not handled
std::atomic<bool> _setupDone; // task ready
}; // Commontask
} // rest
} // arango
_writeBuffersStats; // statistics buffers
bool _isChunked; // true if within a chunked response
bool _requestPending; // true if request is complete but not handled
};
}
}
#endif

View File

@ -25,6 +25,10 @@
#include "GeneralListenTask.h"
#include "GeneralServer/GeneralServer.h"
#include "GeneralServer/GeneralServerFeature.h"
#include "Scheduler/Scheduler.h"
#include "Scheduler/SchedulerFeature.h"
#include "Ssl/SslServerFeature.h"
using namespace arangodb;
using namespace arangodb::rest;
@ -38,9 +42,45 @@ GeneralListenTask::GeneralListenTask(GeneralServer* server, Endpoint* endpoint,
: Task("GeneralListenTask"),
ListenTask(endpoint),
_server(server),
_connectionType(connectionType) {}
_connectionType(connectionType) {
_keepAliveTimeout = GeneralServerFeature::keepAliveTimeout();
bool GeneralListenTask::handleConnected(TRI_socket_t s, ConnectionInfo&& info) {
_server->handleConnected(s, std::move(info), _connectionType);
SslServerFeature* ssl =
application_features::ApplicationServer::getFeature<SslServerFeature>(
"SslServer");
if (ssl != nullptr) {
_sslContext = ssl->sslContext();
}
_verificationMode = GeneralServerFeature::verificationMode();
_verificationCallback = GeneralServerFeature::verificationCallback();
}
bool GeneralListenTask::handleConnected(TRI_socket_t socket,
ConnectionInfo&& info) {
GeneralCommTask* commTask;
switch (_connectionType) {
case ConnectionType::VPPS:
commTask =
new HttpCommTask(_server, socket, std::move(info), _keepAliveTimeout);
break;
case ConnectionType::VPP:
commTask =
new HttpCommTask(_server, socket, std::move(info), _keepAliveTimeout);
break;
case ConnectionType::HTTPS:
commTask = new HttpsCommTask(_server, socket, std::move(info),
_keepAliveTimeout, _sslContext,
_verificationMode, _verificationCallback);
break;
case ConnectionType::HTTP:
commTask =
new HttpCommTask(_server, socket, std::move(info), _keepAliveTimeout);
break;
}
SchedulerFeature::SCHEDULER->registerTask(commTask);
return true;
}

View File

@ -26,6 +26,9 @@
#define ARANGOD_HTTP_SERVER_HTTP_LISTEN_TASK_H 1
#include "Scheduler/ListenTask.h"
#include <openssl/ssl.h>
#include "GeneralServer/GeneralDefinitions.h"
namespace arangodb {
@ -34,19 +37,11 @@ class Endpoint;
namespace rest {
class GeneralServer;
////////////////////////////////////////////////////////////////////////////////
/// @brief task used to establish connections
////////////////////////////////////////////////////////////////////////////////
class GeneralListenTask : public ListenTask {
GeneralListenTask(GeneralListenTask const&) = delete;
GeneralListenTask& operator=(GeneralListenTask const&) = delete;
public:
//////////////////////////////////////////////////////////////////////////////
/// @brief listen to given port
//////////////////////////////////////////////////////////////////////////////
GeneralListenTask(GeneralServer* server, Endpoint* endpoint,
ConnectionType connectionType);
@ -54,12 +49,13 @@ class GeneralListenTask : public ListenTask {
bool handleConnected(TRI_socket_t s, ConnectionInfo&& info) override;
private:
//////////////////////////////////////////////////////////////////////////////
/// @brief underlying general server
//////////////////////////////////////////////////////////////////////////////
GeneralServer* _server;
ConnectionType _connectionType;
double _keepAliveTimeout = 300.0;
SSL_CTX* _sslContext = nullptr;
int _verificationMode = SSL_VERIFY_NONE;
int (*_verificationCallback)(int, X509_STORE_CTX*) = nullptr
;
};
}
}

View File

@ -66,18 +66,12 @@ int GeneralServer::sendChunk(uint64_t taskId, std::string const& data) {
////////////////////////////////////////////////////////////////////////////////
GeneralServer::GeneralServer(
double keepAliveTimeout, bool allowMethodOverride,
std::vector<std::string> const& accessControlAllowOrigins, SSL_CTX* ctx)
bool allowMethodOverride,
std::vector<std::string> const& accessControlAllowOrigins)
: _listenTasks(),
_endpointList(nullptr),
_commTasks(),
_keepAliveTimeout(keepAliveTimeout),
_allowMethodOverride(allowMethodOverride),
_accessControlAllowOrigins(accessControlAllowOrigins),
_ctx(ctx),
_verificationMode(SSL_VERIFY_NONE),
_verificationCallback(nullptr),
_sslAllowed(ctx != nullptr) {}
_accessControlAllowOrigins(accessControlAllowOrigins) {}
////////////////////////////////////////////////////////////////////////////////
/// @brief destructs a general server
@ -85,27 +79,6 @@ GeneralServer::GeneralServer(
GeneralServer::~GeneralServer() { stopListening(); }
////////////////////////////////////////////////////////////////////////////////
/// @brief generates a suitable communication task
////////////////////////////////////////////////////////////////////////////////
GeneralCommTask* GeneralServer::createCommTask(TRI_socket_t s,
ConnectionInfo&& info,
ConnectionType conntype) {
switch (conntype) {
case ConnectionType::VPPS:
return new HttpCommTask(this, s, std::move(info), _keepAliveTimeout);
case ConnectionType::VPP:
return new HttpCommTask(this, s, std::move(info), _keepAliveTimeout);
case ConnectionType::HTTPS:
// check _ctx and friends? REVIEW
return new HttpsCommTask(this, s, std::move(info), _keepAliveTimeout,
_ctx, _verificationMode, _verificationCallback);
default:
return new HttpCommTask(this, s, std::move(info), _keepAliveTimeout);
}
}
////////////////////////////////////////////////////////////////////////////////
/// @brief add the endpoint list
////////////////////////////////////////////////////////////////////////////////
@ -138,7 +111,7 @@ void GeneralServer::startListening() {
}
////////////////////////////////////////////////////////////////////////////////
/// @brief stops listening
/// @brief removes all listen and comm tasks
////////////////////////////////////////////////////////////////////////////////
void GeneralServer::stopListening() {
@ -149,69 +122,6 @@ void GeneralServer::stopListening() {
_listenTasks.clear();
}
////////////////////////////////////////////////////////////////////////////////
/// @brief removes all listen and comm tasks
////////////////////////////////////////////////////////////////////////////////
void GeneralServer::stop() {
while (true) {
GeneralCommTask* task = nullptr;
{
MUTEX_LOCKER(mutexLocker, _commTasksLock);
if (_commTasks.empty()) {
break;
}
task = *_commTasks.begin();
_commTasks.erase(task);
}
SchedulerFeature::SCHEDULER->destroyTask(task);
}
}
////////////////////////////////////////////////////////////////////////////////
/// @brief handles connection request
////////////////////////////////////////////////////////////////////////////////
void GeneralServer::handleConnected(TRI_socket_t s, ConnectionInfo&& info,
ConnectionType connectionType) {
GeneralCommTask* task = createCommTask(s, std::move(info), connectionType);
try {
MUTEX_LOCKER(mutexLocker, _commTasksLock);
_commTasks.emplace(task);
} catch (...) {
// destroy the task to prevent a leak
deleteTask(task);
throw;
}
// registers the task and get the number of the scheduler thread
ssize_t n;
SchedulerFeature::SCHEDULER->registerTask(task, &n);
}
////////////////////////////////////////////////////////////////////////////////
/// @brief handles a connection close
////////////////////////////////////////////////////////////////////////////////
void GeneralServer::handleCommunicationClosed(GeneralCommTask* task) {
MUTEX_LOCKER(mutexLocker, _commTasksLock);
_commTasks.erase(task);
}
////////////////////////////////////////////////////////////////////////////////
/// @brief handles a connection failure
////////////////////////////////////////////////////////////////////////////////
void GeneralServer::handleCommunicationFailure(GeneralCommTask* task) {
MUTEX_LOCKER(mutexLocker, _commTasksLock);
_commTasks.erase(task);
}
////////////////////////////////////////////////////////////////////////////////
/// @brief create a job for asynchronous execution (using the dispatcher)
////////////////////////////////////////////////////////////////////////////////
@ -297,20 +207,12 @@ bool GeneralServer::openEndpoint(Endpoint* endpoint) {
if (endpoint->transport() == Endpoint::TransportType::HTTP) {
if (endpoint->encryption() == Endpoint::EncryptionType::SSL) {
if (!_sslAllowed) { // we should not end up here
LOG(FATAL) << "no ssl context";
FATAL_ERROR_EXIT();
}
connectionType = ConnectionType::HTTPS;
} else {
connectionType = ConnectionType::HTTP;
}
} else {
if (endpoint->encryption() == Endpoint::EncryptionType::SSL) {
if (!_sslAllowed) { // we should not end up here
LOG(FATAL) << "no ssl context";
FATAL_ERROR_EXIT();
}
connectionType = ConnectionType::VPPS;
} else {
connectionType = ConnectionType::VPP;

View File

@ -26,14 +26,14 @@
#ifndef ARANGOD_HTTP_SERVER_HTTP_SERVER_H
#define ARANGOD_HTTP_SERVER_HTTP_SERVER_H 1
#include "GeneralServer/GeneralDefinitions.h"
#include "Scheduler/TaskManager.h"
#include "Basics/Mutex.h"
#include "Endpoint/ConnectionInfo.h"
#include "GeneralServer/RestHandler.h"
#include "GeneralServer/GeneralDefinitions.h"
#include "GeneralServer/HttpCommTask.h"
#include "GeneralServer/HttpsCommTask.h"
#include <openssl/ssl.h>
#include "GeneralServer/RestHandler.h"
namespace arangodb {
class EndpointList;
@ -57,9 +57,8 @@ class GeneralServer : protected TaskManager {
static int sendChunk(uint64_t, std::string const&);
public:
GeneralServer(double keepAliveTimeout, bool allowMethodOverride,
std::vector<std::string> const& accessControlAllowOrigins,
SSL_CTX* ctx = nullptr);
GeneralServer(bool allowMethodOverride,
std::vector<std::string> const& accessControlAllowOrigins);
virtual ~GeneralServer();
public:
@ -69,15 +68,6 @@ class GeneralServer : protected TaskManager {
// check, if we allow a method override
bool allowMethodOverride() { return _allowMethodOverride; }
// generates a suitable communication task
virtual GeneralCommTask* createCommTask(
TRI_socket_t, ConnectionInfo&&, ConnectionType = ConnectionType::HTTP);
void setVerificationMode(int mode) { _verificationMode = mode; }
void setVerificationCallback(int (*func)(int, X509_STORE_CTX*)) {
_verificationCallback = func;
}
public:
// list of trusted origin urls for CORS
std::vector<std::string> const& trustedOrigins() const {
@ -93,18 +83,6 @@ class GeneralServer : protected TaskManager {
// stops listining
void stopListening();
// removes all listen and comm tasks
void stop();
// handles connection request
void handleConnected(TRI_socket_t s, ConnectionInfo&& info, ConnectionType);
// handles a connection close
void handleCommunicationClosed(GeneralCommTask*);
// handles a connection failure
void handleCommunicationFailure(GeneralCommTask*);
// creates a job for asynchronous execution
bool handleRequestAsync(GeneralCommTask*,
arangodb::WorkItem::uptr<RestHandler>&,
@ -138,15 +116,6 @@ class GeneralServer : protected TaskManager {
// defined ports and addresses
const EndpointList* _endpointList;
// mutex for comm tasks
arangodb::Mutex _commTasksLock;
// active comm tasks
std::unordered_set<GeneralCommTask*> _commTasks;
// keep-alive timeout
double _keepAliveTimeout;
// allow to override the method
bool _allowMethodOverride;
@ -154,10 +123,6 @@ class GeneralServer : protected TaskManager {
std::vector<std::string> const _accessControlAllowOrigins;
private:
SSL_CTX* _ctx;
int _verificationMode;
int (*_verificationCallback)(int, X509_STORE_CTX*);
bool _sslAllowed;
};
}
}

View File

@ -87,7 +87,6 @@ AuthInfo GeneralServerFeature::AUTH_INFO;
GeneralServerFeature::GeneralServerFeature(
application_features::ApplicationServer* server)
: ApplicationFeature(server, "GeneralServer"),
_keepAliveTimeout(300.0),
_allowMethodOverride(false),
_authentication(true),
_authenticationUnixSockets(true),
@ -322,7 +321,7 @@ void GeneralServerFeature::stop() {
}
for (auto& server : _servers) {
server->stop();
server->stopListening();
}
}
@ -345,7 +344,6 @@ void GeneralServerFeature::buildServers() {
auto const& endpointList = endpoint->endpointList();
// check if endpointList contains ssl featured server
SSL_CTX* sslContext = nullptr;
if (endpointList.hasSsl()) {
SslServerFeature* ssl =
application_features::ApplicationServer::getFeature<SslServerFeature>(
@ -356,12 +354,11 @@ void GeneralServerFeature::buildServers() {
"please use the '--ssl.keyfile' option";
FATAL_ERROR_EXIT();
}
sslContext = ssl->sslContext();
}
GeneralServer* server =
new GeneralServer(_keepAliveTimeout, _allowMethodOverride,
_accessControlAllowOrigins, sslContext);
new GeneralServer(_allowMethodOverride,
_accessControlAllowOrigins);
server->setEndpointList(&endpointList);
_servers.push_back(server);

View File

@ -25,6 +25,8 @@
#include "ApplicationFeatures/ApplicationFeature.h"
#include <openssl/ssl.h>
#include "Actions/RestActionHandler.h"
#include "VocBase/AuthInfo.h"
@ -39,12 +41,30 @@ class RestServerThread;
class GeneralServerFeature final
: public application_features::ApplicationFeature {
public:
typedef int (*verification_callback_fptr)(int, X509_STORE_CTX*);
public:
static rest::RestHandlerFactory* HANDLER_FACTORY;
static rest::AsyncJobManager* JOB_MANAGER;
static AuthInfo AUTH_INFO;
public:
static double keepAliveTimeout() {
return GENERAL_SERVER != nullptr ? GENERAL_SERVER->_keepAliveTimeout
: 300.0;
};
static int verificationMode() {
return GENERAL_SERVER != nullptr ? GENERAL_SERVER->_verificationMode
: SSL_VERIFY_NONE;
};
static verification_callback_fptr verificationCallback() {
return GENERAL_SERVER != nullptr ? GENERAL_SERVER->_verificationCallback
: nullptr;
};
static bool authenticationEnabled() {
return GENERAL_SERVER != nullptr && GENERAL_SERVER->authentication();
}
@ -57,6 +77,7 @@ class GeneralServerFeature final
if (GENERAL_SERVER == nullptr) {
return std::vector<std::string>();
}
return GENERAL_SERVER->trustedProxies();
}
@ -64,6 +85,7 @@ class GeneralServerFeature final
if (GENERAL_SERVER == nullptr) {
return std::string();
}
return GENERAL_SERVER->jwtSecret();
}
@ -82,8 +104,14 @@ class GeneralServerFeature final
void stop() override final;
void unprepare() override final;
public:
void setVerificationMode(int mode) { _verificationMode = mode; }
void setVerificationCallback(int (*func)(int, X509_STORE_CTX*)) {
_verificationCallback = func;
}
private:
double _keepAliveTimeout;
double _keepAliveTimeout = 300.0;
bool _allowMethodOverride;
bool _authentication;
bool _authenticationUnixSockets;
@ -94,6 +122,8 @@ class GeneralServerFeature final
std::vector<std::string> _accessControlAllowOrigins;
std::string _jwtSecret;
int _verificationMode;
verification_callback_fptr _verificationCallback;
public:
bool authentication() const { return _authentication; }

View File

@ -844,7 +844,6 @@ void HttpCommTask::signalTask(TaskData* data) {
bool HttpCommTask::handleRead() {
bool res = true;
if (!_setupDone.load(std::memory_order_relaxed)) return res;
if (!_closeRequested) {
res = fillReadBuffer();
@ -862,10 +861,8 @@ bool HttpCommTask::handleRead() {
if (_clientClosed) {
res = false;
_server->handleCommunicationClosed(this);
} else if (!res) {
_clientClosed = true;
_server->handleCommunicationFailure(this);
}
return res;
@ -886,7 +883,6 @@ void HttpCommTask::completedWriteBuffer() {
if (!_clientClosed && _closeRequested && !hasWriteBuffer() &&
_writeBuffers.empty() && !_isChunked) {
_clientClosed = true;
_server->handleCommunicationClosed(this);
}
}

View File

@ -111,7 +111,6 @@ bool HttpsCommTask::handleEvent(EventToken token, EventType revents) {
// status is somehow invalid. we got here even though no accept was ever
// successful
_clientClosed = true;
_server->handleCommunicationFailure(this);
_scheduler->destroyTask(this);
}

View File

@ -24,9 +24,9 @@
#include "SocketTask.h"
#include "Logger/Logger.h"
#include "Basics/StringBuffer.h"
#include "Basics/socket-utils.h"
#include "Logger/Logger.h"
#include "Scheduler/Scheduler.h"
#include <errno.h>
@ -122,17 +122,20 @@ bool SocketTask::fillReadBuffer() {
return fillReadBuffer();
}
// condition is required like this because g++ 6 will complain about
// condition is required like this because g++ 6 will complain about
// if (myerrno != EWOULDBLOCK && myerrno != EAGAIN)
// having two identical branches (because EWOULDBLOCK == EAGAIN on Linux).
// however, posix states that there may be systems where EWOULDBLOCK != EAGAIN...
// however, posix states that there may be systems where EWOULDBLOCK !=
// EAGAIN...
if (myerrno != EWOULDBLOCK && (EWOULDBLOCK == EAGAIN || myerrno != EAGAIN)) {
LOG(DEBUG) << "read from socket failed with " << myerrno << ": " << strerror(myerrno);
LOG(DEBUG) << "read from socket failed with " << myerrno << ": "
<< strerror(myerrno);
return false;
}
TRI_ASSERT(myerrno == EWOULDBLOCK || (EWOULDBLOCK != EAGAIN && myerrno == EAGAIN));
TRI_ASSERT(myerrno == EWOULDBLOCK ||
(EWOULDBLOCK != EAGAIN && myerrno == EAGAIN));
// from man(2) read:
// The file descriptor fd refers to a socket and has been marked
@ -141,7 +144,8 @@ bool SocketTask::fillReadBuffer() {
// either error to be returned for this case, and does not require these
// constants to have the same value,
// so a portable application should check for both possibilities.
LOG(TRACE) << "read would block with " << myerrno << ": " << strerror(myerrno);
LOG(TRACE) << "read would block with " << myerrno << ": "
<< strerror(myerrno);
return true;
}
@ -172,13 +176,16 @@ bool SocketTask::handleWrite() {
return handleWrite();
}
if (myerrno != EWOULDBLOCK && (EAGAIN == EWOULDBLOCK || myerrno != EAGAIN)) {
LOG(DEBUG) << "writing to socket failed with " << myerrno << ": " << strerror(myerrno);
if (myerrno != EWOULDBLOCK &&
(EAGAIN == EWOULDBLOCK || myerrno != EAGAIN)) {
LOG(DEBUG) << "writing to socket failed with " << myerrno << ": "
<< strerror(myerrno);
return false;
}
TRI_ASSERT(myerrno == EWOULDBLOCK || (EWOULDBLOCK != EAGAIN && myerrno == EAGAIN));
TRI_ASSERT(myerrno == EWOULDBLOCK ||
(EWOULDBLOCK != EAGAIN && myerrno == EAGAIN));
nr = 0;
}
@ -269,7 +276,8 @@ bool SocketTask::setup(Scheduler* scheduler, EventLoop loop) {
LOG(TRACE) << "attempting to convert socket handle to socket descriptor";
if (!TRI_isvalidsocket(_commSocket)) {
LOG(ERR) << "In SocketTask::setup could not convert socket handle to socket descriptor -- invalid socket handle";
LOG(ERR) << "In SocketTask::setup could not convert socket handle to "
"socket descriptor -- invalid socket handle";
return false;
}
@ -282,12 +290,15 @@ bool SocketTask::setup(Scheduler* scheduler, EventLoop loop) {
int res = (int)_commSocket.fileHandle;
if (res == -1) {
LOG(ERR) << "In SocketTask::setup could not convert socket handle to socket descriptor -- _open_osfhandle(...) failed";
LOG(ERR) << "In SocketTask::setup could not convert socket handle to "
"socket descriptor -- _open_osfhandle(...) failed";
res = TRI_CLOSE_SOCKET(_commSocket);
if (res != 0) {
res = WSAGetLastError();
LOG(ERR) << "In SocketTask::setup closesocket(...) failed with error code: " << res;
LOG(ERR)
<< "In SocketTask::setup closesocket(...) failed with error code: "
<< res;
}
TRI_invalidatesocket(&_commSocket);
@ -301,11 +312,6 @@ bool SocketTask::setup(Scheduler* scheduler, EventLoop loop) {
_scheduler = scheduler;
_loop = loop;
_readWatcher = _scheduler->installSocketEvent(loop, EVENT_SOCKET_READ, this,
_commSocket);
_writeWatcher = _scheduler->installSocketEvent(loop, EVENT_SOCKET_WRITE, this,
_commSocket);
// install timer for keep-alive timeout with some high default value
_keepAliveWatcher = _scheduler->installTimerEvent(loop, this, 60.0);
@ -314,6 +320,11 @@ bool SocketTask::setup(Scheduler* scheduler, EventLoop loop) {
_tid = Thread::currentThreadId();
_writeWatcher = _scheduler->installSocketEvent(loop, EVENT_SOCKET_WRITE, this,
_commSocket);
_readWatcher = _scheduler->installSocketEvent(loop, EVENT_SOCKET_READ, this,
_commSocket);
return true;
}