1
0
Fork 0

added --server.database option for client tools

This commit is contained in:
Jan Steemann 2013-09-02 13:02:53 +02:00
parent 41acd59698
commit 42b8dfca49
26 changed files with 485 additions and 160 deletions

View File

@ -1113,10 +1113,10 @@ static bool handleUserDatabase (TRI_doc_mptr_t const* document,
return true;
}
string dbName = doc.getStringValue("name", "");
string dbPath = doc.getStringValue("path", "");
string databaseName = doc.getStringValue("name", "");
string databasePath = doc.getStringValue("path", "");
int res = VocbaseManager::manager.canAddVocbase(dbName, dbPath, false);
int res = VocbaseManager::manager.canAddVocbase(databaseName, databasePath, false);
if (res != TRI_ERROR_NO_ERROR) {
LOGGER_ERROR("cannot load database: " << string(TRI_errno_string(res)));
@ -1146,15 +1146,15 @@ static bool handleUserDatabase (TRI_doc_mptr_t const* document,
systemDefaults->authenticateSystemOnly);
// open/load database
TRI_vocbase_t* userVocbase = TRI_OpenVocBase(dbPath.c_str(), dbName.c_str(), &defaults);
TRI_vocbase_t* userVocbase = TRI_OpenVocBase(databasePath.c_str(), databaseName.c_str(), &defaults);
if (userVocbase) {
VocbaseManager::manager.addUserVocbase(userVocbase);
LOGGER_INFO("loaded database '" << dbName << "' from '" << dbPath << "'");
LOGGER_INFO("loaded database '" << databaseName << "' from '" << databasePath << "'");
}
else {
LOGGER_ERROR("unable to load database '" << dbName << "' from '" << dbPath << "'");
LOGGER_ERROR("unable to load database '" << databaseName << "' from '" << databasePath << "'");
}
return true;
@ -1228,26 +1228,26 @@ static bool handleEnpoint (TRI_doc_mptr_t const* document,
TRI_json_t const* json = doc.getJson();
vector<std::string> dbNames;
vector<std::string> databaseNames;
if (JsonHelper::isList(json)) {
for (size_t i = 0; i < json->_value._objects._length; ++i) {
TRI_json_t const* e = (TRI_json_t const*) TRI_AtVector(&json->_value._objects, i);
if (JsonHelper::isString(e)) {
const string dbName = JsonHelper::getStringValue(e, "");
const string databaseName = JsonHelper::getStringValue(e, "");
if (! dbName.empty()) {
dbNames.push_back(dbName);
if (! databaseName.empty()) {
databaseNames.push_back(databaseName);
}
}
}
}
else {
dbNames.push_back(TRI_VOC_SYSTEM_DATABASE);
databaseNames.push_back(TRI_VOC_SYSTEM_DATABASE);
}
VocbaseManager::manager.addEndpoint(endpoint, dbNames);
VocbaseManager::manager.addEndpoint(endpoint, databaseNames);
return true;
}

View File

@ -96,15 +96,15 @@ void VocbaseContext::setRequestUserByName (string const& name) {
/// @brief checks the authentication
////////////////////////////////////////////////////////////////////////////////
bool VocbaseContext::authenticate () {
HttpResponse::HttpResponseCode VocbaseContext::authenticate () {
if (! _vocbase) {
// no vocbase known
return true;
return HttpResponse::NOT_FOUND;
}
if (! _vocbase->_requireAuthentication) {
// no authentication required at all
return true;
return HttpResponse::OK;
}
if (_vocbase->_authenticateSystemOnly) {
@ -114,10 +114,10 @@ bool VocbaseContext::authenticate () {
if (path != 0) {
// check if path starts with /_
if (*path != '/') {
return true;
return HttpResponse::OK;
}
if (*path != '\0' && *(path + 1) != '_') {
return true;
return HttpResponse::OK;
}
}
}

View File

@ -30,6 +30,7 @@
#include "VocBase/vocbase.h"
#include "Rest/HttpRequest.h"
#include "Rest/HttpResponse.h"
#include "Rest/RequestContext.h"
#include <map>
#include <string>
@ -127,7 +128,7 @@ namespace triagens {
/// @brief checks the authentication
////////////////////////////////////////////////////////////////////////////////
bool authenticate ();
rest::HttpResponse::HttpResponseCode authenticate ();
////////////////////////////////////////////////////////////////////////////////
/// @}

View File

@ -258,12 +258,12 @@ void VocbaseManager::initializeFoxx (TRI_vocbase_t* vocbase,
////////////////////////////////////////////////////////////////////////////////
bool VocbaseManager::addEndpoint (std::string const& name,
std::vector<std::string> const& dbNames) {
std::vector<std::string> const& databaseNames) {
if (_endpointServer) {
{
WRITE_LOCKER(_rwLock);
_endpoints[name] = dbNames;
_endpoints[name] = databaseNames;
}
return _endpointServer->addEndpoint(name);
@ -280,7 +280,7 @@ TRI_vocbase_t* VocbaseManager::lookupVocbaseByHttpRequest (triagens::rest::HttpR
TRI_vocbase_t* vocbase = 0;
// get database name from request
string requestedName = request->dbName();
string requestedName = request->databaseName();
if (requestedName.empty()) {
// no name set in request, use system database name as a fallback
@ -321,9 +321,9 @@ TRI_vocbase_t* VocbaseManager::lookupVocbaseByHttpRequest (triagens::rest::HttpR
}
// we have a user-defined mapping for the endpoint
const vector<string>& dbNames = (*it2).second;
const vector<string>& databaseNames = (*it2).second;
if (dbNames.size() == 0) {
if (databaseNames.size() == 0) {
// list of database names is specified but empty. this means no-one will get access
return 0;
}
@ -331,7 +331,7 @@ TRI_vocbase_t* VocbaseManager::lookupVocbaseByHttpRequest (triagens::rest::HttpR
// finally check if the requested database is in the list of allowed databases for the endpoint
vector<string>::const_iterator it3;
for (it3 = dbNames.begin(); it3 != dbNames.end(); ++it3) {
for (it3 = databaseNames.begin(); it3 != databaseNames.end(); ++it3) {
if (requestedName == *it3) {
return vocbase;
}
@ -345,80 +345,76 @@ TRI_vocbase_t* VocbaseManager::lookupVocbaseByHttpRequest (triagens::rest::HttpR
/// @brief authenticate a request
////////////////////////////////////////////////////////////////////////////////
bool VocbaseManager::authenticate (TRI_vocbase_t* vocbase,
triagens::rest::HttpRequest* request) {
if (! vocbase) {
// unknown vocbase
return false;
}
HttpResponse::HttpResponseCode VocbaseManager::authenticate (TRI_vocbase_t* vocbase,
triagens::rest::HttpRequest* request) {
assert(vocbase != 0);
std::map<TRI_vocbase_t*, std::map<std::string, std::string> >::iterator mapIter;
bool found;
char const* auth = request->header("authorization", found);
if (found) {
if (! TRI_CaseEqualString2(auth, "basic ", 6)) {
return false;
}
auth += 6;
while (*auth == ' ') {
++auth;
}
{
READ_LOCKER(_rwLock);
mapIter = _authCache.find(vocbase);
if (mapIter == _authCache.end()) {
// unknown vocbase
return false;
}
map<string,string>::iterator i = mapIter->second.find(auth);
if (i != mapIter->second.end()) {
request->setUser(i->second);
return true;
}
}
string up = StringUtils::decodeBase64(auth);
std::string::size_type n = up.find(':', 0);
if (n == std::string::npos || n == 0 || n + 1 > up.size()) {
LOGGER_TRACE("invalid authentication data found, cannot extract username/password");
return false;
}
const string username = up.substr(0, n);
LOGGER_TRACE("checking authentication for user '" << username << "'");
bool res = TRI_CheckAuthenticationAuthInfo2(vocbase, username.c_str(), up.substr(n + 1).c_str());
if (res) {
WRITE_LOCKER(_rwLock);
mapIter = _authCache.find(vocbase);
if (mapIter == _authCache.end()) {
// unknown vocbase
return false;
}
mapIter->second[auth] = username;
request->setUser(username);
// TODO: create a user object for the VocbaseContext
}
return res;
if (! found || ! TRI_CaseEqualString2(auth, "basic ", 6)) {
return HttpResponse::UNAUTHORIZED;
}
return false;
// skip over "basic "
auth += 6;
while (*auth == ' ') {
++auth;
}
{
READ_LOCKER(_rwLock);
mapIter = _authCache.find(vocbase);
if (mapIter == _authCache.end()) {
// unknown vocbase
return HttpResponse::NOT_FOUND;
}
map<string, string>::iterator i = mapIter->second.find(auth);
if (i != mapIter->second.end()) {
request->setUser(i->second);
return HttpResponse::OK;
}
}
string up = StringUtils::decodeBase64(auth);
std::string::size_type n = up.find(':', 0);
if (n == std::string::npos || n == 0 || n + 1 > up.size()) {
LOGGER_TRACE("invalid authentication data found, cannot extract username/password");
return HttpResponse::BAD;
}
const string username = up.substr(0, n);
LOGGER_TRACE("checking authentication for user '" << username << "'");
bool res = TRI_CheckAuthenticationAuthInfo2(vocbase, username.c_str(), up.substr(n + 1).c_str());
if (! res) {
return HttpResponse::UNAUTHORIZED;
}
WRITE_LOCKER(_rwLock);
mapIter = _authCache.find(vocbase);
if (mapIter == _authCache.end()) {
// unknown vocbase
return HttpResponse::UNAUTHORIZED;
}
mapIter->second[auth] = username;
request->setUser(username);
// TODO: create a user object for the VocbaseContext
return HttpResponse::OK;
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -32,10 +32,11 @@
#include "BasicsC/win-utils.h"
#endif
#include "VocBase/vocbase.h"
#include "Rest/HttpRequest.h"
#include "Basics/ReadLocker.h"
#include "Basics/WriteLocker.h"
#include "Rest/HttpRequest.h"
#include "Rest/HttpResponse.h"
#include "VocBase/vocbase.h"
#include <map>
#include <string>
@ -182,8 +183,8 @@ namespace triagens {
/// @brief authenticate a request
////////////////////////////////////////////////////////////////////////////////
bool authenticate (TRI_vocbase_t*,
triagens::rest::HttpRequest*);
rest::HttpResponse::HttpResponseCode authenticate (TRI_vocbase_t*,
triagens::rest::HttpRequest*);
////////////////////////////////////////////////////////////////////////////////
/// @brief reload auth info

View File

@ -105,6 +105,7 @@ ArangoClient::ArangoClient ()
_disableAuthentication(false),
_endpointString(),
_endpointServer(0),
_databaseName("_system"),
_username("root"),
_password(""),
_hasPassword(false),
@ -233,7 +234,8 @@ void ArangoClient::setupServer (ProgramOptionsDescription& description) {
ProgramOptionsDescription clientOptions("CLIENT options");
clientOptions
("server.disable-authentication", &_disableAuthentication, "disable authentication")
("server.database", &_databaseName, "database name to use when connecting")
("server.disable-authentication", &_disableAuthentication, "disable authentication (will disable password prompt)")
("server.endpoint", &_endpointString, "endpoint to connect to, use 'none' to start without a server")
("server.username", &_username, "username to use when connecting")
("server.password", &_password, "password to use when connecting (leave empty for prompt)")
@ -675,6 +677,14 @@ Endpoint* ArangoClient::endpointServer() const {
return _endpointServer;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
string const& ArangoClient::databaseName () const {
return _databaseName;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief user to send to endpoint
////////////////////////////////////////////////////////////////////////////////
@ -691,6 +701,14 @@ string const& ArangoClient::password () const {
return _password;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief set database name
////////////////////////////////////////////////////////////////////////////////
void ArangoClient::setDatabaseName (string const& databaseName) {
_databaseName = databaseName;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief set username
////////////////////////////////////////////////////////////////////////////////

View File

@ -343,6 +343,18 @@ namespace triagens {
triagens::rest::Endpoint* endpointServer() const;
////////////////////////////////////////////////////////////////////////////////
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
string const& databaseName () const;
////////////////////////////////////////////////////////////////////////////////
/// @brief set database name
////////////////////////////////////////////////////////////////////////////////
void setDatabaseName (string const&);
////////////////////////////////////////////////////////////////////////////////
/// @brief user to send to endpoint
////////////////////////////////////////////////////////////////////////////////
@ -518,6 +530,12 @@ namespace triagens {
triagens::rest::Endpoint* _endpointServer;
////////////////////////////////////////////////////////////////////////////////
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
string _databaseName;
////////////////////////////////////////////////////////////////////////////////
/// @brief user to send to endpoint
////////////////////////////////////////////////////////////////////////////////

View File

@ -78,6 +78,7 @@ namespace triagens {
const unsigned long batchSize,
BenchmarkCounter<unsigned long>* operationsCounter,
Endpoint* endpoint,
const string& databaseName,
const string& username,
const string& password,
double requestTimeout,
@ -91,6 +92,7 @@ namespace triagens {
_warningCount(0),
_operationsCounter(operationsCounter),
_endpoint(endpoint),
_databaseName(databaseName),
_username(username),
_password(password),
_requestTimeout(requestTimeout),
@ -146,6 +148,13 @@ namespace triagens {
}
_client = new SimpleHttpClient(_connection, 10.0, true);
if (_client == 0) {
LOGGER_FATAL_AND_EXIT("out of memory");
}
_client->setLocationRewriter(this, &rewriteLocation);
_client->setUserNamePassword("/", _username, _password);
// test the connection
@ -213,6 +222,28 @@ namespace triagens {
private:
////////////////////////////////////////////////////////////////////////////////
/// @brief request location rewriter (injects database name)
////////////////////////////////////////////////////////////////////////////////
static string rewriteLocation (void* data, const string& location) {
BenchmarkThread* t = static_cast<BenchmarkThread*>(data);
assert(t != 0);
if (location.substr(0, 5) == "/_db/") {
// location already contains /_db/
return location;
}
if (location[0] == '/') {
return "/_db/" + t->_databaseName + location;
}
else {
return "/_db/" + t->_databaseName + "/" + location;
}
}
////////////////////////////////////////////////////////////////////////////////
/// @brief execute a batch request with numOperations parts
////////////////////////////////////////////////////////////////////////////////
@ -442,6 +473,12 @@ namespace triagens {
Endpoint* _endpoint;
////////////////////////////////////////////////////////////////////////////////
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
const string _databaseName;
////////////////////////////////////////////////////////////////////////////////
/// @brief HTTP username
////////////////////////////////////////////////////////////////////////////////
@ -470,7 +507,7 @@ namespace triagens {
/// @brief underlying client
////////////////////////////////////////////////////////////////////////////////
triagens::httpclient::SimpleClient* _client;
triagens::httpclient::SimpleHttpClient* _client;
////////////////////////////////////////////////////////////////////////////////
/// @brief connection to the server

View File

@ -290,6 +290,7 @@ int main (int argc, char* argv[]) {
(unsigned long) BatchSize,
&operationsCounter,
endpoint,
BaseClient.databaseName(),
BaseClient.username(),
BaseClient.password(),
BaseClient.requestTimeout(),
@ -348,7 +349,7 @@ int main (int argc, char* argv[]) {
cout << endl;
cout << "Total number of operations: " << Operations << ", batch size: " << BatchSize << ", concurrency level (threads): " << Concurrency << endl;
cout << "Test case: " << TestCase << ", complexity: " << Complexity << ", collection: '" << Collection << "'" << endl;
cout << "Test case: " << TestCase << ", complexity: " << Complexity << ", database: '" << BaseClient.databaseName() << "', collection: '" << Collection << "'" << endl;
cout << "Total request/response duration (sum of all threads): " << fixed << requestTime << " s" << endl;
cout << "Request/response duration (per thread): " << fixed << (requestTime / (double) Concurrency) << " s" << endl;
cout << "Time needed per operation: " << fixed << (time / Operations) << " s" << endl;

View File

@ -61,6 +61,7 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
V8ClientConnection::V8ClientConnection (Endpoint* endpoint,
string databaseName,
const string& username,
const string& password,
double requestTimeout,
@ -68,6 +69,7 @@ V8ClientConnection::V8ClientConnection (Endpoint* endpoint,
size_t numRetries,
bool warn)
: _connection(0),
_databaseName(databaseName),
_lastHttpReturnCode(0),
_lastErrorMessage(""),
_client(0),
@ -83,9 +85,11 @@ V8ClientConnection::V8ClientConnection (Endpoint* endpoint,
_client = new SimpleHttpClient(_connection, requestTimeout, warn);
if (_client == 0) {
throw "out of memory";
LOGGER_FATAL_AND_EXIT("out of memory");
}
_client->setLocationRewriter(this, &rewriteLocation);
_client->setUserNamePassword("/", username, password);
// connect to server and get version number
@ -167,6 +171,28 @@ V8ClientConnection::~V8ClientConnection () {
/// @{
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// @brief request location rewriter (injects database name)
////////////////////////////////////////////////////////////////////////////////
string V8ClientConnection::rewriteLocation (void* data, const string& location) {
V8ClientConnection* c = static_cast<V8ClientConnection*>(data);
assert(c != 0);
if (location.substr(0, 5) == "/_db/") {
// location already contains /_db/
return location;
}
if (location[0] == '/') {
return "/_db/" + c->_databaseName + location;
}
else {
return "/_db/" + c->_databaseName + "/" + location;
}
}
////////////////////////////////////////////////////////////////////////////////
/// @brief returns true if it is connected
////////////////////////////////////////////////////////////////////////////////
@ -175,6 +201,22 @@ bool V8ClientConnection::isConnected () {
return _connection->isConnected();
}
////////////////////////////////////////////////////////////////////////////////
/// @brief returns the current database name
////////////////////////////////////////////////////////////////////////////////
const string& V8ClientConnection::getDatabaseName () {
return _databaseName;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief set the current database name
////////////////////////////////////////////////////////////////////////////////
void V8ClientConnection::setDatabaseName (const string& databaseName) {
_databaseName = databaseName;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief returns the version and build number of the arango server
////////////////////////////////////////////////////////////////////////////////

View File

@ -91,6 +91,7 @@ namespace triagens {
////////////////////////////////////////////////////////////////////////////////
V8ClientConnection (triagens::rest::Endpoint*,
string,
const string&,
const string&,
double,
@ -119,12 +120,30 @@ namespace triagens {
public:
////////////////////////////////////////////////////////////////////////////////
/// @brief request location rewriter (injects database name)
////////////////////////////////////////////////////////////////////////////////
static string rewriteLocation (void*, const string&);
////////////////////////////////////////////////////////////////////////////////
/// @brief returns true if it is connected
////////////////////////////////////////////////////////////////////////////////
bool isConnected ();
////////////////////////////////////////////////////////////////////////////////
/// @brief returns the current database name
////////////////////////////////////////////////////////////////////////////////
const string& getDatabaseName ();
////////////////////////////////////////////////////////////////////////////////
/// @brief set the current database name
////////////////////////////////////////////////////////////////////////////////
void setDatabaseName (const string&);
////////////////////////////////////////////////////////////////////////////////
/// @brief returns the version and build number of the arango server
////////////////////////////////////////////////////////////////////////////////
@ -344,6 +363,12 @@ namespace triagens {
triagens::httpclient::GeneralClientConnection* _connection;
////////////////////////////////////////////////////////////////////////////////
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
std::string _databaseName;
////////////////////////////////////////////////////////////////////////////////
/// @brief server version
////////////////////////////////////////////////////////////////////////////////

View File

@ -284,10 +284,11 @@ int main (int argc, char* argv[]) {
if (BaseClient.endpointServer() == 0) {
cerr << "invalid value for --server.endpoint ('" << BaseClient.endpointString() << "')" << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
ClientConnection = new V8ClientConnection(BaseClient.endpointServer(),
BaseClient.databaseName(),
BaseClient.username(),
BaseClient.password(),
BaseClient.requestTimeout(),
@ -298,7 +299,7 @@ int main (int argc, char* argv[]) {
if (! ClientConnection->isConnected() || ClientConnection->getLastHttpReturnCode() != HttpResponse::OK) {
cerr << "Could not connect to endpoint " << BaseClient.endpointServer()->getSpecification() << endl;
cerr << "Error message: '" << ClientConnection->getErrorMessage() << "'" << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
// successfully connected
@ -306,6 +307,7 @@ int main (int argc, char* argv[]) {
<< "' Version " << ClientConnection->getVersion() << endl;
cout << "----------------------------------------" << endl;
cout << "database: " << BaseClient.databaseName() << endl;
cout << "collection: " << CollectionName << endl;
cout << "create: " << (CreateCollection ? "yes" : "no") << endl;
cout << "file: " << FileName << endl;
@ -333,7 +335,7 @@ int main (int argc, char* argv[]) {
}
else {
cerr << "Wrong length of quote character." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
// separator
@ -342,24 +344,24 @@ int main (int argc, char* argv[]) {
}
else {
cerr << "Separator must be exactly one character." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
// collection name
if (CollectionName == "") {
cerr << "collection name is missing." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
// filename
if (FileName == "") {
cerr << "file name is missing." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
if (FileName != "-" && ! FileUtils::isRegularFile(FileName)) {
cerr << "file '" << FileName << "' is not a regular file." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
// progress
@ -389,7 +391,7 @@ int main (int argc, char* argv[]) {
else {
cerr << "Wrong type '" << TypeImport << "'." << endl;
TRI_EXIT_FUNCTION(EXIT_FAILURE,NULL);
TRI_EXIT_FUNCTION(EXIT_FAILURE, NULL);
}
cout << endl;

View File

@ -386,6 +386,7 @@ static v8::Handle<v8::Value> JS_compare_string (v8::Arguments const& argv) {
static V8ClientConnection* CreateConnection () {
return new V8ClientConnection(BaseClient.endpointServer(),
BaseClient.databaseName(),
BaseClient.username(),
BaseClient.password(),
BaseClient.requestTimeout(),
@ -554,22 +555,24 @@ static v8::Handle<v8::Value> ClientConnection_reconnect (v8::Arguments const& ar
TRI_V8_EXCEPTION_INTERNAL(scope, "connection class corrupted");
}
if (argv.Length() < 1) {
TRI_V8_EXCEPTION_USAGE(scope, "reconnect(<endpoint>[, <username>, <password>])");
if (argv.Length() < 2) {
TRI_V8_EXCEPTION_USAGE(scope, "reconnect(<endpoint>, <databasename>, [, <username>, <password>])");
}
string definition = TRI_ObjectToString(argv[0]);
string databaseName = TRI_ObjectToString(argv[1]);
string username;
if (argv.Length() < 2) {
if (argv.Length() < 3) {
username = BaseClient.username();
}
else {
username = TRI_ObjectToString(argv[1]);
username = TRI_ObjectToString(argv[2]);
}
string password;
if (argv.Length() < 3) {
if (argv.Length() < 4) {
cout << "Please specify a password: " << flush;
// now prompt for it
@ -584,16 +587,18 @@ static v8::Handle<v8::Value> ClientConnection_reconnect (v8::Arguments const& ar
cout << "\n";
}
else {
password = TRI_ObjectToString(argv[2]);
password = TRI_ObjectToString(argv[3]);
}
const string oldDefinition = BaseClient.endpointString();
const string oldUsername = BaseClient.username();
const string oldPassword = BaseClient.password();
const string oldDefinition = BaseClient.endpointString();
const string oldDatabaseName = BaseClient.databaseName();
const string oldUsername = BaseClient.username();
const string oldPassword = BaseClient.password();
delete connection;
BaseClient.setEndpointString(definition);
BaseClient.setDatabaseName(databaseName);
BaseClient.setUsername(username);
BaseClient.setPassword(password);
@ -601,6 +606,7 @@ static v8::Handle<v8::Value> ClientConnection_reconnect (v8::Arguments const& ar
BaseClient.createEndpoint();
if (BaseClient.endpointServer() == 0) {
BaseClient.setEndpointString(oldDefinition);
BaseClient.setDatabaseName(oldDatabaseName);
BaseClient.setUsername(oldUsername);
BaseClient.setPassword(oldPassword);
BaseClient.createEndpoint();
@ -644,6 +650,7 @@ static v8::Handle<v8::Value> ClientConnection_reconnect (v8::Arguments const& ar
// rollback
BaseClient.setEndpointString(oldDefinition);
BaseClient.setDatabaseName(oldDatabaseName);
BaseClient.setUsername(oldUsername);
BaseClient.setPassword(oldPassword);
BaseClient.createEndpoint();
@ -1173,6 +1180,50 @@ static v8::Handle<v8::Value> ClientConnection_getVersion (v8::Arguments const& a
return scope.Close(v8::String::New(connection->getVersion().c_str()));
}
////////////////////////////////////////////////////////////////////////////////
/// @brief ClientConnection method "getDatabaseName"
////////////////////////////////////////////////////////////////////////////////
static v8::Handle<v8::Value> ClientConnection_getDatabaseName (v8::Arguments const& argv) {
v8::HandleScope scope;
// get the connection
V8ClientConnection* connection = TRI_UnwrapClass<V8ClientConnection>(argv.Holder(), WRAP_TYPE_CONNECTION);
if (connection == 0) {
TRI_V8_EXCEPTION_INTERNAL(scope, "connection class corrupted");
}
if (argv.Length() != 0) {
TRI_V8_EXCEPTION_USAGE(scope, "getDatabaseName()");
}
return scope.Close(v8::String::New(connection->getDatabaseName().c_str()));
}
////////////////////////////////////////////////////////////////////////////////
/// @brief ClientConnection method "setDatabaseName"
////////////////////////////////////////////////////////////////////////////////
static v8::Handle<v8::Value> ClientConnection_setDatabaseName (v8::Arguments const& argv) {
v8::HandleScope scope;
// get the connection
V8ClientConnection* connection = TRI_UnwrapClass<V8ClientConnection>(argv.Holder(), WRAP_TYPE_CONNECTION);
if (connection == 0) {
TRI_V8_EXCEPTION_INTERNAL(scope, "connection class corrupted");
}
if (argv.Length() != 1 || ! argv[0]->IsString()) {
TRI_V8_EXCEPTION_USAGE(scope, "setDatabaseName(<name>)");
}
connection->setDatabaseName(TRI_ObjectToString(argv[0]));
return scope.Close(v8::True());
}
////////////////////////////////////////////////////////////////////////////////
/// @brief executes the shell
////////////////////////////////////////////////////////////////////////////////
@ -1646,6 +1697,8 @@ int main (int argc, char* argv[]) {
connection_proto->Set("reconnect", v8::FunctionTemplate::New(ClientConnection_reconnect));
connection_proto->Set("toString", v8::FunctionTemplate::New(ClientConnection_toString));
connection_proto->Set("getVersion", v8::FunctionTemplate::New(ClientConnection_getVersion));
connection_proto->Set("getDatabaseName", v8::FunctionTemplate::New(ClientConnection_getDatabaseName));
connection_proto->Set("setDatabaseName", v8::FunctionTemplate::New(ClientConnection_setDatabaseName));
connection_proto->SetCallAsFunctionHandler(ClientConnection_ConstructorCallback);
v8::Handle<v8::ObjectTemplate> connection_inst = connection_templ->InstanceTemplate();

View File

@ -263,8 +263,9 @@ ArangoDatabase.prototype.toString = function () {
/// @brief return all collections from the database
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._collections = function () {
var requestResult = this._connection.GET(this._collectionurl());
ArangoDatabase.prototype._collections = function (excludeSystem) {
var append = (excludeSystem ? "?excludeSystem=true" : "");
var requestResult = this._connection.GET(this._collectionurl() + append);
arangosh.checkRequestResult(requestResult);
@ -452,8 +453,8 @@ ArangoDatabase.prototype._flushCache = function () {
/// @brief query the database properties
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._queryProperties = function () {
if (this._properties === null) {
ArangoDatabase.prototype._queryProperties = function (force) {
if (force || this._properties === null) {
var requestResult = this._connection.GET("/_api/current-database");
arangosh.checkRequestResult(requestResult);
@ -844,6 +845,37 @@ ArangoDatabase.prototype._listDatabases = function () {
return requestResult.result;
};
////////////////////////////////////////////////////////////////////////////////
/// @brief uses a database
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._useDatabase = function (name) {
var old = this._connection.getDatabaseName();
this._connection.setDatabaseName(name);
try {
// re-query properties
this._queryProperties(true);
}
catch (err) {
this._connection.setDatabaseName(old);
if (err.hasOwnProperty("errorNum")) {
throw err;
}
throw new ArangoError({
error: true,
code: internal.errors.ERROR_BAD_PARAMETER.code,
errorNum: internal.errors.ERROR_BAD_PARAMETER.code,
errorMessage: "cannot use database '" + name + "'"
});
}
return true;
};
////////////////////////////////////////////////////////////////////////////////
/// @}
////////////////////////////////////////////////////////////////////////////////

View File

@ -262,8 +262,9 @@ ArangoDatabase.prototype.toString = function () {
/// @brief return all collections from the database
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._collections = function () {
var requestResult = this._connection.GET(this._collectionurl());
ArangoDatabase.prototype._collections = function (excludeSystem) {
var append = (excludeSystem ? "?excludeSystem=true" : "");
var requestResult = this._connection.GET(this._collectionurl() + append);
arangosh.checkRequestResult(requestResult);
@ -451,8 +452,8 @@ ArangoDatabase.prototype._flushCache = function () {
/// @brief query the database properties
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._queryProperties = function () {
if (this._properties === null) {
ArangoDatabase.prototype._queryProperties = function (force) {
if (force || this._properties === null) {
var requestResult = this._connection.GET("/_api/current-database");
arangosh.checkRequestResult(requestResult);
@ -843,6 +844,43 @@ ArangoDatabase.prototype._listDatabases = function () {
return requestResult.result;
};
////////////////////////////////////////////////////////////////////////////////
/// @brief uses a database
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._useDatabase = function (name) {
var old = this._connection.getDatabaseName();
// no change
if (name === old) {
return true;
}
this._connection.setDatabaseName(name);
try {
// re-query properties
this._queryProperties(true);
this._flushCache();
}
catch (err) {
this._connection.setDatabaseName(old);
if (err.hasOwnProperty("errorNum")) {
throw err;
}
throw new ArangoError({
error: true,
code: internal.errors.ERROR_BAD_PARAMETER.code,
errorNum: internal.errors.ERROR_BAD_PARAMETER.code,
errorMessage: "cannot use database '" + name + "'"
});
}
return true;
};
////////////////////////////////////////////////////////////////////////////////
/// @}
////////////////////////////////////////////////////////////////////////////////

View File

@ -47,7 +47,7 @@ function AuthSuite () {
////////////////////////////////////////////////////////////////////////////////
setUp : function () {
arango.reconnect(arango.getEndpoint(), "root", "");
arango.reconnect(arango.getEndpoint(), db._name(), "root", "");
try {
users.remove("hackers@arangodb.org");
@ -76,21 +76,21 @@ function AuthSuite () {
users.save("hackers@arangodb.org", "foobar");
users.reload();
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "foobar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "foobar");
// this will issue a request using the new user
assertTrue(db._collections().length > 0);
// double check with wrong passwords
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "foobar2");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "foobar2");
fail();
}
catch (err1) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "");
fail();
}
catch (err2) {
@ -105,14 +105,14 @@ function AuthSuite () {
users.save("hackers@arangodb.org", "");
users.reload();
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "");
// this will issue a request using the new user
assertTrue(db._collections().length > 0);
// double check with wrong password
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "foobar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "foobar");
fail();
}
catch (err1) {
@ -127,28 +127,28 @@ function AuthSuite () {
users.save("hackers@arangodb.org", "FooBar");
users.reload();
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "FooBar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "FooBar");
// this will issue a request using the new user
assertTrue(db._collections().length > 0);
// double check with wrong passwords
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "Foobar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "Foobar");
fail();
}
catch (err1) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "foobar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "foobar");
fail();
}
catch (err2) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "FOOBAR");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "FOOBAR");
fail();
}
catch (err3) {
@ -163,28 +163,28 @@ function AuthSuite () {
users.save("hackers@arangodb.org", "fuxx::bar");
users.reload();
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "fuxx::bar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "fuxx::bar");
// this will issue a request using the new user
assertTrue(db._collections().length > 0);
// double check with wrong passwords
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "fuxx");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "fuxx");
fail();
}
catch (err1) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "bar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "bar");
fail();
}
catch (err2) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "");
fail();
}
catch (err3) {
@ -199,28 +199,28 @@ function AuthSuite () {
users.save("hackers@arangodb.org", ":\\abc'def:foobar@04. x-a");
users.reload();
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", ":\\abc'def:foobar@04. x-a");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", ":\\abc'def:foobar@04. x-a");
// this will issue a request using the new user
assertTrue(db._collections().length > 0);
// double check with wrong passwords
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "foobar");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "foobar");
fail();
}
catch (err1) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "\\abc'def: x-a");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "\\abc'def: x-a");
fail();
}
catch (err2) {
}
try {
arango.reconnect(arango.getEndpoint(), "hackers@arangodb.org", "");
arango.reconnect(arango.getEndpoint(), db._name(), "hackers@arangodb.org", "");
fail();
}
catch (err3) {

View File

@ -151,6 +151,14 @@ ArangoDatabase.prototype._listDatabases = function () {
return LIST_DATABASES();
}
////////////////////////////////////////////////////////////////////////////////
/// @brief use a different database
////////////////////////////////////////////////////////////////////////////////
ArangoDatabase.prototype._useDatabase = function (name) {
return USE_DATABASES(name);
}
////////////////////////////////////////////////////////////////////////////////
/// @}
////////////////////////////////////////////////////////////////////////////////

View File

@ -57,11 +57,11 @@ function ReplicationSuite () {
var replicatorPassword = "replicator-password";
var connectToMaster = function () {
arango.reconnect(masterEndpoint, replicatorUser, replicatorPassword);
arango.reconnect(masterEndpoint, db._name(), replicatorUser, replicatorPassword);
};
var connectToSlave = function () {
arango.reconnect(slaveEndpoint, "root", "");
arango.reconnect(slaveEndpoint, db._name(), "root", "");
};
var collectionChecksum = function (name) {

View File

@ -463,11 +463,11 @@ namespace triagens {
// authenticate
// .............................................................................
bool auth = this->_server->getHandlerFactory()->authenticateRequest(this->_request);
HttpResponse::HttpResponseCode authResult = this->_server->getHandlerFactory()->authenticateRequest(this->_request);
// authenticated
// or an HTTP OPTIONS request. OPTIONS requests currently go unauthenticated
if (auth || this->_requestType == HttpRequest::HTTP_REQUEST_OPTIONS) {
if (authResult == HttpResponse::OK || this->_requestType == HttpRequest::HTTP_REQUEST_OPTIONS) {
// handle HTTP OPTIONS requests directly
if (this->_requestType == HttpRequest::HTTP_REQUEST_OPTIONS) {
@ -539,6 +539,14 @@ namespace triagens {
}
}
// not found
else if (authResult == HttpResponse::NOT_FOUND) {
HttpResponse response(HttpResponse::NOT_FOUND);
this->handleResponse(&response);
this->resetState();
}
// not authenticated
else {
const string realm = "basic realm=\"" + this->_server->getHandlerFactory()->authenticationRealm(this->_request) + "\"";

View File

@ -143,15 +143,19 @@ pair<size_t, size_t> HttpHandlerFactory::sizeRestrictions () const {
/// disabled authentication etc.
////////////////////////////////////////////////////////////////////////////////
bool HttpHandlerFactory::authenticateRequest (HttpRequest* request) {
HttpResponse::HttpResponseCode HttpHandlerFactory::authenticateRequest (HttpRequest* request) {
RequestContext* rc = request->getRequestContext();
if (! rc) {
if (! setRequestContext(request)) {
return false;
return HttpResponse::NOT_FOUND;
}
}
if (! rc) {
return HttpResponse::NOT_FOUND;
}
return rc->authenticate();
}

View File

@ -32,6 +32,7 @@
#include "Basics/Mutex.h"
#include "Basics/ReadWriteLock.h"
#include "Rest/HttpResponse.h"
// -----------------------------------------------------------------------------
// --SECTION-- forward declarations
@ -183,7 +184,7 @@ namespace triagens {
/// @brief authenticates a new request, wrapper method
////////////////////////////////////////////////////////////////////////////////
virtual bool authenticateRequest (HttpRequest * request);
virtual HttpResponse::HttpResponseCode authenticateRequest (HttpRequest * request);
////////////////////////////////////////////////////////////////////////////////
/// @brief set request context, wrapper method

View File

@ -71,7 +71,7 @@ HttpRequest::HttpRequest (ConnectionInfo const& info, char const* header, size_t
_prefix(),
_suffix(),
_version(HTTP_UNKNOWN),
_dbName(),
_databaseName(),
_user(),
_requestContext(0) {
@ -103,7 +103,7 @@ HttpRequest::HttpRequest ()
_prefix(),
_suffix(),
_version(HTTP_UNKNOWN),
_dbName(),
_databaseName(),
_user(),
_requestContext(0) {
}
@ -589,8 +589,8 @@ void HttpRequest::setRequestType (HttpRequestType newType) {
/// @brief returns the database name
////////////////////////////////////////////////////////////////////////////////
string const& HttpRequest::dbName () const {
return _dbName;
string const& HttpRequest::databaseName () const {
return _databaseName;
}
////////////////////////////////////////////////////////////////////////////////
@ -834,7 +834,7 @@ void HttpRequest::parseHeader (char* ptr, size_t length) {
++q;
}
_dbName = string(pathBegin, q - pathBegin);
_databaseName = string(pathBegin, q - pathBegin);
pathBegin = q;
}

View File

@ -226,7 +226,7 @@ namespace triagens {
/// @brief returns the database name
////////////////////////////////////////////////////////////////////////////////
std::string const& dbName () const;
std::string const& databaseName () const;
////////////////////////////////////////////////////////////////////////////////
/// @brief returns the authenticated user
@ -658,7 +658,7 @@ namespace triagens {
/// @brief database name
////////////////////////////////////////////////////////////////////////////////
string _dbName;
string _databaseName;
////////////////////////////////////////////////////////////////////////////////
/// @brief authenticated user

View File

@ -30,6 +30,7 @@
#include "Rest/RequestUser.h"
#include "Rest/HttpRequest.h"
#include "Rest/HttpResponse.h"
namespace triagens {
namespace rest {
@ -97,7 +98,7 @@ namespace triagens {
/// @brief authenticate user
////////////////////////////////////////////////////////////////////////////////
virtual bool authenticate () = 0;
virtual HttpResponse::HttpResponseCode authenticate () = 0;
////////////////////////////////////////////////////////////////////////////////
/// @}

View File

@ -52,7 +52,14 @@ namespace triagens {
SimpleHttpClient::SimpleHttpClient (GeneralClientConnection* connection,
double requestTimeout,
bool warn) :
SimpleClient(connection, requestTimeout, warn), _result(0), _maxPacketSize(128 * 1024 * 1024) {
SimpleClient(connection, requestTimeout, warn),
_locationRewriter(),
_result(0),
_maxPacketSize(128 * 1024 * 1024) {
// waiting for C++11...
_locationRewriter.func = 0;
_locationRewriter.data = 0;
}
SimpleHttpClient::~SimpleHttpClient () {
@ -73,8 +80,8 @@ namespace triagens {
_result = new SimpleHttpResult;
_errorMessage = "";
// set body to all connections
setRequest(method, location, body, bodyLength, headerFields);
// set body
setRequest(method, rewriteLocation(location), body, bodyLength, headerFields);
double endTime = now() + _requestTimeout;
double remainingTime = _requestTimeout;

View File

@ -91,6 +91,16 @@ namespace triagens {
const string& username,
const string& password);
////////////////////////////////////////////////////////////////////////////////
/// @brief allows rewriting locations
////////////////////////////////////////////////////////////////////////////////
void setLocationRewriter (void* data,
std::string (*func)(void*, const std::string&)) {
_locationRewriter.data = data;
_locationRewriter.func = func;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief reset state
////////////////////////////////////////////////////////////////////////////////
@ -99,6 +109,18 @@ namespace triagens {
private:
////////////////////////////////////////////////////////////////////////////////
/// @brief rewrite a location URL
////////////////////////////////////////////////////////////////////////////////
string rewriteLocation (const string& location) {
if (_locationRewriter.func != 0) {
return _locationRewriter.func(_locationRewriter.data, location);
}
return location;
}
////////////////////////////////////////////////////////////////////////////////
/// @brief get the result
/// the caller has to delete the result object
@ -147,6 +169,16 @@ namespace triagens {
bool readChunkedBody ();
private:
////////////////////////////////////////////////////////////////////////////////
/// @brief struct for rewriting location URLs
////////////////////////////////////////////////////////////////////////////////
struct {
void* data;
std::string (*func)(void*, const std::string&);
}
_locationRewriter;
uint32_t _nextChunkedSize;