From 8d8044ba47b57564c90dd3f190bf3f11799f68eb Mon Sep 17 00:00:00 2001 From: Jan Steemann Date: Wed, 25 Jul 2012 17:47:54 +0200 Subject: [PATCH] added proper SSL support for arangosh and arangoimp --- arangod/RestServer/ArangoServer.cpp | 10 +- arangosh/V8Client/V8ClientConnection.cpp | 18 ++-- arangosh/V8Client/V8ClientConnection.h | 12 +-- arangosh/V8Client/arangoimp.cpp | 22 ++--- arangosh/V8Client/arangosh.cpp | 29 +++--- lib/GeneralServer/GeneralServer.h | 6 +- lib/HttpServer/HttpServer.h | 6 +- lib/HttpsServer/HttpsServer.h | 6 +- lib/Rest/Endpoint.cpp | 93 +++++++++++++------ lib/Rest/Endpoint.h | 44 +++++++-- lib/Rest/EndpointList.cpp | 31 +++++-- lib/Rest/EndpointList.h | 55 +++++++---- lib/Scheduler/ListenTask.cpp | 2 +- lib/SimpleHttpClient/ClientConnection.cpp | 26 ++---- .../GeneralClientConnection.cpp | 26 ++++++ .../GeneralClientConnection.h | 6 ++ lib/SimpleHttpClient/SslClientConnection.cpp | 31 ++----- lib/SimpleHttpClient/SslClientConnection.h | 5 - 18 files changed, 269 insertions(+), 159 deletions(-) diff --git a/arangod/RestServer/ArangoServer.cpp b/arangod/RestServer/ArangoServer.cpp index ac43f14327..fed53e6c67 100755 --- a/arangod/RestServer/ArangoServer.cpp +++ b/arangod/RestServer/ArangoServer.cpp @@ -613,7 +613,7 @@ int ArangoServer::startupServer () { assert(endpoint); - bool ok = _endpointList.addEndpoint(endpoint->getProtocol(), endpoint); + bool ok = _endpointList.addEndpoint(endpoint->getProtocol(), endpoint->getEncryption(), endpoint); if (! ok) { LOGGER_FATAL << "invalid endpoint '" << *i << "'"; cerr << "invalid endpoint '" << *i << "'\n"; @@ -635,8 +635,8 @@ int ArangoServer::startupServer () { httpOptions._contexts.insert("api"); httpOptions._contexts.insert("admin"); - // HTTP endpoints - if (_endpointList.count(Endpoint::PROTOCOL_HTTP) > 0) { + // unencrypted endpoints + if (_endpointList.count(Endpoint::PROTOCOL_HTTP, Endpoint::ENCRYPTION_NONE) > 0) { // create the http server _httpServer = _applicationHttpServer->buildServer(&_endpointList); @@ -651,8 +651,8 @@ int ArangoServer::startupServer () { } #ifdef TRI_OPENSSL_VERSION - // HTTPS endpoints - if (_endpointList.count(Endpoint::PROTOCOL_HTTPS) > 0) { + // SSL endpoints + if (_endpointList.count(Endpoint::PROTOCOL_HTTP, Endpoint::ENCRYPTION_SSL) > 0) { // create the https server _httpsServer = _applicationHttpsServer->buildServer(&_endpointList); diff --git a/arangosh/V8Client/V8ClientConnection.cpp b/arangosh/V8Client/V8ClientConnection.cpp index ea77f6e8e7..b96c09916b 100644 --- a/arangosh/V8Client/V8ClientConnection.cpp +++ b/arangosh/V8Client/V8ClientConnection.cpp @@ -41,9 +41,7 @@ #include #include "Basics/StringUtils.h" -#include "Rest/Endpoint.h" -#include "SimpleHttpClient/ClientConnection.h" -#include "SimpleHttpClient/SslClientConnection.h" +#include "SimpleHttpClient/GeneralClientConnection.h" #include "SimpleHttpClient/SimpleHttpClient.h" #include "SimpleHttpClient/SimpleHttpResult.h" #include "Variant/VariantArray.h" @@ -73,15 +71,20 @@ using namespace std; V8ClientConnection::V8ClientConnection (Endpoint* endpoint, double requestTimeout, - size_t retries, - double connectionTimeout, + double connectTimeout, + size_t numRetries, bool warn) - : _connection(new ClientConnection(endpoint, requestTimeout, connectionTimeout, retries)), + : _connection(0), _lastHttpReturnCode(0), _lastErrorMessage(""), _client(0), _httpResult(0) { - + + _connection = GeneralClientConnection::factory(endpoint, 3.0, 3.0, 3); + if (_connection == 0) { + throw "out of memory"; + } + _client = new SimpleHttpClient(_connection, requestTimeout, warn); // connect to server and get version number @@ -158,7 +161,6 @@ bool V8ClientConnection::isConnected () { return _connection->isConnected(); } - //////////////////////////////////////////////////////////////////////////////// /// @brief returns the version and build number of the arango server //////////////////////////////////////////////////////////////////////////////// diff --git a/arangosh/V8Client/V8ClientConnection.h b/arangosh/V8Client/V8ClientConnection.h index 9c4186e272..7c14421032 100644 --- a/arangosh/V8Client/V8ClientConnection.h +++ b/arangosh/V8Client/V8ClientConnection.h @@ -40,7 +40,6 @@ #define TRIAGENS_V8_CLIENT_CONNECTION_H 1 #include -#include #include @@ -54,6 +53,10 @@ namespace triagens { class SimpleHttpClient; class SimpleHttpResult; } + + namespace rest { + class Endpoint; + } } // ----------------------------------------------------------------------------- @@ -94,17 +97,12 @@ namespace triagens { //////////////////////////////////////////////////////////////////////////////// /// @brief constructor -/// -/// @param Endpoint endpoint Endpoint to connect to -/// @param double requestTimeout timeout in seconds for one request -/// @param size_t retries maximum number of request retries -/// @param double connTimeout timeout in seconds for the tcp connect //////////////////////////////////////////////////////////////////////////////// V8ClientConnection (triagens::rest::Endpoint*, double, + double, size_t, - double, bool); //////////////////////////////////////////////////////////////////////////////// diff --git a/arangosh/V8Client/arangoimp.cpp b/arangosh/V8Client/arangoimp.cpp index 8124cf8003..23ad4da56b 100644 --- a/arangosh/V8Client/arangoimp.cpp +++ b/arangosh/V8Client/arangoimp.cpp @@ -38,11 +38,12 @@ #include "BasicsC/init.h" #include "BasicsC/logging.h" #include "BasicsC/strings.h" +#include "ImportHelper.h" #include "Logger/Logger.h" +#include "Rest/Endpoint.h" +#include "Rest/Initialise.h" #include "SimpleHttpClient/SimpleHttpClient.h" #include "SimpleHttpClient/SimpleHttpResult.h" -#include "ImportHelper.h" -#include "Rest/Endpoint.h" #include "V8ClientConnection.h" using namespace std; @@ -65,8 +66,8 @@ using namespace triagens::v8client; //////////////////////////////////////////////////////////////////////////////// static int64_t DEFAULT_REQUEST_TIMEOUT = 300; -static size_t DEFAULT_RETRIES = 5; -static int64_t DEFAULT_CONNECTION_TIMEOUT = 5; +static size_t DEFAULT_RETRIES = 2; +static int64_t DEFAULT_CONNECTION_TIMEOUT = 3; //////////////////////////////////////////////////////////////////////////////// /// @brief endpoint to connect to @@ -196,6 +197,8 @@ static void ParseProgramOptions (int argc, char* argv[]) { int main (int argc, char* argv[]) { TRIAGENS_C_INITIALISE(argc, argv); + TRIAGENS_REST_INITIALISE(argc, argv); + TRI_InitialiseLogging(false); EndpointString = Endpoint::getDefaultEndpoint(); @@ -221,13 +224,8 @@ int main (int argc, char* argv[]) { } assert(_endpoint); - - clientConnection = new V8ClientConnection( - _endpoint, - (double) requestTimeout, - DEFAULT_RETRIES, - (double) connectTimeout, - true); + + clientConnection = new V8ClientConnection(_endpoint, (double) requestTimeout, (double) connectTimeout, DEFAULT_RETRIES, false); if (!clientConnection->isConnected()) { cerr << "Could not connect to endpoint " << _endpoint->getSpecification() << endl; @@ -320,6 +318,8 @@ int main (int argc, char* argv[]) { cerr << "error message : " << ih.getErrorMessage() << endl; } + TRIAGENS_REST_SHUTDOWN; + return EXIT_SUCCESS; } diff --git a/arangosh/V8Client/arangosh.cpp b/arangosh/V8Client/arangosh.cpp index 773efffc2c..e151b6f023 100644 --- a/arangosh/V8Client/arangosh.cpp +++ b/arangosh/V8Client/arangosh.cpp @@ -85,8 +85,8 @@ using namespace triagens::arango; //////////////////////////////////////////////////////////////////////////////// static int64_t const DEFAULT_REQUEST_TIMEOUT = 300; -static size_t const DEFAULT_RETRIES = 5; -static int64_t const DEFAULT_CONNECTION_TIMEOUT = 5; +static size_t const DEFAULT_RETRIES = 2; +static int64_t const DEFAULT_CONNECTION_TIMEOUT = 3; //////////////////////////////////////////////////////////////////////////////// /// @brief colors for output @@ -512,6 +512,18 @@ static void StopPager () } } +//////////////////////////////////////////////////////////////////////////////// +/// @brief return a new client connection instance +//////////////////////////////////////////////////////////////////////////////// + +static V8ClientConnection* createConnection () { + return new V8ClientConnection(_endpoint, + (double) RequestTimeout, + (double) ConnectTimeout, + DEFAULT_RETRIES, + false); +} + //////////////////////////////////////////////////////////////////////////////// /// @brief parses the program options //////////////////////////////////////////////////////////////////////////////// @@ -656,8 +668,6 @@ static v8::Handle wrapV8ClientConnection (V8ClientConnection* connec static v8::Handle ClientConnection_ConstructorCallback (v8::Arguments const& argv) { v8::HandleScope scope; - size_t retries = DEFAULT_RETRIES; - if (argv.Length() > 0 && argv[0]->IsString()) { string definition = TRI_ObjectToString(argv[0]); @@ -679,7 +689,7 @@ static v8::Handle ClientConnection_ConstructorCallback (v8::Arguments assert(_endpoint); - V8ClientConnection* connection = new V8ClientConnection(_endpoint, (double) RequestTimeout, retries, (double) ConnectTimeout, false); + V8ClientConnection* connection = createConnection(); if (connection->isConnected()) { cout << "Connected to ArangoDB '" << _endpoint->getSpecification() << "' Version " << connection->getVersion() << endl; @@ -1223,12 +1233,7 @@ int main (int argc, char* argv[]) { assert(_endpoint); - _clientConnection = new V8ClientConnection( - _endpoint, - (double) RequestTimeout, - DEFAULT_RETRIES, - (double) ConnectTimeout, - false); + _clientConnection = createConnection(); } // ............................................................................. @@ -1410,7 +1415,7 @@ int main (int argc, char* argv[]) { exit(EXIT_FAILURE); } } - + // ............................................................................. // run normal shell // ............................................................................. diff --git a/lib/GeneralServer/GeneralServer.h b/lib/GeneralServer/GeneralServer.h index 6d623a0aea..bcaee6b299 100644 --- a/lib/GeneralServer/GeneralServer.h +++ b/lib/GeneralServer/GeneralServer.h @@ -155,10 +155,10 @@ namespace triagens { public: //////////////////////////////////////////////////////////////////////////////// -/// @brief return the protocol to be used +/// @brief return the encryption to be used //////////////////////////////////////////////////////////////////////////////// - virtual Endpoint::Protocol getProtocol () = 0; + virtual Endpoint::Encryption getEncryption () const = 0; //////////////////////////////////////////////////////////////////////////////// /// @brief return the scheduler @@ -181,7 +181,7 @@ namespace triagens { //////////////////////////////////////////////////////////////////////////////// void startListening () { - EndpointList::ListType endpoints = _endpointList->getEndpoints(this->getProtocol()); + EndpointList::ListType endpoints = _endpointList->getEndpoints(Endpoint::PROTOCOL_HTTP, this->getEncryption()); for (EndpointList::ListType::const_iterator i = endpoints.begin(); i != endpoints.end(); ++i) { LOGGER_TRACE << "trying to bind to endpoint '" << (*i)->getSpecification() << "' for requests"; diff --git a/lib/HttpServer/HttpServer.h b/lib/HttpServer/HttpServer.h index 15ba2d6572..f32ba31581 100644 --- a/lib/HttpServer/HttpServer.h +++ b/lib/HttpServer/HttpServer.h @@ -97,11 +97,11 @@ namespace triagens { public: //////////////////////////////////////////////////////////////////////////////// -/// @brief return protocol to be used +/// @brief return encryption to be used //////////////////////////////////////////////////////////////////////////////// - virtual Endpoint::Protocol getProtocol () { - return Endpoint::PROTOCOL_HTTP; + virtual Endpoint::Encryption getEncryption () const { + return Endpoint::ENCRYPTION_NONE; } //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/HttpsServer/HttpsServer.h b/lib/HttpsServer/HttpsServer.h index ea2cdf10f9..89fc313b38 100644 --- a/lib/HttpsServer/HttpsServer.h +++ b/lib/HttpsServer/HttpsServer.h @@ -153,11 +153,11 @@ namespace triagens { public: //////////////////////////////////////////////////////////////////////////////// -/// @brief return protocol to be used +/// @brief return encryption to be used //////////////////////////////////////////////////////////////////////////////// - virtual Endpoint::Protocol getProtocol () { - return Endpoint::PROTOCOL_HTTPS; + virtual Endpoint::Encryption getEncryption () const { + return Endpoint::ENCRYPTION_SSL; } //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/Rest/Endpoint.cpp b/lib/Rest/Endpoint.cpp index 3b32b18580..5acc0c9c9a 100644 --- a/lib/Rest/Endpoint.cpp +++ b/lib/Rest/Endpoint.cpp @@ -80,12 +80,14 @@ const std::string EndpointIp::_defaultHost = "127.0.0.1"; Endpoint::Endpoint (const Endpoint::Type type, const Endpoint::DomainType domainType, const Endpoint::Protocol protocol, + const Endpoint::Encryption encryption, const string& specification) : _connected(false), _socket(0), _type(type), _domainType(domainType), _protocol(protocol), + _encryption(encryption), _specification(specification) { } @@ -141,7 +143,8 @@ Endpoint* Endpoint::factory (const Endpoint::Type type, copy = copy.substr(0, copy.size() - 1); } - Endpoint::Protocol protocol = PROTOCOL_UNKNOWN; + // default protocol is HTTP + Endpoint::Protocol protocol = PROTOCOL_HTTP; // read protocol from string size_t found = copy.find('@'); @@ -151,14 +154,7 @@ Endpoint* Endpoint::factory (const Endpoint::Type type, protocol = PROTOCOL_BINARY; copy = copy.substr(strlen("pb@")); } -#ifdef TRI_OPENSSL_VERSION - else if (protoString == "https") { - protocol = PROTOCOL_HTTPS; - copy = copy.substr(strlen("https@")); - } -#endif else if (protoString == "http") { - protocol = PROTOCOL_HTTP; copy = copy.substr(strlen("http@")); } else { @@ -166,22 +162,23 @@ Endpoint* Endpoint::factory (const Endpoint::Type type, return 0; } } - else { - // no protocol specified, use HTTP - protocol = PROTOCOL_HTTP; - } + Encryption encryption = ENCRYPTION_NONE; string domainType = StringUtils::tolower(copy.substr(0, 7)); if (StringUtils::isPrefix(domainType, "unix://")) { // unix socket return new EndpointUnix(type, protocol, specification, copy.substr(strlen("unix://"))); } + else if (StringUtils::isPrefix(domainType, "ssl://")) { + // ssl + encryption = ENCRYPTION_SSL; + } else if (! StringUtils::isPrefix(domainType, "tcp://")) { // invalid type return 0; } - // tcp/ip + // tcp/ip or ssl copy = copy.substr(strlen("tcp://"), copy.length()); if (copy[0] == '[') { @@ -191,14 +188,14 @@ Endpoint* Endpoint::factory (const Endpoint::Type type, // hostname and port (e.g. [address]:port) uint16_t port = (uint16_t) StringUtils::uint32(copy.substr(found + 2)); - return new EndpointIpV6(type, protocol, specification, copy.substr(1, found - 1), port); + return new EndpointIpV6(type, protocol, encryption, specification, copy.substr(1, found - 1), port); } found = copy.find("]", 1); if (found != string::npos && found + 1 == copy.size()) { // hostname only (e.g. [address]) - return new EndpointIpV6(type, protocol, specification, copy.substr(1, found - 1), EndpointIp::_defaultPort); + return new EndpointIpV6(type, protocol, encryption, specification, copy.substr(1, found - 1), EndpointIp::_defaultPort); } // invalid address specification @@ -212,11 +209,11 @@ Endpoint* Endpoint::factory (const Endpoint::Type type, // hostname and port uint16_t port = (uint16_t) StringUtils::uint32(copy.substr(found + 1)); - return new EndpointIpV4(type, protocol, specification, copy.substr(0, found), port); + return new EndpointIpV4(type, protocol, encryption, specification, copy.substr(0, found), port); } // hostname only - return new EndpointIpV4(type, protocol, specification, copy, EndpointIp::_defaultPort); + return new EndpointIpV4(type, protocol, encryption, specification, copy, EndpointIp::_defaultPort); } //////////////////////////////////////////////////////////////////////////////// @@ -235,12 +232,25 @@ const std::string Endpoint::getDefaultEndpoint () { return "tcp://" + EndpointIp::_defaultHost + ":" + StringUtils::itoa(EndpointIp::_defaultPort); } +//////////////////////////////////////////////////////////////////////////////// +/// @brief set socket timeout +//////////////////////////////////////////////////////////////////////////////// + +void Endpoint::setTimeout (socket_t s, double timeout) { + struct timeval tv; + tv.tv_sec = (uint64_t) timeout; + tv.tv_usec = ((uint64_t) (timeout * 1000000.0)) % 1000000; + + setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)); +} + //////////////////////////////////////////////////////////////////////////////// /// @brief set common socket flags //////////////////////////////////////////////////////////////////////////////// bool Endpoint::setSocketFlags (socket_t _socket) { - if (_protocol == PROTOCOL_HTTPS && _type == ENDPOINT_CLIENT) { + if (_encryption == ENCRYPTION_SSL && _type == ENDPOINT_CLIENT) { // SSL client endpoints are not set to non-blocking return true; } @@ -291,7 +301,7 @@ EndpointUnix::EndpointUnix (const Endpoint::Type type, const Endpoint::Protocol protocol, string const& specification, string const& path) : - Endpoint(type, ENDPOINT_UNIX, protocol, specification), + Endpoint(type, ENDPOINT_UNIX, protocol, ENCRYPTION_NONE, specification), _path(path) { } @@ -322,7 +332,7 @@ EndpointUnix::~EndpointUnix () { /// @brief connect the endpoint //////////////////////////////////////////////////////////////////////////////// -socket_t EndpointUnix::connect () { +socket_t EndpointUnix::connect (double connectTimeout, double requestTimeout) { assert(_socket == 0); assert(!_connected); @@ -387,7 +397,15 @@ socket_t EndpointUnix::connect () { } else if (_type == ENDPOINT_CLIENT) { // connect to endpoint, executed for client endpoints only - ::connect(listenSocket, (const struct sockaddr*) &address, SUN_LEN(&address)); + + // set timeout + setTimeout(listenSocket, connectTimeout); + + if (::connect(listenSocket, (const struct sockaddr*) &address, SUN_LEN(&address)) != 0) { + close(listenSocket); + + return 0; + } } if (!setSocketFlags(listenSocket)) { @@ -396,6 +414,10 @@ socket_t EndpointUnix::connect () { return 0; } + if (_type == ENDPOINT_CLIENT) { + setTimeout(listenSocket, requestTimeout); + } + _connected = true; _socket = listenSocket; @@ -456,10 +478,11 @@ bool EndpointUnix::initIncoming (socket_t incoming) { EndpointIp::EndpointIp (const Endpoint::Type type, const Endpoint::DomainType domainType, const Endpoint::Protocol protocol, + const Endpoint::Encryption encryption, string const& specification, string const& host, const uint16_t port) : - Endpoint(type, domainType, protocol, specification), _host(host), _port(port) { + Endpoint(type, domainType, protocol, encryption, specification), _host(host), _port(port) { assert(domainType == ENDPOINT_IPV4 || domainType == ENDPOINT_IPV6); } @@ -487,7 +510,7 @@ EndpointIp::~EndpointIp () { /// @{ //////////////////////////////////////////////////////////////////////////////// -socket_t EndpointIp::connectSocket (const struct addrinfo* aip) { +socket_t EndpointIp::connectSocket (const struct addrinfo* aip, double connectTimeout, double requestTimeout) { // set address and port char host[NI_MAXHOST], serv[NI_MAXSERV]; if (getnameinfo(aip->ai_addr, aip->ai_addrlen, @@ -534,7 +557,15 @@ socket_t EndpointIp::connectSocket (const struct addrinfo* aip) { } else if (_type == ENDPOINT_CLIENT) { // connect to endpoint, executed for client endpoints only - ::connect(listenSocket, (const struct sockaddr*) aip->ai_addr, aip->ai_addrlen); + + // set timeout + setTimeout(listenSocket, connectTimeout); + + if (::connect(listenSocket, (const struct sockaddr*) aip->ai_addr, aip->ai_addrlen) != 0) { + close(listenSocket); + + return 0; + } } if (!setSocketFlags(listenSocket)) { @@ -542,6 +573,10 @@ socket_t EndpointIp::connectSocket (const struct addrinfo* aip) { return 0; } + + if (_type == ENDPOINT_CLIENT) { + setTimeout(listenSocket, requestTimeout); + } _connected = true; _socket = listenSocket; @@ -566,7 +601,7 @@ socket_t EndpointIp::connectSocket (const struct addrinfo* aip) { /// @brief connect the endpoint //////////////////////////////////////////////////////////////////////////////// -socket_t EndpointIp::connect () { +socket_t EndpointIp::connect (double connectTimeout, double requestTimeout) { struct addrinfo* result = 0; struct addrinfo* aip; struct addrinfo hints; @@ -599,7 +634,7 @@ socket_t EndpointIp::connect () { // Try all returned addresses until one works for (aip = result; aip != NULL; aip = aip->ai_next) { // try to bind the address info pointer - listenSocket = connectSocket(aip); + listenSocket = connectSocket(aip, connectTimeout, requestTimeout); if (listenSocket != 0) { // OK break; @@ -667,10 +702,11 @@ bool EndpointIp::initIncoming (socket_t incoming) { EndpointIpV4::EndpointIpV4 (const Endpoint::Type type, const Endpoint::Protocol protocol, + const Endpoint::Encryption encryption, string const& specification, string const& host, const uint16_t port) : - EndpointIp(type, ENDPOINT_IPV4, protocol, specification, host, port) { + EndpointIp(type, ENDPOINT_IPV4, protocol, encryption, specification, host, port) { } //////////////////////////////////////////////////////////////////////////////// @@ -703,10 +739,11 @@ EndpointIpV4::~EndpointIpV4 () { EndpointIpV6::EndpointIpV6 (const Endpoint::Type type, const Endpoint::Protocol protocol, + const Endpoint::Encryption encryption, string const& specification, string const& host, const uint16_t port) : - EndpointIp(type, ENDPOINT_IPV6, protocol, specification, host, port) { + EndpointIp(type, ENDPOINT_IPV6, protocol, encryption, specification, host, port) { } //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/Rest/Endpoint.h b/lib/Rest/Endpoint.h index 2837983dd8..6f834af1cd 100644 --- a/lib/Rest/Endpoint.h +++ b/lib/Rest/Endpoint.h @@ -95,12 +95,18 @@ namespace triagens { enum Protocol { PROTOCOL_UNKNOWN, PROTOCOL_BINARY, -#ifdef TRI_OPENSSL_VERSION - PROTOCOL_HTTPS, -#endif PROTOCOL_HTTP }; +//////////////////////////////////////////////////////////////////////////////// +/// @brief encryption used when talking to endpoint +//////////////////////////////////////////////////////////////////////////////// + + enum Encryption { + ENCRYPTION_NONE = 0, + ENCRYPTION_SSL + }; + //////////////////////////////////////////////////////////////////////////////// /// @} //////////////////////////////////////////////////////////////////////////////// @@ -123,6 +129,7 @@ namespace triagens { Endpoint (const Type, const DomainType, const Protocol, + const Encryption, const string&); public: @@ -183,7 +190,7 @@ namespace triagens { /// @brief connect the endpoint //////////////////////////////////////////////////////////////////////////////// - virtual int connect () = 0; + virtual int connect (double, double) = 0; //////////////////////////////////////////////////////////////////////////////// /// @brief disconnect the endpoint @@ -197,6 +204,12 @@ namespace triagens { virtual bool initIncoming (socket_t) = 0; +//////////////////////////////////////////////////////////////////////////////// +/// @brief set socket timeout +//////////////////////////////////////////////////////////////////////////////// + + virtual void setTimeout (socket_t, double); + //////////////////////////////////////////////////////////////////////////////// /// @brief initialise socket flags //////////////////////////////////////////////////////////////////////////////// @@ -227,6 +240,14 @@ namespace triagens { return _protocol; } +//////////////////////////////////////////////////////////////////////////////// +/// @brief get the encryption used +//////////////////////////////////////////////////////////////////////////////// + + Encryption getEncryption () const { + return _encryption; + } + //////////////////////////////////////////////////////////////////////////////// /// @brief get the original endpoint specification //////////////////////////////////////////////////////////////////////////////// @@ -304,6 +325,12 @@ namespace triagens { Protocol _protocol; +//////////////////////////////////////////////////////////////////////////////// +/// @brief encryption used +//////////////////////////////////////////////////////////////////////////////// + + Encryption _encryption; + //////////////////////////////////////////////////////////////////////////////// /// @brief original endpoint specification //////////////////////////////////////////////////////////////////////////////// @@ -369,7 +396,7 @@ namespace triagens { /// @brief connect the endpoint //////////////////////////////////////////////////////////////////////////////// - socket_t connect (); + socket_t connect (double, double); //////////////////////////////////////////////////////////////////////////////// /// @brief disconnect the endpoint @@ -468,6 +495,7 @@ namespace triagens { EndpointIp (const Type, const DomainType, const Protocol, + const Encryption, string const&, string const&, const uint16_t); @@ -527,7 +555,7 @@ namespace triagens { /// @brief connect the socket //////////////////////////////////////////////////////////////////////////////// - socket_t connectSocket (const struct addrinfo*); + socket_t connectSocket (const struct addrinfo*, double, double); //////////////////////////////////////////////////////////////////////////////// /// @} @@ -548,7 +576,7 @@ namespace triagens { /// @brief connect the endpoint //////////////////////////////////////////////////////////////////////////////// - socket_t connect (); + socket_t connect (double, double); //////////////////////////////////////////////////////////////////////////////// /// @brief disconnect the endpoint @@ -642,6 +670,7 @@ namespace triagens { EndpointIpV4 (const Type, const Protocol, + const Encryption, string const&, string const&, const uint16_t); @@ -704,6 +733,7 @@ namespace triagens { EndpointIpV6 (const Type, const Protocol, + const Encryption, string const&, string const&, const uint16_t); diff --git a/lib/Rest/EndpointList.cpp b/lib/Rest/EndpointList.cpp index f73377d3e5..3ea30431b0 100644 --- a/lib/Rest/EndpointList.cpp +++ b/lib/Rest/EndpointList.cpp @@ -57,7 +57,7 @@ EndpointList::EndpointList () : //////////////////////////////////////////////////////////////////////////////// EndpointList::~EndpointList () { - for (map::iterator i = _lists.begin(); i != _lists.end(); ++i) { + for (map::iterator i = _lists.begin(); i != _lists.end(); ++i) { for (ListType::iterator i2 = (*i).second.begin(); i2 != (*i).second.end(); ++i2) { delete *i2; } @@ -79,25 +79,38 @@ EndpointList::~EndpointList () { /// @{ //////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +/// @brief count the number of elements in a sub-list +//////////////////////////////////////////////////////////////////////////////// + +size_t EndpointList::count (const Endpoint::Protocol protocol, const Endpoint::Encryption encryption) const { + map::const_iterator i = _lists.find(getKey(protocol, encryption)); + + if (i == _lists.end()) { + return 0; + } + + return i->second.size(); +} + //////////////////////////////////////////////////////////////////////////////// /// @brief dump all endpoints used //////////////////////////////////////////////////////////////////////////////// -void EndpointList::dump () { - for (map::const_iterator i = _lists.begin(); i != _lists.end(); ++i) { +void EndpointList::dump () const { + for (map::const_iterator i = _lists.begin(); i != _lists.end(); ++i) { for (ListType::const_iterator i2 = (*i).second.begin(); i2 != (*i).second.end(); ++i2) { - LOGGER_INFO << "using endpoint '" << (*i2)->getSpecification() << "' for " << getName((*i).first) << " requests"; + LOGGER_INFO << "using endpoint '" << (*i2)->getSpecification() << "' for " << (*i).first << " requests"; } } } - //////////////////////////////////////////////////////////////////////////////// /// @brief return all endpoints for a specific protocol //////////////////////////////////////////////////////////////////////////////// -EndpointList::ListType EndpointList::getEndpoints (const Endpoint::Protocol protocol) const { +EndpointList::ListType EndpointList::getEndpoints (const Endpoint::Protocol protocol, const Endpoint::Encryption encryption) const { EndpointList::ListType result; - map::const_iterator i = _lists.find(protocol); + map::const_iterator i = _lists.find(getKey(protocol, encryption)); if (i != _lists.end()) { for (ListType::const_iterator i2 = i->second.begin(); i2 != i->second.end(); ++i2) { @@ -112,8 +125,8 @@ EndpointList::ListType EndpointList::getEndpoints (const Endpoint::Protocol prot /// @brief adds an endpoint for a specific protocol //////////////////////////////////////////////////////////////////////////////// -bool EndpointList::addEndpoint (const Endpoint::Protocol protocol, Endpoint* endpoint) { - _lists[protocol].insert(endpoint); +bool EndpointList::addEndpoint (const Endpoint::Protocol protocol, const Endpoint::Encryption encryption, Endpoint* endpoint) { + _lists[getKey(protocol, encryption)].insert(endpoint); return true; } diff --git a/lib/Rest/EndpointList.h b/lib/Rest/EndpointList.h index c24ea47a30..1c2de74ec7 100644 --- a/lib/Rest/EndpointList.h +++ b/lib/Rest/EndpointList.h @@ -93,6 +93,25 @@ namespace triagens { // --SECTION-- public methods // ----------------------------------------------------------------------------- +//////////////////////////////////////////////////////////////////////////////// +/// @addtogroup Rest +/// @{ +//////////////////////////////////////////////////////////////////////////////// + +// ----------------------------------------------------------------------------- +// --SECTION-- private methods +// ----------------------------------------------------------------------------- + + static const string getKey (const Endpoint::Protocol protocol, + const Endpoint::Encryption encryption) { + return string(getProtocolName(protocol) + " " + getEncryptionName(encryption)); + } + + +//////////////////////////////////////////////////////////////////////////////// +/// @} +//////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////// /// @addtogroup Rest /// @{ @@ -104,12 +123,10 @@ namespace triagens { /// @brief return a protocol name //////////////////////////////////////////////////////////////////////////////// - static const string getName (const Endpoint::Protocol protocol) { + static const string getProtocolName (const Endpoint::Protocol protocol) { switch (protocol) { case Endpoint::PROTOCOL_BINARY: return "binary"; - case Endpoint::PROTOCOL_HTTPS: - return "https"; case Endpoint::PROTOCOL_HTTP: return "http"; default: @@ -117,37 +134,43 @@ namespace triagens { } } +//////////////////////////////////////////////////////////////////////////////// +/// @brief return a encryption name +//////////////////////////////////////////////////////////////////////////////// + + static const string getEncryptionName (const Endpoint::Encryption encryption) { + switch (encryption) { + case Endpoint::ENCRYPTION_SSL: + return "ssl"; + case Endpoint::ENCRYPTION_NONE: + default: + return "tcp"; + } + } + //////////////////////////////////////////////////////////////////////////////// /// @brief count the number of elements in a sub-list //////////////////////////////////////////////////////////////////////////////// - size_t count (const Endpoint::Protocol protocol) { - map::const_iterator i = _lists.find(protocol); - - if (i == _lists.end()) { - return 0; - } - - return i->second.size(); - } + size_t count (const Endpoint::Protocol, const Endpoint::Encryption) const; //////////////////////////////////////////////////////////////////////////////// /// @brief dump all used endpoints //////////////////////////////////////////////////////////////////////////////// - void dump(); + void dump() const ; //////////////////////////////////////////////////////////////////////////////// /// @brief return all endpoints for a specific protocol //////////////////////////////////////////////////////////////////////////////// - ListType getEndpoints (const Endpoint::Protocol) const; + ListType getEndpoints (const Endpoint::Protocol, const Endpoint::Encryption) const; //////////////////////////////////////////////////////////////////////////////// /// @brief adds an endpoint for a specific protocol //////////////////////////////////////////////////////////////////////////////// - bool addEndpoint (const Endpoint::Protocol, Endpoint*); + bool addEndpoint (const Endpoint::Protocol, const Endpoint::Encryption, Endpoint*); //////////////////////////////////////////////////////////////////////////////// /// @} @@ -168,7 +191,7 @@ namespace triagens { /// @brief lists of endpoints //////////////////////////////////////////////////////////////////////////////// - map _lists; + map _lists; //////////////////////////////////////////////////////////////////////////////// /// @} diff --git a/lib/Scheduler/ListenTask.cpp b/lib/Scheduler/ListenTask.cpp index 22ec2ef36d..2f7cf28a82 100644 --- a/lib/Scheduler/ListenTask.cpp +++ b/lib/Scheduler/ListenTask.cpp @@ -193,7 +193,7 @@ bool ListenTask::handleEvent (EventToken token, EventType revents) { // ----------------------------------------------------------------------------- bool ListenTask::bindSocket () { - listenSocket = _endpoint->connect(); + listenSocket = _endpoint->connect(30, 300); // connect timeout in seconds if (listenSocket == 0) { return false; } diff --git a/lib/SimpleHttpClient/ClientConnection.cpp b/lib/SimpleHttpClient/ClientConnection.cpp index 537a9e3229..b88d26db39 100644 --- a/lib/SimpleHttpClient/ClientConnection.cpp +++ b/lib/SimpleHttpClient/ClientConnection.cpp @@ -116,34 +116,22 @@ bool ClientConnection::checkSocket () { //////////////////////////////////////////////////////////////////////////////// bool ClientConnection::connectSocket () { - _socket = _endpoint->connect(); + if (_endpoint->isConnected()) { + _endpoint->disconnect(); + } + _socket = _endpoint->connect(_connectTimeout, _requestTimeout); if (_socket == 0) { return false; } - struct timeval tv; - fd_set fdset; - - tv.tv_sec = (uint64_t) _connectTimeout; - tv.tv_usec = ((uint64_t) (_connectTimeout * 1000000.0)) % 1000000; - - FD_ZERO(&fdset); - FD_SET(_socket, &fdset); - - if (select(_socket + 1, NULL, &fdset, NULL, &tv) > 0) { - if (checkSocket()) { - return _endpoint->isConnected(); - } - - return false; + if (checkSocket()) { + return _endpoint->isConnected(); } - // connect timeout reached - disconnect(); - return false; } + //////////////////////////////////////////////////////////////////////////////// /// @brief disconnect //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/SimpleHttpClient/GeneralClientConnection.cpp b/lib/SimpleHttpClient/GeneralClientConnection.cpp index 53eeba25ce..1a06320e02 100644 --- a/lib/SimpleHttpClient/GeneralClientConnection.cpp +++ b/lib/SimpleHttpClient/GeneralClientConnection.cpp @@ -26,6 +26,11 @@ //////////////////////////////////////////////////////////////////////////////// #include "GeneralClientConnection.h" +#include "SimpleHttpClient/ClientConnection.h" + +#ifdef TRI_OPENSSL_VERSION +#include "SimpleHttpClient/SslClientConnection.h" +#endif using namespace triagens::basics; using namespace triagens::rest; @@ -79,6 +84,27 @@ GeneralClientConnection::~GeneralClientConnection () { /// @{ //////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +/// @brief create a new connection from an endpoint +//////////////////////////////////////////////////////////////////////////////// + +GeneralClientConnection* GeneralClientConnection::factory (Endpoint* endpoint, + double requestTimeout, + double connectTimeout, + size_t numRetries) { + if (endpoint->getEncryption() == Endpoint::ENCRYPTION_NONE) { + return new ClientConnection(endpoint, requestTimeout, connectTimeout, numRetries); + } +#ifdef TRI_OPENSSL_VERSION + else if (endpoint->getEncryption() == Endpoint::ENCRYPTION_SSL) { + return new SslClientConnection(endpoint, requestTimeout, connectTimeout, numRetries); + } +#endif + else { + return 0; + } +} + //////////////////////////////////////////////////////////////////////////////// /// @brief connect //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/SimpleHttpClient/GeneralClientConnection.h b/lib/SimpleHttpClient/GeneralClientConnection.h index 829db89bf6..96bc245bc1 100644 --- a/lib/SimpleHttpClient/GeneralClientConnection.h +++ b/lib/SimpleHttpClient/GeneralClientConnection.h @@ -113,6 +113,12 @@ namespace triagens { public: +//////////////////////////////////////////////////////////////////////////////// +/// @brief create a new connection from an endpoint +//////////////////////////////////////////////////////////////////////////////// + + static GeneralClientConnection* factory (triagens::rest::Endpoint*, double, double, size_t); + //////////////////////////////////////////////////////////////////////////////// /// @brief return the endpoint //////////////////////////////////////////////////////////////////////////////// diff --git a/lib/SimpleHttpClient/SslClientConnection.cpp b/lib/SimpleHttpClient/SslClientConnection.cpp index b95bb31516..0b83cf07da 100644 --- a/lib/SimpleHttpClient/SslClientConnection.cpp +++ b/lib/SimpleHttpClient/SslClientConnection.cpp @@ -102,14 +102,17 @@ SslClientConnection::~SslClientConnection () { //////////////////////////////////////////////////////////////////////////////// bool SslClientConnection::connectSocket () { - _socket = _endpoint->connect(); - if (_socket <= 0) { + _socket = _endpoint->connect(_connectTimeout, _requestTimeout); + + if (_socket <= 0 || _ctx == 0) { return false; } + _ssl = SSL_new(_ctx); if (_ssl == 0) { _endpoint->disconnect(); _socket = 0; + return false; } @@ -118,6 +121,7 @@ bool SslClientConnection::connectSocket () { SSL_free(_ssl); _ssl = 0; _socket = 0; + return false; } @@ -131,11 +135,7 @@ bool SslClientConnection::connectSocket () { _socket = 0; return false; } - /* - _writeBlockedOnRead = false; - _readBlockedOnWrite = false; - _readBlocked = false; -*/ + return true; } @@ -194,13 +194,7 @@ bool SslClientConnection::write (void* buffer, size_t length, size_t* bytesWritt if (_ssl == 0) { return false; } -/* - if (! (_canWrite || _writeBlockedOnRead && _canRead)) { - return false; - } - - _writeBlockedOnRead = false; -*/ + int written = SSL_write(_ssl, buffer, length); switch (SSL_get_error(_ssl, written)) { case SSL_ERROR_NONE: @@ -232,14 +226,7 @@ bool SslClientConnection::read (StringBuffer& stringBuffer) { if (_ssl == 0) { return false; } -/* - if (! ((_canRead && !_writeBlockedOnRead) || (_readBlockedOnWrite && _canWrite))) { - return false; - } - - _readBlocked = false; - _readBlockedOnWrite = false; -*/ + do { char buffer[READBUFFER_SIZE]; diff --git a/lib/SimpleHttpClient/SslClientConnection.h b/lib/SimpleHttpClient/SslClientConnection.h index e3d62a08a4..30bc8b3522 100644 --- a/lib/SimpleHttpClient/SslClientConnection.h +++ b/lib/SimpleHttpClient/SslClientConnection.h @@ -162,11 +162,6 @@ namespace triagens { SSL_CTX* _ctx; -/* - bool _writeBlockedOnRead; - bool _readBlockedOnWrite; - bool _readBlocked; -*/ //////////////////////////////////////////////////////////////////////////////// /// @} ////////////////////////////////////////////////////////////////////////////////