//////////////////////////////////////////////////////////////////////////////// /// DISCLAIMER /// /// Copyright 2014-2019 ArangoDB GmbH, Cologne, Germany /// Copyright 2004-2014 triAGENS GmbH, Cologne, Germany /// /// Licensed under the Apache License, Version 2.0 (the "License"); /// you may not use this file except in compliance with the License. /// You may obtain a copy of the License at /// /// http://www.apache.org/licenses/LICENSE-2.0 /// /// Unless required by applicable law or agreed to in writing, software /// distributed under the License is distributed on an "AS IS" BASIS, /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. /// See the License for the specific language governing permissions and /// limitations under the License. /// /// Copyright holder is ArangoDB GmbH, Cologne, Germany /// /// @author Jan Christoph Uhde /// @author Simon Grätzer //////////////////////////////////////////////////////////////////////////////// #include "VstCommTask.h" #include "Basics/HybridLogicalClock.h" #include "Basics/Result.h" #include "Basics/ScopeGuard.h" #include "Basics/VelocyPackHelper.h" #include "Basics/tryEmplaceHelper.h" #include "Cluster/ServerState.h" #include "GeneralServer/AuthenticationFeature.h" #include "GeneralServer/GeneralServer.h" #include "GeneralServer/GeneralServerFeature.h" #include "GeneralServer/RestHandler.h" #include "Logger/LogMacros.h" #include "Logger/Logger.h" #include "Logger/LoggerFeature.h" #include "Logger/LoggerStream.h" #include "Meta/conversion.h" #include "RestServer/ServerFeature.h" #include "Statistics/ConnectionStatistics.h" #include "Statistics/RequestStatistics.h" #include #include #include using namespace arangodb; using namespace arangodb::basics; using namespace arangodb::rest; template VstCommTask::VstCommTask(GeneralServer& server, ConnectionInfo info, std::unique_ptr> so, fuerte::vst::VSTVersion v) : GeneralCommTask(server, "VstCommTask", std::move(info), std::move(so)), _writing(false), _authorized(!this->_auth->isActive()), _authMethod(rest::AuthenticationMethod::NONE), _vstVersion(v) {} template VstCommTask::~VstCommTask() { ResponseItem* tmp = nullptr; while (_writeQueue.pop(tmp)) { delete tmp; } } /// @brief send simple response including response body template void VstCommTask::addSimpleResponse(rest::ResponseCode code, rest::ContentType respType, uint64_t messageId, velocypack::Buffer&& buffer) { auto resp = std::make_unique(code, messageId); TRI_ASSERT(respType == rest::ContentType::VPACK); // or not ? resp->setContentType(respType); try { if (!buffer.empty()) { resp->setPayload(std::move(buffer), true, VPackOptions::Defaults); } sendResponse(std::move(resp), this->stealStatistics(messageId)); } catch (...) { this->close(); } } template bool VstCommTask::readCallback(asio_ns::error_code ec) { using namespace fuerte; if (ec) { if (ec != asio_ns::error::misc_errors::eof) { LOG_TOPIC("495fe", INFO, Logger::REQUESTS) << "Error while reading from socket: '" << ec.message() << "'"; } this->close(); return false; } // Inspect the data we've received so far. auto recvBuffs = this->_protocol->buffer.data(); // no copy auto cursor = asio_ns::buffer_cast(recvBuffs); auto available = asio_ns::buffer_size(recvBuffs); // TODO technically buffer_cast is deprecated size_t parsedBytes = 0; while (true) { vst::Chunk chunk; vst::parser::ChunkState state = vst::parser::ChunkState::Invalid; if (_vstVersion == vst::VST1_1) { state = vst::parser::readChunkVST1_1(chunk, cursor, available); } else if (_vstVersion == vst::VST1_0) { state = vst::parser::readChunkVST1_0(chunk, cursor, available); } if (vst::parser::ChunkState::Incomplete == state) { break; } else if (vst::parser::ChunkState::Invalid == state) { // actually should never happen this->close(); return false; // stop read loop } // move cursors cursor += chunk.header.chunkLength(); available -= chunk.header.chunkLength(); parsedBytes += chunk.header.chunkLength(); // Process chunk if (!processChunk(chunk)) { this->close(); return false; // stop read loop } } // Remove consumed data from receive buffer. this->_protocol->buffer.consume(parsedBytes); return true; // continue read loop } // Process the given incoming chunk. template bool VstCommTask::processChunk(fuerte::vst::Chunk const& chunk) { if (chunk.body.size() > CommTask::MaximalBodySize) { LOG_TOPIC("695ef", WARN, Logger::REQUESTS) << "\"vst-request\"; chunk is too big for server, \"" << chunk.body.size() << "\", this=" << this; return false; // close connection } else if (chunk.body.size() == 0) { LOG_TOPIC("695ff", WARN, Logger::REQUESTS) << "\"vst-request\"; chunk was empty, this=" << this; return false; // close connection } if (chunk.header.isFirst()) { RequestStatistics* stat = this->acquireStatistics(chunk.header.messageID()); RequestStatistics::SET_READ_START(stat, TRI_microtime()); // single chunk optimization if (chunk.header.numberOfChunks() == 1) { VPackBuffer buffer; // TODO lease buffers ? buffer.append(reinterpret_cast(chunk.body.data()), chunk.body.size()); return processMessage(std::move(buffer), chunk.header.messageID()); } } // Find stored message for this chunk. auto it = _messages.try_emplace( chunk.header.messageID(), arangodb::lazyConstruct([&]{ return std::make_unique(); })).first; Message* msg = it->second.get(); // returns false if message gets too big if (!msg->addChunk(chunk)) { LOG_TOPIC("695fd", WARN, Logger::REQUESTS) << "\"vst-request\"; chunk contents have become larger than " << "allowed, this=" << this; return false; // close connection } if (!msg->assemble()) { return true; // wait for more chunks } //this->_proto->timer.cancel(); auto guard = scopeGuard([&] { _messages.erase(chunk.header.messageID()); }); return processMessage(std::move(msg->buffer), chunk.header.messageID()); } /// process a VST message template bool VstCommTask::processMessage(velocypack::Buffer buffer, uint64_t messageId) { using namespace fuerte; auto ptr = buffer.data(); auto len = buffer.byteSize(); // first part of the buffer contains the response buffer std::size_t headerLength = 0; MessageType mt = MessageType::Undefined; try { mt = vst::parser::validateAndExtractMessageType(ptr, len, headerLength); } catch(std::exception const& e) { LOG_TOPIC("6479a", ERR, Logger::REQUESTS) << "\"vst-request\"; invalid message: '" << e.what() << "'"; // error is handled below } RequestStatistics* stat = this->statistics(messageId); RequestStatistics::SET_READ_END(stat); RequestStatistics::ADD_RECEIVED_BYTES(stat, buffer.size()); // handle request types if (mt == MessageType::Authentication) { // auth handleAuthHeader(VPackSlice(buffer.data()), messageId); // Separate superuser traffic: // Note that currently, velocystream traffic will never come from // a forwarding, since we always forward with HTTP. if (_authMethod != AuthenticationMethod::NONE && _authorized && this->_authToken._username.empty()) { RequestStatistics::SET_SUPERUSER(stat); } } else if (mt == MessageType::Request) { // request VPackSlice header(buffer.data()); // the handler will take ownership of this pointer auto req = std::make_unique(this->_connectionInfo, std::move(buffer), /*payloadOffset*/headerLength, messageId); req->setAuthenticated(_authorized); req->setUser(this->_authToken._username); req->setAuthenticationMethod(_authMethod); if (_authorized && this->_auth->userManager() != nullptr) { // if we don't call checkAuthentication we need to refresh this->_auth->userManager()->refreshUser(this->_authToken._username); } // Separate superuser traffic: // Note that currently, velocystream traffic will never come from // a forwarding, since we always forward with HTTP. if (_authMethod != AuthenticationMethod::NONE && _authorized && this->_authToken._username.empty()) { RequestStatistics::SET_SUPERUSER(stat); } LOG_TOPIC("92fd6", DEBUG, Logger::REQUESTS) << "\"vst-request-begin\",\"" << (void*)this << "\",\"" << this->_connectionInfo.clientAddress << "\",\"" << VstRequest::translateMethod(req->requestType()) << "\",\"" << (req->databaseName().empty() ? "" : "/_db/" + req->databaseName()) << (Logger::logRequestParameters() ? req->fullUrl() : req->requestPath()) << "\""; CommTask::Flow cont = this->prepareExecution(*req.get()); if (cont == CommTask::Flow::Continue) { auto resp = std::make_unique(rest::ResponseCode::SERVER_ERROR, messageId); resp->setContentTypeRequested(req->contentTypeResponse()); this->executeRequest(std::move(req), std::move(resp)); } } else { // not supported on server LOG_TOPIC("b5073", ERR, Logger::REQUESTS) << "\"vst-request-header\",\"" << (void*)this << "/" << messageId << "\"" << " is unsupported"; addSimpleResponse(rest::ResponseCode::BAD, rest::ContentType::VPACK, messageId, VPackBuffer()); } return true; } template void VstCommTask::sendResponse(std::unique_ptr baseRes, RequestStatistics* stat) { using namespace fuerte; #ifdef ARANGODB_ENABLE_MAINTAINER_MODE VstResponse& response = dynamic_cast(*baseRes); #else VstResponse& response = static_cast(*baseRes); #endif this->finishExecution(*baseRes); auto resItem = std::make_unique(); response.writeMessageHeader(resItem->metadata); resItem->response = std::move(baseRes); RequestStatistics::SET_WRITE_START(stat); resItem->stat = stat; asio_ns::const_buffer payload; if (response.generateBody()) { payload = asio_ns::buffer(response.payload().data(), response.payload().size()); } vst::message::prepareForNetwork(_vstVersion, response.messageId(), resItem->metadata, payload, resItem->buffers); if (stat != nullptr) { LOG_TOPIC("cf80d", TRACE, Logger::REQUESTS) << "\"vst-request-statistics\",\"" << (void*)this << "\",\"" << static_cast(response.responseCode()) << "," << this->_connectionInfo.clientAddress << "\"," << stat->timingsCsv(); } double const totalTime = RequestStatistics::ELAPSED_SINCE_READ_START(stat); // and give some request information LOG_TOPIC("92fd7", DEBUG, Logger::REQUESTS) << "\"vst-request-end\",\"" << (void*)this << "/" << response.messageId() << "\",\"" << this->_connectionInfo.clientAddress << "\",\"" << static_cast(response.responseCode()) << "," << "\"," << Logger::FIXED(totalTime, 6); while (true) { if (_writeQueue.push(resItem.get())) { break; } cpu_relax(); } resItem.release(); bool expected = _writing.load(); if (false == expected) { if (_writing.compare_exchange_strong(expected, true)) { // we managed to start writing if constexpr (SocketType::Ssl == T) { this->_protocol->context.io_context.post([self = this->shared_from_this()]() { static_cast*>(self.get())->doWrite(); }); } else { doWrite(); } } } } template void VstCommTask::doWrite() { TRI_ASSERT(_writing.load() == true); while(true) { // loop instead of using recursion ResponseItem* tmp = nullptr; if (!_writeQueue.pop(tmp)) { // careful now, we need to consider that someone queues // a new request item _writing.store(false); if (_writeQueue.empty()) { break; // done, someone else may restart } bool expected = false; // may fail in a race if (_writing.compare_exchange_strong(expected, true)) { continue; // we re-start writing } TRI_ASSERT(expected == true); break; // someone else restarted writing } TRI_ASSERT(tmp != nullptr); std::unique_ptr item(tmp); auto& buffers = item->buffers; asio_ns::async_write(this->_protocol->socket, buffers, [self(CommTask::shared_from_this()), rsp(std::move(item))] (asio_ns::error_code ec, size_t transferred) { auto* thisPtr = static_cast*>(self.get()); RequestStatistics::SET_WRITE_END(rsp->stat); RequestStatistics::ADD_SENT_BYTES(rsp->stat, rsp->buffers[0].size() + rsp->buffers[1].size()); if (ec) { LOG_TOPIC("5c6b4", INFO, arangodb::Logger::REQUESTS) << "asio write error: '" << ec.message() << "'"; thisPtr->close(); } else { thisPtr->doWrite(); // write next one } if (rsp->stat != nullptr) { rsp->stat->release(); } }); break; // done } } template void VstCommTask::handleAuthHeader(VPackSlice header, uint64_t mId) { std::string authString; std::string user = ""; _authorized = false; _authMethod = AuthenticationMethod::NONE; std::string encryption = header.at(2).copyString(); if (encryption == "jwt") { // doing JWT authString = header.at(3).copyString(); _authMethod = AuthenticationMethod::JWT; } else if (encryption == "plain") { user = header.at(3).copyString(); std::string pass = header.at(4).copyString(); authString = basics::StringUtils::encodeBase64(user + ":" + pass); _authMethod = AuthenticationMethod::BASIC; } else { LOG_TOPIC("01f44", WARN, Logger::REQUESTS) << "Unknown VST encryption type"; } this->_authToken = this->_auth->tokenCache().checkAuthentication(_authMethod, authString); _authorized = this->_authToken.authenticated(); if (this->_authorized || !this->_auth->isActive()) { // simon: drivers expect a response for their auth request this->addErrorResponse(ResponseCode::OK, rest::ContentType::VPACK, mId, TRI_ERROR_NO_ERROR, "auth successful"); } else { this->_authToken = auth::TokenCache::Entry::Unauthenticated(); this->addErrorResponse(rest::ResponseCode::UNAUTHORIZED, rest::ContentType::VPACK, mId, TRI_ERROR_HTTP_UNAUTHORIZED); } } template std::unique_ptr VstCommTask::createResponse(rest::ResponseCode responseCode, uint64_t messageId) { return std::make_unique(responseCode, messageId); } /// @brief add chunk to this message /// @return false if the message size is too big template bool VstCommTask::Message::addChunk(fuerte::vst::Chunk const& chunk) { if (chunk.header.isFirst()) { // only the first chunk has the message length // and number of chunks (in VST/1.0) expectedChunks = chunk.header.numberOfChunks(); expectedMsgSize = chunk.header.messageLength(); chunks.reserve(expectedChunks); TRI_ASSERT(buffer.empty()); if (buffer.capacity() < chunk.header.messageLength()) { buffer.reserve(chunk.header.messageLength() - buffer.capacity()); } } // verify total message body size limit size_t newSize = buffer.size() + chunk.body.size(); if (newSize > CommTask::MaximalBodySize || (expectedMsgSize != 0 && expectedMsgSize < newSize)) { return false; // error } uint8_t const* begin = reinterpret_cast(chunk.body.data()); size_t offset = buffer.size(); buffer.append(begin, chunk.body.size()); // Add chunk to index list chunks.push_back(ChunkInfo{chunk.header.index(), offset, chunk.body.size()}); return true; } /// assemble message, if true result is in _buffer template bool VstCommTask::Message::assemble() { if (expectedChunks == 0 || chunks.size() < expectedChunks) { return false; } // fast-path: chunks received in-order bool reject = false; for (size_t i = 0; i < expectedChunks; i++) { if (chunks[i].index != i) { reject = true; break; } } if (!reject) { // fast-path, chunks are in order return true; } // We now have all chunks. Sort them by index. std::sort(chunks.begin(), chunks.end(), [](auto const& a, auto const& b) { return a.index < b.index; }); // Combine chunk content VPackBuffer cp(std::move(buffer)); buffer.clear(); for (ChunkInfo const& info : chunks) { buffer.append(cp.data() + info.offset, info.size); } return true; } template class arangodb::rest::VstCommTask; template class arangodb::rest::VstCommTask; #ifndef _WIN32 template class arangodb::rest::VstCommTask; #endif