mirror of https://gitee.com/bigwinds/arangodb
257 lines
5.8 KiB
C++
257 lines
5.8 KiB
C++
////////////////////////////////////////////////////////////////////////////////
|
|
/// DISCLAIMER
|
|
///
|
|
/// Copyright 2016 ArangoDB GmbH, Cologne, Germany
|
|
///
|
|
/// Licensed under the Apache License, Version 2.0 (the "License");
|
|
/// you may not use this file except in compliance with the License.
|
|
/// You may obtain a copy of the License at
|
|
///
|
|
/// http://www.apache.org/licenses/LICENSE-2.0
|
|
///
|
|
/// Unless required by applicable law or agreed to in writing, software
|
|
/// distributed under the License is distributed on an "AS IS" BASIS,
|
|
/// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
/// See the License for the specific language governing permissions and
|
|
/// limitations under the License.
|
|
///
|
|
/// Copyright holder is ArangoDB GmbH, Cologne, Germany
|
|
///
|
|
/// @author Dr. Frank Celler
|
|
/// @author Copyright 2016, ArangoDB GmbH, Cologne, Germany
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
#include "EndpointSrv.h"
|
|
|
|
#ifndef _WIN32
|
|
|
|
#define BIND_4_COMPAT 1 // LINUX
|
|
#define BIND_8_COMPAT 1 // MACOSX
|
|
|
|
#include <arpa/inet.h>
|
|
#include <arpa/nameser.h>
|
|
#include <netinet/in.h>
|
|
#include <resolv.h>
|
|
|
|
#include "Basics/StringUtils.h"
|
|
#include "Endpoint/EndpointIp.h"
|
|
#include "Logger/Logger.h"
|
|
|
|
using namespace arangodb;
|
|
using namespace arangodb::basics;
|
|
|
|
#if PACKETSZ > 1024
|
|
#define MAXPACKET PACKETSZ
|
|
#else
|
|
#define MAXPACKET 1024
|
|
#endif
|
|
|
|
union QueryBuffer {
|
|
::HEADER header;
|
|
unsigned char buffer[MAXPACKET];
|
|
};
|
|
|
|
struct SrvRecord {
|
|
int priority;
|
|
int weight;
|
|
int port;
|
|
std::string name;
|
|
};
|
|
|
|
static std::vector<SrvRecord> srvRecords(std::string const& specification) {
|
|
res_init();
|
|
|
|
char const* dname = specification.c_str();
|
|
int nclass = ns_c_in;
|
|
int type = ns_t_srv;
|
|
|
|
QueryBuffer answer;
|
|
int anslen = sizeof(answer);
|
|
|
|
int n = res_search(dname, nclass, type, answer.buffer, anslen);
|
|
|
|
std::vector<SrvRecord> services;
|
|
|
|
if (n != -1) {
|
|
HEADER* hp = &answer.header;
|
|
|
|
int qdcount = ntohs(hp->qdcount);
|
|
int ancount = ntohs(hp->ancount);
|
|
|
|
unsigned char* msg = answer.buffer;
|
|
unsigned char* eom = msg + n;
|
|
unsigned char* cp = msg + sizeof(HEADER);
|
|
|
|
unsigned char hostbuf[256];
|
|
|
|
while (0 < qdcount-- && cp < eom) {
|
|
n = dn_expand(msg, eom, cp, (char*)hostbuf, 256);
|
|
|
|
if (n < 0) {
|
|
LOG_TOPIC(WARN, arangodb::Logger::FIXME) << "DNS record for '" << specification << "' is corrupt";
|
|
return {};
|
|
}
|
|
|
|
cp += n + QFIXEDSZ;
|
|
}
|
|
|
|
// loop through the answer buffer and extract SRV records
|
|
while (0 < ancount-- && cp < eom) {
|
|
n = dn_expand(msg, eom, cp, (char*)hostbuf, 256);
|
|
|
|
if (n < 0) {
|
|
LOG_TOPIC(WARN, arangodb::Logger::FIXME) << "DNS record for '" << specification << "' is corrupt";
|
|
return {};
|
|
}
|
|
|
|
cp += n;
|
|
|
|
int type;
|
|
GETSHORT(type, cp);
|
|
|
|
int nclass;
|
|
GETSHORT(nclass, cp);
|
|
|
|
int ttl;
|
|
GETLONG(ttl, cp);
|
|
|
|
int dlen;
|
|
GETSHORT(dlen, cp);
|
|
|
|
int priority;
|
|
GETSHORT(priority, cp);
|
|
|
|
int weight;
|
|
GETSHORT(weight, cp);
|
|
|
|
int port;
|
|
GETSHORT(port, cp);
|
|
|
|
n = dn_expand(msg, eom, cp, (char*)hostbuf, 256);
|
|
|
|
if (n < 0) {
|
|
LOG_TOPIC(WARN, arangodb::Logger::FIXME) << "DNS record for '" << specification << "' is corrupt";
|
|
break;
|
|
}
|
|
|
|
cp += n;
|
|
|
|
LOG_TOPIC(TRACE, arangodb::Logger::FIXME) << "DNS record for '" << specification << "': type " << type
|
|
<< ", class " << nclass << ", ttl " << ttl << ", len " << dlen
|
|
<< ", prio " << priority << ", weight " << weight << ", port "
|
|
<< port << ", host '" << hostbuf << "'";
|
|
|
|
if (type != T_SRV) {
|
|
continue;
|
|
}
|
|
|
|
SrvRecord srv;
|
|
|
|
srv.weight = weight;
|
|
srv.priority = priority;
|
|
srv.port = port;
|
|
srv.name = (char*)hostbuf;
|
|
|
|
services.push_back(srv);
|
|
}
|
|
} else {
|
|
LOG_TOPIC(WARN, arangodb::Logger::FIXME) << "DNS record for '" << specification << "' not found";
|
|
}
|
|
|
|
std::sort(services.begin(), services.end(),
|
|
[](SrvRecord const& lhs, SrvRecord const& rhs) {
|
|
if (lhs.priority != rhs.priority) {
|
|
return lhs.priority < rhs.priority;
|
|
}
|
|
|
|
return lhs.weight > rhs.weight;
|
|
});
|
|
|
|
return services;
|
|
}
|
|
|
|
#else
|
|
|
|
static std::vector<SrvRecord> srvRecords(std::string const& specification) {
|
|
return {};
|
|
}
|
|
|
|
#endif
|
|
|
|
EndpointSrv::EndpointSrv(std::string const& specification)
|
|
: Endpoint(DomainType::SRV, EndpointType::CLIENT, TransportType::HTTP,
|
|
EncryptionType::NONE, specification, 0) {}
|
|
|
|
EndpointSrv::~EndpointSrv() {}
|
|
|
|
bool EndpointSrv::isConnected() const {
|
|
if (_endpoint != nullptr) {
|
|
return _endpoint->isConnected();
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
TRI_socket_t EndpointSrv::connect(double connectTimeout,
|
|
double requestTimeout) {
|
|
auto services = srvRecords(_specification);
|
|
|
|
TRI_socket_t res;
|
|
|
|
for (auto service : services) {
|
|
std::string spec =
|
|
"tcp://" + service.name + ":" + StringUtils::itoa(service.port);
|
|
|
|
_endpoint.reset(Endpoint::clientFactory(spec));
|
|
|
|
res = _endpoint->connect(connectTimeout, requestTimeout);
|
|
|
|
if (_endpoint->isConnected()) {
|
|
return res;
|
|
}
|
|
}
|
|
|
|
TRI_invalidatesocket(&res);
|
|
return res;
|
|
}
|
|
|
|
void EndpointSrv::disconnect() {
|
|
if (_endpoint != nullptr) {
|
|
_endpoint->disconnect();
|
|
}
|
|
}
|
|
|
|
bool EndpointSrv::initIncoming(TRI_socket_t) { return false; }
|
|
|
|
int EndpointSrv::domain() const {
|
|
if (_endpoint != nullptr) {
|
|
return _endpoint->domain();
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
int EndpointSrv::port() const {
|
|
if (_endpoint != nullptr) {
|
|
return _endpoint->port();
|
|
}
|
|
|
|
return -1;
|
|
}
|
|
|
|
std::string EndpointSrv::host() const {
|
|
if (_endpoint != nullptr) {
|
|
return _endpoint->host();
|
|
}
|
|
|
|
return "";
|
|
}
|
|
|
|
std::string EndpointSrv::hostAndPort() const {
|
|
if (_endpoint != nullptr) {
|
|
return _endpoint->hostAndPort();
|
|
}
|
|
|
|
return "";
|
|
}
|