#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "main.h" #include "auth_manager.cpp" #include "client.cpp" #include "util.h" using namespace std::chrono_literals; std::mutex client_mutex; struct pollfd fds[MAX_CLIENTS]; struct ftpconn { Client* client; std::thread* thread; bool close = false; } fdc[MAX_CLIENTS]; void runClient(struct ftpconn* cfd) { if (!cfd) { logger->print(LOGLEVEL_ERROR, "Invalid connection handle"); return; } std::unique_lock lock(client_mutex); if (!cfd->client) { logger->print(LOGLEVEL_ERROR, "Invalid client handle"); return; } int client_sock = cfd->client->control_sock; Client* client = cfd->client; lock.unlock(); char inbuf[BUFFERSIZE]; logger->print(LOGLEVEL_DEBUG, "C(%i) Client initialized", client_sock); while (true) { memset(inbuf, 0, BUFFERSIZE); if (fcntl(client_sock, F_GETFD) < 0) { logger->print(LOGLEVEL_DEBUG, "C(%i) Socket closed", client_sock); break; } struct timeval tv; tv.tv_sec = 60; tv.tv_usec = 0; fd_set readfds, writefds; FD_ZERO(&readfds); FD_ZERO(&writefds); FD_SET(client_sock, &readfds); // Add socket to writefds if SSL wants to write if (client->isSecure() && client->getSSL()) { FD_SET(client_sock, &writefds); } int select_result = select(client_sock + 1, &readfds, &writefds, NULL, &tv); if (select_result < 0) { if (errno == EINTR) continue; logger->print(LOGLEVEL_ERROR, "C(%i) Select failed: %s", client_sock, strerror(errno)); break; } if (select_result == 0) { logger->print(LOGLEVEL_DEBUG, "C(%i) Connection timeout", client_sock); break; } int rc; if (client->isSecure() && client->getSSL()) { if (!client->isHandshakeComplete()) { // Continue SSL handshake ERR_clear_error(); // Clear any previous errors int ret = SSL_accept(client->getSSL()); if (ret <= 0) { int ssl_err = SSL_get_error(client->getSSL(), ret); if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { continue; // Need more data for handshake } unsigned long err = ERR_get_error(); char err_buf[256]; ERR_error_string_n(err, err_buf, sizeof(err_buf)); logger->print(LOGLEVEL_ERROR, "C(%i) SSL handshake failed with error: %d (%s)", client_sock, ssl_err, err_buf); break; } client->setHandshakeComplete(true); logger->print(LOGLEVEL_DEBUG, "C(%i) SSL handshake completed", client_sock); continue; } else { // Normal SSL read after handshake ERR_clear_error(); // Clear any previous errors rc = SSL_read(client->getSSL(), inbuf, sizeof(inbuf) - 1); if (rc <= 0) { int ssl_err = SSL_get_error(client->getSSL(), rc); if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { continue; } if (ssl_err == SSL_ERROR_SYSCALL) { unsigned long err = ERR_get_error(); if (err == 0 && rc == 0) { logger->print(LOGLEVEL_DEBUG, "C(%i) SSL connection closed", client_sock); } else { char err_buf[256]; ERR_error_string_n(err, err_buf, sizeof(err_buf)); logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read syscall error: %s", client_sock, err_buf); } } else { unsigned long err = ERR_get_error(); char err_buf[256]; ERR_error_string_n(err, err_buf, sizeof(err_buf)); logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read failed with error: %d (%s)", client_sock, ssl_err, err_buf); } break; } } } else { rc = recv(client_sock, inbuf, sizeof(inbuf) - 1, 0); if (rc <= 0) { if (rc == 0) { logger->print(LOGLEVEL_DEBUG, "C(%i) Client disconnected", client_sock); } else { logger->print(LOGLEVEL_ERROR, "C(%i) Recv failed: %s", client_sock, strerror(errno)); } break; } } inbuf[rc] = '\0'; if (rc >= 2 && inbuf[rc-2] == '\r' && inbuf[rc-1] == '\n') { rc -= 2; inbuf[rc] = '\0'; } std::string input(inbuf, rc); logger->print(LOGLEVEL_DEBUG, "C(%i) >> %s", client_sock, input.c_str()); std::string::size_type space_pos = input.find(" "); std::string cmd = space_pos != std::string::npos ? toUpper(input.substr(0, space_pos)) : toUpper(input); std::string args = space_pos != std::string::npos ? input.substr(space_pos + 1) : ""; lock.lock(); if (!cfd->client) { lock.unlock(); break; } int revc = client->receive(cmd, args); lock.unlock(); if (revc != 0) break; } // Mark for cleanup logger->print(LOGLEVEL_DEBUG, "C(%i) Client thread ending", client_sock); cfd->close = true; } void initializeAuth() { AuthManager& auth_manager = AuthManager::getInstance(); auth_manager.setLogger(logger); // Load auth plugins from plugin directory std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); // Load any additional plugins from plugin directory for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { if (entry.path().extension() == ".so" && entry.path().filename().string().find("libauth_") == 0) { auth_manager.loadPlugin(entry.path().string()); } } // Create auth instance based on config std::string auth_type = config->getValue("core", "auth_engine", "pam"); auth = auth_manager.createAuth(auth_type); if (!auth) { logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str()); exit(1); } // Initialize auth plugin with config std::map auth_config = config->get("auth_"+auth_type)->get(); if (!auth->initialize(auth_config)) { logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str()); exit(1); } } void initializeFiler() { FilerManager& filer_manager = FilerManager::getInstance(); filer_manager.setLogger(logger); std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { if (entry.path().extension() == ".so" && entry.path().filename().string().find("libfiler_") == 0) { filer_manager.loadPlugin(entry.path().string()); } } std::string filer_type = config->getValue("core", "filer_engine", "local"); default_filer_factory = filer_manager.getFactory(filer_type); if (!default_filer_factory) { logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str()); exit(1); } std::map filer_config = config->get("filer_"+filer_type)->get(); // Test create using the factory Filer* test_filer = default_filer_factory(); if (!test_filer) { logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance"); exit(1); } delete test_filer; } int main(int argc , char *argv[]) { printf("%s %s Copyright (C) 2024 Worlio LLC\n", APPNAME, VERSION); printf("This program comes with ABSOLUTELY NO WARRANTY.\n"); printf("This is free software, and you are welcome to redistribute it under certain conditions.\n\n"); // SIGNALS signal(SIGPIPE, SIG_IGN); config = new ConfigFile(concatPath(std::string(CONFIG_DIR), "ftp.conf")); server_name = config->getValue("core", "server_name", "digFTP %v"); sscanf( config->getValue("net", "listen", "127.0.0.1:21").c_str(), "%u.%u.%u.%u:%u", &server_address[0], &server_address[1], &server_address[2], &server_address[3], &server_port ); logger = new Logger(); std::string mainLogFile = config->getValue("logging", "all", ""); logger->openFileOnLevel(LOGLEVEL_MAX, mainLogFile.c_str()); logger->openFileOnLevel(LOGLEVEL_DEBUG, config->getValue("logging", "debug", mainLogFile).c_str()); logger->openFileOnLevel(LOGLEVEL_INFO, config->getValue("logging", "info", mainLogFile).c_str()); logger->openFileOnLevel(LOGLEVEL_WARNING, config->getValue("logging", "warning", mainLogFile).c_str()); logger->openFileOnLevel(LOGLEVEL_ERROR, config->getValue("logging", "error", mainLogFile).c_str()); logger->openFileOnLevel(LOGLEVEL_CRITICAL, config->getValue("logging", "critical", mainLogFile).c_str()); if (config->getBool("net", "ssl", false)) { std::string cert_file = config->getValue("ssl", "certificate", "cert.pem"); if (cert_file[0] != '/') cert_file = concatPath(std::string(CONFIG_DIR), cert_file); logger->print(LOGLEVEL_INFO, "Using certificate file: %s", cert_file.c_str()); std::string key_file = config->getValue("ssl", "private_key", "key.pem"); if (key_file[0] != '/') key_file = concatPath(std::string(CONFIG_DIR), key_file); logger->print(LOGLEVEL_INFO, "Using private key file: %s", key_file.c_str()); if (!SSLManager::getInstance().initialize(cert_file, key_file)) { logger->print(LOGLEVEL_CRITICAL, "Failed to initialize SSL"); return 1; } // set configured ciphers SSLManager::getInstance().setCiphers(config->getValue("ssl", "ciphers", "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH").c_str()); // handle configured flags bool ssl_flags = SSL_OP_NO_TICKET; if (!config->getBool("ssl", "ssl_v2", false)) ssl_flags+=SSL_OP_NO_SSLv2; if (!config->getBool("ssl", "ssl_v3", false)) ssl_flags+=SSL_OP_NO_SSLv3; if (!config->getBool("ssl", "tls_v1_0", false)) ssl_flags+=SSL_OP_NO_TLSv1; if (!config->getBool("ssl", "tls_v1_1", false)) ssl_flags+=SSL_OP_NO_TLSv1_1; if (!config->getBool("ssl", "tls_v1_2", true)) ssl_flags+=SSL_OP_NO_TLSv1_2; if (!config->getBool("ssl", "tls_v1_3", true)) ssl_flags+=SSL_OP_NO_TLSv1_3; if ((ssl_flags & SSL_OP_NO_SSLv2) && (ssl_flags & SSL_OP_NO_SSLv3) && (ssl_flags & SSL_OP_NO_TLSv1) && (ssl_flags & SSL_OP_NO_TLSv1_1) && (ssl_flags & SSL_OP_NO_TLSv1_2) && (ssl_flags & SSL_OP_NO_TLSv1_3)) { logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!"); } if (!config->getBool("ssl", "compression", true)) ssl_flags+=SSL_OP_NO_COMPRESSION; if (config->getBool("ssl", "prefer_server_ciphers", true)) ssl_flags+=SSL_OP_CIPHER_SERVER_PREFERENCE; SSLManager::getInstance().setFlags(ssl_flags); } initializeAuth(); initializeFiler(); int opt = 1, master_socket = -1, newsock = -1, nfds = 1, current_size = 0, src = 0; runServer = true; runCompression = false; struct sockaddr_in ctrl_address; if ((master_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) { logger->print(LOGLEVEL_CRITICAL, "Failed creating socket"); return master_socket; } if ((src = setsockopt(master_socket, SOL_SOCKET, SO_REUSEADDR, (char *)&opt, sizeof(opt))) < 0) { logger->print(LOGLEVEL_CRITICAL, "Unable to configure socket"); close(master_socket); return src; } if ((src = ioctl(master_socket, FIONBIO, (char *)&opt)) < 0) { logger->print(LOGLEVEL_CRITICAL, "Unable to read socket"); close(master_socket); return src; } ctrl_address.sin_family = AF_INET; ctrl_address.sin_addr.s_addr = INADDR_ANY; ctrl_address.sin_port = htons(server_port); if ((src = bind(master_socket, (struct sockaddr *)&ctrl_address, sizeof(ctrl_address))) < 0) { logger->print( LOGLEVEL_CRITICAL, "Bind to %i.%i.%i.%i:%i failed", &server_address[0], &server_address[1], &server_address[2], &server_address[3], server_port ); close(master_socket); return src; } if ((src = listen(master_socket, 3)) < 0) { logger->print(LOGLEVEL_CRITICAL, "Unable to listen to socket"); close(master_socket); return src; } memset(fds, 0, sizeof(fds)); memset(fdc, 0, sizeof(fdc)); fds[0].fd = master_socket; fds[0].events = POLLIN; logger->print(LOGLEVEL_INFO, "Server started."); while (runServer) { int pc = poll(fds, nfds, -1); if (pc < 0) { if (errno == EINTR) continue; logger->print(LOGLEVEL_CRITICAL, "Connection poll faced a fatal error"); break; } if (pc == 0) continue; current_size = nfds; for (int i = 0; i < current_size; i++) { if (fds[i].revents == 0) continue; // Handle poll errors properly without skipping cleanup if (fds[i].revents != POLLIN) { if (fds[i].fd != master_socket) { logger->print(LOGLEVEL_ERROR, "Poll error on fd %d", fds[i].fd); fdc[i].close = true; } } if (fds[i].fd == master_socket && fds[i].revents == POLLIN) { // Handle new connections do { newsock = accept(master_socket, NULL, NULL); if (newsock < 0) { if (errno != EWOULDBLOCK) { logger->print(LOGLEVEL_ERROR, "accept() failed: %s", strerror(errno)); runServer = false; } break; } // Find first available slot int slot = -1; for (int j = 1; j < MAX_CLIENTS; j++) { if (fds[j].fd <= 0) { // Changed from < 0 to <= 0 slot = j; break; } } if (slot < 0) { logger->print(LOGLEVEL_ERROR, "No free slots available"); close(newsock); continue; } // Set non-blocking mode int flags = fcntl(newsock, F_GETFL, 0); fcntl(newsock, F_SETFL, flags | O_NONBLOCK); logger->print(LOGLEVEL_DEBUG, "C(%i) Accepted client in slot %d", newsock, slot); fds[slot].fd = newsock; fds[slot].events = POLLIN; fds[slot].revents = 0; fdc[slot].client = new Client(newsock); fdc[slot].client->addOption("SIZE", true); fdc[slot].client->addOption("UTF8", config->getBool("features", "utf8", true)); fdc[slot].thread = new std::thread(runClient, &fdc[slot]); fdc[slot].close = false; if (slot >= nfds) { nfds = slot + 1; } } while (newsock != -1); } // Handle cleanup for any connections marked for closing if (fds[i].fd != master_socket && (fdc[i].close || fds[i].revents != POLLIN)) { int fd = fds[i].fd; logger->print(LOGLEVEL_DEBUG, "C(%i) Cleaning up client in slot %d", fd, i); // Close socket if (fd > 0) { shutdown(fd, SHUT_RDWR); close(fd); } // Clean up thread if (fdc[i].thread) { if (fdc[i].thread->joinable()) { fdc[i].thread->join(); } delete fdc[i].thread; } // Clean up client delete fdc[i].client; // Reset slot fds[i].fd = -1; fds[i].events = 0; fds[i].revents = 0; memset(&fdc[i], 0, sizeof(struct ftpconn)); logger->print(LOGLEVEL_DEBUG, "C(%i) Cleanup completed", fd); // Recalculate nfds if needed if (i == nfds - 1) { for (int j = nfds - 1; j >= 0; j--) { if (fds[j].fd != -1) { nfds = j + 1; break; } } } } } } // Cleanup for (int i = 0; i < nfds; i++) { if (fds[i].fd >= 0) { close(fds[i].fd); if (fdc[i].thread && fdc[i].thread->joinable()) { fdc[i].thread->join(); } delete fdc[i].thread; delete fdc[i].client; } } close(master_socket); logger->print(LOGLEVEL_INFO, "Server closing..."); logger->close(); return 0; }