mirror of https://gitee.com/bigwinds/arangodb
460 lines
14 KiB
C++
460 lines
14 KiB
C++
////////////////////////////////////////////////////////////////////////////////
|
|
/// DISCLAIMER
|
|
///
|
|
/// Copyright 2014-2016 ArangoDB GmbH, Cologne, Germany
|
|
/// Copyright 2004-2014 triAGENS GmbH, Cologne, Germany
|
|
///
|
|
/// Licensed under the Apache License, Version 2.0 (the "License");
|
|
/// you may not use this file except in compliance with the License.
|
|
/// You may obtain a copy of the License at
|
|
///
|
|
/// http://www.apache.org/licenses/LICENSE-2.0
|
|
///
|
|
/// Unless required by applicable law or agreed to in writing, software
|
|
/// distributed under the License is distributed on an "AS IS" BASIS,
|
|
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
/// See the License for the specific language governing permissions and
|
|
/// limitations under the License.
|
|
///
|
|
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
|
|
///
|
|
/// @author Jan Steemann
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include <cstring>
|
|
#include "Endpoint.h"
|
|
|
|
#include "Basics/Exceptions.h"
|
|
#include "Basics/StringUtils.h"
|
|
#include "Basics/socket-utils.h"
|
|
#include "Endpoint/EndpointIpV4.h"
|
|
#include "Endpoint/EndpointIpV6.h"
|
|
#include "Endpoint/EndpointSrv.h"
|
|
#include "Logger/Logger.h"
|
|
|
|
#if ARANGODB_HAVE_DOMAIN_SOCKETS
|
|
#include "Endpoint/EndpointUnixDomain.h"
|
|
#endif
|
|
|
|
#include <algorithm>
|
|
|
|
using namespace arangodb;
|
|
using namespace arangodb::basics;
|
|
|
|
Endpoint::Endpoint(DomainType domainType, EndpointType type,
|
|
TransportType transport, EncryptionType encryption,
|
|
std::string const& specification, int listenBacklog)
|
|
: _domainType(domainType),
|
|
_type(type),
|
|
_transport(transport),
|
|
_encryption(encryption),
|
|
_specification(specification),
|
|
_listenBacklog(listenBacklog),
|
|
_connected(false) {
|
|
TRI_invalidatesocket(&_socket);
|
|
}
|
|
|
|
std::string Endpoint::uriForm(std::string const& endpoint) {
|
|
static std::string illegal;
|
|
|
|
if (StringUtils::isPrefix(endpoint, "http+tcp://")) {
|
|
return "http://" + endpoint.substr(11);
|
|
} else if (StringUtils::isPrefix(endpoint, "http+ssl://")) {
|
|
return "https://" + endpoint.substr(11);
|
|
} else if (StringUtils::isPrefix(endpoint, "tcp://")) {
|
|
return "http://" + endpoint.substr(6);
|
|
} else if (StringUtils::isPrefix(endpoint, "ssl://")) {
|
|
return "https://" + endpoint.substr(6);
|
|
} else if (StringUtils::isPrefix(endpoint, "unix://")) {
|
|
return endpoint;
|
|
} else if (StringUtils::isPrefix(endpoint, "http+unix://")) {
|
|
return "unix://" + endpoint.substr(12);
|
|
} else {
|
|
return illegal;
|
|
}
|
|
}
|
|
|
|
std::string Endpoint::unifiedForm(std::string const& specification) {
|
|
static std::string illegal;
|
|
|
|
if (specification.size() < 7) {
|
|
return illegal;
|
|
}
|
|
|
|
TransportType protocol = TransportType::HTTP;
|
|
|
|
std::string prefix("http+");
|
|
std::string const localName("localhost");
|
|
std::string const localIP("127.0.0.1");
|
|
|
|
std::string copy = specification;
|
|
StringUtils::trimInPlace(copy);
|
|
|
|
if (specification.back() == '/') {
|
|
// address ends with a slash => remove
|
|
copy.pop_back();
|
|
}
|
|
|
|
size_t pos = copy.find("://");
|
|
if (pos == std::string::npos) {
|
|
return illegal;
|
|
}
|
|
// lowercase schema for prefix-checks
|
|
std::string schema = StringUtils::tolower(copy.substr(0, pos + 3));
|
|
|
|
// read protocol from string
|
|
if (StringUtils::isPrefix(schema, "http+") ||
|
|
StringUtils::isPrefix(schema, "http@")) {
|
|
protocol = TransportType::HTTP;
|
|
prefix = "http+";
|
|
copy = copy.substr(5);
|
|
schema = schema.substr(5);
|
|
}
|
|
|
|
if (StringUtils::isPrefix(schema, "vst+")) {
|
|
protocol = TransportType::VST;
|
|
prefix = "vst+";
|
|
copy = copy.substr(4);
|
|
schema = schema.substr(4);
|
|
}
|
|
|
|
if (StringUtils::isPrefix(schema, "unix://")) {
|
|
#if ARANGODB_HAVE_DOMAIN_SOCKETS
|
|
return prefix + schema + copy.substr(7);
|
|
#else
|
|
// no unix socket for windows
|
|
return illegal;
|
|
#endif
|
|
}
|
|
|
|
if (StringUtils::isPrefix(schema, "srv://")) {
|
|
#ifndef _WIN32
|
|
return prefix + schema + copy.substr(6);
|
|
#else
|
|
return illegal;
|
|
#endif
|
|
}
|
|
|
|
// strip tcp:// or ssl://
|
|
if (StringUtils::isPrefix(schema, "ssl://")) {
|
|
prefix.append("ssl://");
|
|
} else if (StringUtils::isPrefix(schema, "tcp://")) {
|
|
prefix.append("tcp://");
|
|
} else {
|
|
return illegal;
|
|
}
|
|
copy = StringUtils::tolower(copy.substr(6, copy.length()));
|
|
|
|
// handle tcp or ssl
|
|
size_t found;
|
|
if (copy[0] == '[') {
|
|
// ipv6
|
|
found = copy.find("]:", 1);
|
|
if (found != std::string::npos && found > 2 && found + 2 < copy.size()) {
|
|
// hostname and port (e.g. [address]:port)
|
|
return prefix + copy;
|
|
}
|
|
|
|
found = copy.find("]", 1);
|
|
if (found != std::string::npos && found > 2 && found + 1 == copy.size()) {
|
|
// hostname only (e.g. [address])
|
|
if (protocol == TransportType::VST) {
|
|
return prefix + copy + ":" + StringUtils::itoa(EndpointIp::_defaultPortVst);
|
|
} else {
|
|
return prefix + copy + ":" + StringUtils::itoa(EndpointIp::_defaultPortHttp);
|
|
}
|
|
}
|
|
|
|
// invalid address specification
|
|
return illegal;
|
|
}
|
|
|
|
// Replace localhost with 127.0.0.1
|
|
found = copy.find(localName);
|
|
if (found != std::string::npos) {
|
|
copy.replace(found, localName.length(), localIP);
|
|
}
|
|
|
|
// ipv4
|
|
found = copy.find(':');
|
|
if (found != std::string::npos && found + 1 < copy.size()) {
|
|
// hostname and port
|
|
return prefix + copy;
|
|
}
|
|
|
|
// hostname only
|
|
if (protocol == TransportType::HTTP) {
|
|
return prefix + copy + ":" + StringUtils::itoa(EndpointIp::_defaultPortHttp);
|
|
} else {
|
|
return prefix + copy + ":" + StringUtils::itoa(EndpointIp::_defaultPortVst);
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief create a server endpoint object from a string value
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
Endpoint* Endpoint::serverFactory(std::string const& specification,
|
|
int listenBacklog, bool reuseAddress) {
|
|
return Endpoint::factory(EndpointType::SERVER, specification, listenBacklog, reuseAddress);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief create a client endpoint object from a string value
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
Endpoint* Endpoint::clientFactory(std::string const& specification) {
|
|
return Endpoint::factory(EndpointType::CLIENT, specification, 0, false);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief create an endpoint object from a string value
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
Endpoint* Endpoint::factory(const Endpoint::EndpointType type, std::string const& specification,
|
|
int listenBacklog, bool reuseAddress) {
|
|
if (specification.size() < 7) {
|
|
return nullptr;
|
|
}
|
|
|
|
if (listenBacklog > 0 && type == EndpointType::CLIENT) {
|
|
// backlog is only allowed for server endpoints
|
|
TRI_ASSERT(false);
|
|
}
|
|
|
|
if (listenBacklog == 0 && type == EndpointType::SERVER) {
|
|
// use some default value
|
|
listenBacklog = 10;
|
|
}
|
|
|
|
std::string copy = unifiedForm(specification);
|
|
TransportType protocol = TransportType::HTTP;
|
|
|
|
if (StringUtils::isPrefix(copy, "http+")) {
|
|
copy = copy.substr(5);
|
|
} else {
|
|
// invalid protocol
|
|
return nullptr;
|
|
}
|
|
|
|
EncryptionType encryption = EncryptionType::NONE;
|
|
|
|
if (StringUtils::isPrefix(copy, "unix://")) {
|
|
#if ARANGODB_HAVE_DOMAIN_SOCKETS
|
|
return new EndpointUnixDomain(type, listenBacklog, copy.substr(7));
|
|
#else
|
|
// no unix socket for windows
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
|
|
if (StringUtils::isPrefix(copy, "srv://")) {
|
|
if (type != EndpointType::CLIENT) {
|
|
return nullptr;
|
|
}
|
|
|
|
#ifndef _WIN32
|
|
return new EndpointSrv(copy.substr(6));
|
|
#else
|
|
return nullptr;
|
|
#endif
|
|
}
|
|
|
|
if (StringUtils::isPrefix(copy, "ssl://")) {
|
|
encryption = EncryptionType::SSL;
|
|
} else if (!StringUtils::isPrefix(copy, "tcp://")) {
|
|
// invalid type
|
|
return nullptr;
|
|
}
|
|
|
|
// tcp or ssl
|
|
copy = copy.substr(6);
|
|
uint16_t defaultPort = EndpointIp::_defaultPortHttp;
|
|
size_t found;
|
|
|
|
if (copy[0] == '[') {
|
|
found = copy.find("]:", 1);
|
|
|
|
// hostname and port (e.g. [address]:port)
|
|
if (found != std::string::npos && found > 2 && found + 2 < copy.size()) {
|
|
int64_t value = StringUtils::int64(copy.substr(found + 2));
|
|
// check port over-/underrun
|
|
if (value < (std::numeric_limits<uint16_t>::min)() ||
|
|
value > (std::numeric_limits<uint16_t>::max)()) {
|
|
LOG_TOPIC("7ccf9", ERR, arangodb::Logger::FIXME)
|
|
<< "specified port number '" << value << "' is outside the allowed range";
|
|
return nullptr;
|
|
}
|
|
uint16_t port = static_cast<uint16_t>(value);
|
|
std::string host = copy.substr(1, found - 1);
|
|
|
|
return new EndpointIpV6(type, protocol, encryption, listenBacklog,
|
|
reuseAddress, host, port);
|
|
}
|
|
|
|
found = copy.find("]", 1);
|
|
|
|
// hostname only (e.g. [address])
|
|
if (found != std::string::npos && found > 2 && found + 1 == copy.size()) {
|
|
std::string host = copy.substr(1, found - 1);
|
|
|
|
return new EndpointIpV6(type, protocol, encryption, listenBacklog,
|
|
reuseAddress, host, defaultPort);
|
|
}
|
|
|
|
// invalid address specification
|
|
return nullptr;
|
|
}
|
|
|
|
// ipv4
|
|
found = copy.find(':');
|
|
|
|
// hostname and port
|
|
if (found != std::string::npos && found + 1 < copy.size()) {
|
|
int64_t value = StringUtils::int64(copy.substr(found + 1));
|
|
// check port over-/underrun
|
|
if (value < (std::numeric_limits<uint16_t>::min)() ||
|
|
value > (std::numeric_limits<uint16_t>::max)()) {
|
|
LOG_TOPIC("127ce", ERR, arangodb::Logger::FIXME)
|
|
<< "specified port number '" << value << "' is outside the allowed range";
|
|
return nullptr;
|
|
}
|
|
uint16_t port = static_cast<uint16_t>(value);
|
|
std::string host = copy.substr(0, found);
|
|
|
|
return new EndpointIpV4(type, protocol, encryption, listenBacklog,
|
|
reuseAddress, host, port);
|
|
}
|
|
|
|
// hostname only
|
|
return new EndpointIpV4(type, protocol, encryption, listenBacklog,
|
|
reuseAddress, copy, defaultPort);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief return the default endpoint (http/vstream)
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
std::string const Endpoint::defaultEndpoint(TransportType type) {
|
|
switch (type) {
|
|
case TransportType::HTTP:
|
|
return "http+tcp://" + std::string(EndpointIp::_defaultHost) + ":" +
|
|
StringUtils::itoa(EndpointIp::_defaultPortHttp);
|
|
|
|
case TransportType::VST:
|
|
return "vst+tcp://" + std::string(EndpointIp::_defaultHost) + ":" +
|
|
StringUtils::itoa(EndpointIp::_defaultPortVst);
|
|
|
|
default: {
|
|
THROW_ARANGO_EXCEPTION_MESSAGE(TRI_ERROR_INTERNAL,
|
|
"invalid transport type");
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief compare two endpoints
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool Endpoint::operator==(Endpoint const& that) const {
|
|
return specification() == that.specification();
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief set socket timeout
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool Endpoint::setTimeout(TRI_socket_t s, double timeout) {
|
|
return TRI_setsockopttimeout(s, timeout);
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
/// @brief set common socket flags
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
bool Endpoint::setSocketFlags(TRI_socket_t s) {
|
|
if (_encryption == EncryptionType::SSL && _type == EndpointType::CLIENT) {
|
|
// SSL client endpoints are not set to non-blocking
|
|
return true;
|
|
}
|
|
|
|
// set to non-blocking, executed for both client and server endpoints
|
|
bool ok = TRI_SetNonBlockingSocket(s);
|
|
|
|
if (!ok) {
|
|
LOG_TOPIC("572b6", ERR, arangodb::Logger::FIXME) << "cannot switch to non-blocking: " << errno
|
|
<< " (" << strerror(errno) << ")";
|
|
|
|
return false;
|
|
}
|
|
|
|
// set close-on-exec flag, executed for both client and server endpoints
|
|
ok = TRI_SetCloseOnExecSocket(s);
|
|
|
|
if (!ok) {
|
|
LOG_TOPIC("1ef8e", ERR, arangodb::Logger::FIXME)
|
|
<< "cannot set close-on-exit: " << errno << " (" << strerror(errno) << ")";
|
|
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, arangodb::Endpoint::TransportType type) {
|
|
switch (type) {
|
|
case arangodb::Endpoint::TransportType::HTTP:
|
|
stream << "http";
|
|
break;
|
|
case arangodb::Endpoint::TransportType::VST:
|
|
stream << "vst";
|
|
break;
|
|
}
|
|
return stream;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, arangodb::Endpoint::EndpointType type) {
|
|
switch (type) {
|
|
case arangodb::Endpoint::EndpointType::SERVER:
|
|
stream << "server";
|
|
break;
|
|
case arangodb::Endpoint::EndpointType::CLIENT:
|
|
stream << "client";
|
|
break;
|
|
}
|
|
return stream;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, arangodb::Endpoint::EncryptionType type) {
|
|
switch (type) {
|
|
case arangodb::Endpoint::EncryptionType::NONE:
|
|
stream << "none";
|
|
break;
|
|
case arangodb::Endpoint::EncryptionType::SSL:
|
|
stream << "ssl";
|
|
break;
|
|
}
|
|
return stream;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& stream, arangodb::Endpoint::DomainType type) {
|
|
switch (type) {
|
|
case arangodb::Endpoint::DomainType::UNIX:
|
|
stream << "unix";
|
|
break;
|
|
case arangodb::Endpoint::DomainType::IPV4:
|
|
stream << "ipv4";
|
|
break;
|
|
case arangodb::Endpoint::DomainType::IPV6:
|
|
stream << "ipv6";
|
|
break;
|
|
case arangodb::Endpoint::DomainType::SRV:
|
|
stream << "srv";
|
|
break;
|
|
case arangodb::Endpoint::DomainType::UNKNOWN:
|
|
stream << "unknown";
|
|
break;
|
|
}
|
|
return stream;
|
|
}
|