00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _PASSENGER_MESSAGE_SERVER_H_
00026 #define _PASSENGER_MESSAGE_SERVER_H_
00027
00028 #include <string>
00029 #include <vector>
00030
00031 #include <boost/shared_ptr.hpp>
00032 #include <boost/thread.hpp>
00033 #include <oxt/system_calls.hpp>
00034 #include <oxt/dynamic_thread_group.hpp>
00035
00036 #include <sys/types.h>
00037 #include <sys/stat.h>
00038 #include <sys/un.h>
00039 #include <unistd.h>
00040 #include <cerrno>
00041 #include <cassert>
00042
00043 #include "Account.h"
00044 #include "AccountsDatabase.h"
00045 #include "Constants.h"
00046 #include "FileDescriptor.h"
00047 #include "MessageChannel.h"
00048 #include "Logging.h"
00049 #include "Exceptions.h"
00050 #include "Utils/StrIntUtils.h"
00051 #include "Utils/IOUtils.h"
00052
00053 namespace Passenger {
00054
00055 using namespace std;
00056 using namespace boost;
00057 using namespace oxt;
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158 class MessageServer {
00159 public:
00160 static const unsigned int CLIENT_THREAD_STACK_SIZE =
00161 #ifdef __FreeBSD__
00162
00163 1024 * 96;
00164 #else
00165 1024 * 64;
00166 #endif
00167
00168
00169 class ClientContext {
00170 public:
00171 virtual ~ClientContext() { }
00172 };
00173
00174 typedef shared_ptr<ClientContext> ClientContextPtr;
00175
00176
00177
00178
00179
00180 class CommonClientContext: public ClientContext {
00181 public:
00182
00183 FileDescriptor fd;
00184
00185
00186 MessageChannel channel;
00187
00188
00189 AccountPtr account;
00190
00191
00192 CommonClientContext(FileDescriptor &theFd, AccountPtr &theAccount)
00193 : fd(theFd), channel(theFd), account(theAccount)
00194 { }
00195
00196
00197 string name() {
00198 return toString(channel.filenum());
00199 }
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210 void requireRights(Account::Rights rights) {
00211 if (!account->hasRights(rights)) {
00212 P_TRACE(2, "Security error: insufficient rights to execute this command.");
00213 channel.write("SecurityException", "Insufficient rights to execute this command.", NULL);
00214 throw SecurityException("Insufficient rights to execute this command.");
00215 } else {
00216 channel.write("Passed security", NULL);
00217 }
00218 }
00219 };
00220
00221
00222
00223
00224
00225
00226
00227
00228 class Handler {
00229 public:
00230 virtual ~Handler() { }
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241 virtual ClientContextPtr newClient(CommonClientContext &context) {
00242 return ClientContextPtr();
00243 }
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256 virtual void clientDisconnected(MessageServer::CommonClientContext &context,
00257 MessageServer::ClientContextPtr &handlerSpecificContext)
00258 { }
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272 virtual bool processMessage(CommonClientContext &commonContext,
00273 ClientContextPtr &handlerSpecificContext,
00274 const vector<string> &args) = 0;
00275 };
00276
00277 typedef shared_ptr<Handler> HandlerPtr;
00278
00279 protected:
00280
00281 string socketFilename;
00282
00283
00284 AccountsDatabasePtr accountsDatabase;
00285
00286
00287 vector<HandlerPtr> handlers;
00288
00289
00290
00291
00292
00293
00294 unsigned long long loginTimeout;
00295
00296
00297 dynamic_thread_group threadGroup;
00298
00299
00300
00301
00302 int serverFd;
00303
00304
00305
00306 struct DisconnectEventBroadcastGuard {
00307 vector<HandlerPtr> &handlers;
00308 CommonClientContext &commonContext;
00309 vector<ClientContextPtr> &handlerSpecificContexts;
00310
00311 DisconnectEventBroadcastGuard(vector<HandlerPtr> &_handlers,
00312 CommonClientContext &_commonContext,
00313 vector<ClientContextPtr> &_handlerSpecificContexts)
00314 : handlers(_handlers),
00315 commonContext(_commonContext),
00316 handlerSpecificContexts(_handlerSpecificContexts)
00317 { }
00318
00319 ~DisconnectEventBroadcastGuard() {
00320 vector<HandlerPtr>::iterator handler_iter;
00321 vector<ClientContextPtr>::iterator context_iter;
00322
00323 for (handler_iter = handlers.begin(), context_iter = handlerSpecificContexts.begin();
00324 handler_iter != handlers.end();
00325 handler_iter++, context_iter++) {
00326 (*handler_iter)->clientDisconnected(commonContext, *context_iter);
00327 }
00328 }
00329 };
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339
00340 void startListening() {
00341 TRACE_POINT();
00342 int ret;
00343
00344 serverFd = createUnixServer(socketFilename.c_str());
00345 do {
00346 ret = chmod(socketFilename.c_str(),
00347 S_ISVTX |
00348 S_IRUSR | S_IWUSR | S_IXUSR |
00349 S_IRGRP | S_IWGRP | S_IXGRP |
00350 S_IROTH | S_IWOTH | S_IXOTH);
00351 } while (ret == -1 && errno == EINTR);
00352 }
00353
00354
00355
00356
00357
00358
00359 AccountPtr authenticate(FileDescriptor &client) {
00360 MessageChannel channel(client);
00361 string username, password;
00362 MemZeroGuard passwordGuard(password);
00363 unsigned long long timeout = loginTimeout;
00364
00365 try {
00366 channel.write("version", "1", NULL);
00367
00368 try {
00369 if (!channel.readScalar(username, MESSAGE_SERVER_MAX_USERNAME_SIZE, &timeout)) {
00370 return AccountPtr();
00371 }
00372 } catch (const SecurityException &) {
00373 channel.write("The supplied username is too long.", NULL);
00374 return AccountPtr();
00375 }
00376
00377 try {
00378 if (!channel.readScalar(password, MESSAGE_SERVER_MAX_PASSWORD_SIZE, &timeout)) {
00379 return AccountPtr();
00380 }
00381 } catch (const SecurityException &) {
00382 channel.write("The supplied password is too long.", NULL);
00383 return AccountPtr();
00384 }
00385
00386 AccountPtr account = accountsDatabase->authenticate(username, password);
00387 passwordGuard.zeroNow();
00388 if (account == NULL) {
00389 channel.write("Invalid username or password.", NULL);
00390 return AccountPtr();
00391 } else {
00392 channel.write("ok", NULL);
00393 return account;
00394 }
00395 } catch (const SystemException &) {
00396 return AccountPtr();
00397 } catch (const TimeoutException &) {
00398 return AccountPtr();
00399 }
00400 }
00401
00402 void broadcastNewClientEvent(CommonClientContext &context,
00403 vector<ClientContextPtr> &handlerSpecificContexts) {
00404 vector<HandlerPtr>::iterator it;
00405
00406 for (it = handlers.begin(); it != handlers.end(); it++) {
00407 handlerSpecificContexts.push_back((*it)->newClient(context));
00408 }
00409 }
00410
00411 bool processMessage(CommonClientContext &commonContext,
00412 vector<ClientContextPtr> &handlerSpecificContexts,
00413 const vector<string> &args) {
00414 vector<HandlerPtr>::iterator handler_iter;
00415 vector<ClientContextPtr>::iterator context_iter;
00416
00417 for (handler_iter = handlers.begin(), context_iter = handlerSpecificContexts.begin();
00418 handler_iter != handlers.end();
00419 handler_iter++, context_iter++) {
00420 if ((*handler_iter)->processMessage(commonContext, *context_iter, args)) {
00421 return true;
00422 }
00423 }
00424 return false;
00425 }
00426
00427 void processUnknownMessage(CommonClientContext &commonContext, const vector<string> &args) {
00428 TRACE_POINT();
00429 string name;
00430 if (args.empty()) {
00431 name = "(null)";
00432 } else {
00433 name = args[0];
00434 }
00435 P_TRACE(2, "A MessageServer client sent an invalid command: "
00436 << name << " (" << args.size() << " elements)");
00437 }
00438
00439
00440
00441
00442 void clientHandlingMainLoop(FileDescriptor &client) {
00443 TRACE_POINT();
00444 vector<string> args;
00445
00446 P_TRACE(4, "MessageServer client thread " << (int) client << " started.");
00447
00448 try {
00449 AccountPtr account(authenticate(client));
00450 if (account == NULL) {
00451 P_TRACE(4, "MessageServer client thread " << (int) client << " exited.");
00452 return;
00453 }
00454
00455 CommonClientContext commonContext(client, account);
00456 vector<ClientContextPtr> handlerSpecificContexts;
00457 broadcastNewClientEvent(commonContext, handlerSpecificContexts);
00458 DisconnectEventBroadcastGuard dguard(handlers, commonContext, handlerSpecificContexts);
00459
00460 while (!this_thread::interruption_requested()) {
00461 UPDATE_TRACE_POINT();
00462 if (!commonContext.channel.read(args)) {
00463
00464 break;
00465 }
00466
00467 P_TRACE(4, "MessageServer client " << commonContext.name() <<
00468 ": received message: " << toString(args));
00469
00470 UPDATE_TRACE_POINT();
00471 if (!processMessage(commonContext, handlerSpecificContexts, args)) {
00472 processUnknownMessage(commonContext, args);
00473 break;
00474 }
00475 args.clear();
00476 }
00477
00478 P_TRACE(4, "MessageServer client thread " << (int) client << " exited.");
00479 client.close();
00480 } catch (const boost::thread_interrupted &) {
00481 P_TRACE(2, "MessageServer client thread " << (int) client << " interrupted.");
00482 } catch (const tracable_exception &e) {
00483 P_TRACE(2, "An error occurred in a MessageServer client thread " << (int) client << ":\n"
00484 << " message: " << toString(args) << "\n"
00485 << " exception: " << e.what() << "\n"
00486 << " backtrace:\n" << e.backtrace());
00487 }
00488 }
00489
00490 public:
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503 MessageServer(const string &socketFilename, AccountsDatabasePtr accountsDatabase) {
00504 this->socketFilename = socketFilename;
00505 this->accountsDatabase = accountsDatabase;
00506 loginTimeout = 2000;
00507 startListening();
00508 }
00509
00510 ~MessageServer() {
00511 this_thread::disable_syscall_interruption dsi;
00512 syscalls::close(serverFd);
00513 syscalls::unlink(socketFilename.c_str());
00514 }
00515
00516 string getSocketFilename() const {
00517 return socketFilename;
00518 }
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529 void mainLoop() {
00530 TRACE_POINT();
00531 while (true) {
00532 this_thread::interruption_point();
00533 sockaddr_un addr;
00534 socklen_t len = sizeof(addr);
00535 FileDescriptor fd;
00536
00537 UPDATE_TRACE_POINT();
00538 fd = syscalls::accept(serverFd, (struct sockaddr *) &addr, &len);
00539 if (fd == -1) {
00540 throw SystemException("Unable to accept a new client", errno);
00541 }
00542
00543 UPDATE_TRACE_POINT();
00544 this_thread::disable_interruption di;
00545 this_thread::disable_syscall_interruption dsi;
00546
00547 function<void ()> func(boost::bind(&MessageServer::clientHandlingMainLoop,
00548 this, fd));
00549 string name = "MessageServer client thread ";
00550 name.append(toString(fd));
00551 threadGroup.create_thread(func, name, CLIENT_THREAD_STACK_SIZE);
00552 }
00553 }
00554
00555
00556
00557
00558
00559
00560 void addHandler(HandlerPtr handler) {
00561 handlers.push_back(handler);
00562 }
00563
00564
00565
00566
00567
00568
00569
00570
00571 void setLoginTimeout(unsigned long long timeout) {
00572 assert(timeout != 0);
00573 loginTimeout = timeout;
00574 }
00575 };
00576
00577 typedef shared_ptr<MessageServer> MessageServerPtr;
00578
00579 }
00580
00581 #endif