Files
digFTP/src/main.cpp
Wirlaburla 5ff35c26ab * Added plugin support
* Improved Filer (LocalFiler)
* Supports Explicit SSL/TLS with OpenSSL/Crypto
* Fixed most of the bugs that drove me nuts.
* Properly handled socket closes
* Improved socket cleanup
* Buffered file downloads
* Lots 'o errors!

I did it again! I made a single commit with everything!
2024-12-13 10:51:10 -06:00

522 lines
15 KiB
C++

#include <iostream>
#include <string>
#include <unistd.h>
#include <cstring>
#include <future>
#include <thread>
#include <chrono>
#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <mutex>
#include <memory>
#include <system_error>
#include <sys/ioctl.h>
#include <sys/poll.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <errno.h>
#include <fcntl.h>
#include "main.h"
#include "auth_manager.cpp"
#include "client.cpp"
#include "util.h"
using namespace std::chrono_literals;
std::mutex client_mutex;
struct pollfd fds[MAX_CLIENTS];
struct ftpconn {
Client* client;
std::thread* thread;
bool close = false;
} fdc[MAX_CLIENTS];
void runClient(struct ftpconn* cfd) {
if (!cfd) {
logger->print(LOGLEVEL_ERROR, "Invalid connection handle");
return;
}
std::unique_lock<std::mutex> lock(client_mutex);
if (!cfd->client) {
logger->print(LOGLEVEL_ERROR, "Invalid client handle");
return;
}
int client_sock = cfd->client->control_sock;
Client* client = cfd->client;
lock.unlock();
char inbuf[BUFFERSIZE];
logger->print(LOGLEVEL_DEBUG, "C(%i) Client initialized", client_sock);
while (true) {
memset(inbuf, 0, BUFFERSIZE);
if (fcntl(client_sock, F_GETFD) < 0) {
logger->print(LOGLEVEL_DEBUG, "C(%i) Socket closed", client_sock);
break;
}
struct timeval tv;
tv.tv_sec = 60;
tv.tv_usec = 0;
fd_set readfds, writefds;
FD_ZERO(&readfds);
FD_ZERO(&writefds);
FD_SET(client_sock, &readfds);
// Add socket to writefds if SSL wants to write
if (client->isSecure() && client->getSSL()) {
FD_SET(client_sock, &writefds);
}
int select_result = select(client_sock + 1, &readfds, &writefds, NULL, &tv);
if (select_result < 0) {
if (errno == EINTR) continue;
logger->print(LOGLEVEL_ERROR, "C(%i) Select failed: %s", client_sock, strerror(errno));
break;
}
if (select_result == 0) {
logger->print(LOGLEVEL_DEBUG, "C(%i) Connection timeout", client_sock);
break;
}
int rc;
if (client->isSecure() && client->getSSL()) {
if (!client->isHandshakeComplete()) {
// Continue SSL handshake
ERR_clear_error(); // Clear any previous errors
int ret = SSL_accept(client->getSSL());
if (ret <= 0) {
int ssl_err = SSL_get_error(client->getSSL(), ret);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
continue; // Need more data for handshake
}
unsigned long err = ERR_get_error();
char err_buf[256];
ERR_error_string_n(err, err_buf, sizeof(err_buf));
logger->print(LOGLEVEL_ERROR, "C(%i) SSL handshake failed with error: %d (%s)",
client_sock, ssl_err, err_buf);
break;
}
client->setHandshakeComplete(true);
logger->print(LOGLEVEL_DEBUG, "C(%i) SSL handshake completed", client_sock);
continue;
} else {
// Normal SSL read after handshake
ERR_clear_error(); // Clear any previous errors
rc = SSL_read(client->getSSL(), inbuf, sizeof(inbuf) - 1);
if (rc <= 0) {
int ssl_err = SSL_get_error(client->getSSL(), rc);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
continue;
}
if (ssl_err == SSL_ERROR_SYSCALL) {
unsigned long err = ERR_get_error();
if (err == 0 && rc == 0) {
logger->print(LOGLEVEL_DEBUG, "C(%i) SSL connection closed", client_sock);
} else {
char err_buf[256];
ERR_error_string_n(err, err_buf, sizeof(err_buf));
logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read syscall error: %s",
client_sock, err_buf);
}
} else {
unsigned long err = ERR_get_error();
char err_buf[256];
ERR_error_string_n(err, err_buf, sizeof(err_buf));
logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read failed with error: %d (%s)",
client_sock, ssl_err, err_buf);
}
break;
}
}
} else {
rc = recv(client_sock, inbuf, sizeof(inbuf) - 1, 0);
if (rc <= 0) {
if (rc == 0) {
logger->print(LOGLEVEL_DEBUG, "C(%i) Client disconnected", client_sock);
} else {
logger->print(LOGLEVEL_ERROR, "C(%i) Recv failed: %s", client_sock, strerror(errno));
}
break;
}
}
inbuf[rc] = '\0';
if (rc >= 2 && inbuf[rc-2] == '\r' && inbuf[rc-1] == '\n') {
rc -= 2;
inbuf[rc] = '\0';
}
std::string input(inbuf, rc);
logger->print(LOGLEVEL_DEBUG, "C(%i) >> %s", client_sock, input.c_str());
std::string::size_type space_pos = input.find(" ");
std::string cmd = space_pos != std::string::npos ?
toUpper(input.substr(0, space_pos)) : toUpper(input);
std::string args = space_pos != std::string::npos ?
input.substr(space_pos + 1) : "";
lock.lock();
if (!cfd->client) {
lock.unlock();
break;
}
int revc = client->receive(cmd, args);
lock.unlock();
if (revc != 0) break;
}
// Mark for cleanup
logger->print(LOGLEVEL_DEBUG, "C(%i) Client thread ending", client_sock);
cfd->close = true;
}
void initializeAuth() {
AuthManager& auth_manager = AuthManager::getInstance();
auth_manager.setLogger(logger);
// Load auth plugins from plugin directory
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
// Load any additional plugins from plugin directory
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
if (entry.path().extension() == ".so" &&
entry.path().filename().string().find("libauth_") == 0) {
auth_manager.loadPlugin(entry.path().string());
}
}
// Create auth instance based on config
std::string auth_type = config->getValue("core", "auth_engine", "pam");
auth = auth_manager.createAuth(auth_type);
if (!auth) {
logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str());
exit(1);
}
// Initialize auth plugin with config
std::map<std::string, std::string> auth_config = config->get("auth_"+auth_type)->get();
if (!auth->initialize(auth_config)) {
logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str());
exit(1);
}
}
void initializeFiler() {
FilerManager& filer_manager = FilerManager::getInstance();
filer_manager.setLogger(logger);
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
if (entry.path().extension() == ".so" &&
entry.path().filename().string().find("libfiler_") == 0) {
filer_manager.loadPlugin(entry.path().string());
}
}
std::string filer_type = config->getValue("core", "filer_engine", "local");
default_filer_factory = filer_manager.getFactory(filer_type);
if (!default_filer_factory) {
logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str());
exit(1);
}
std::map<std::string, std::string> filer_config = config->get("filer_"+filer_type)->get();
// Test create using the factory
Filer* test_filer = default_filer_factory();
if (!test_filer) {
logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance");
exit(1);
}
delete test_filer;
}
int main(int argc , char *argv[]) {
printf("%s %s Copyright (C) 2024 Worlio LLC\n", APPNAME, VERSION);
printf("This program comes with ABSOLUTELY NO WARRANTY.\n");
printf("This is free software, and you are welcome to redistribute it under certain conditions.\n\n");
// SIGNALS
signal(SIGPIPE, SIG_IGN);
config = new ConfigFile(concatPath(std::string(CONFIG_DIR), "ftp.conf"));
server_name = config->getValue("core", "server_name", "digFTP %v");
sscanf(
config->getValue("net", "listen", "127.0.0.1:21").c_str(),
"%u.%u.%u.%u:%u",
&server_address[0],
&server_address[1],
&server_address[2],
&server_address[3],
&server_port
);
logger = new Logger();
std::string mainLogFile = config->getValue("logging", "all", "");
logger->openFileOnLevel(LOGLEVEL_MAX, mainLogFile.c_str());
logger->openFileOnLevel(LOGLEVEL_DEBUG, config->getValue("logging", "debug", mainLogFile).c_str());
logger->openFileOnLevel(LOGLEVEL_INFO, config->getValue("logging", "info", mainLogFile).c_str());
logger->openFileOnLevel(LOGLEVEL_WARNING, config->getValue("logging", "warning", mainLogFile).c_str());
logger->openFileOnLevel(LOGLEVEL_ERROR, config->getValue("logging", "error", mainLogFile).c_str());
logger->openFileOnLevel(LOGLEVEL_CRITICAL, config->getValue("logging", "critical", mainLogFile).c_str());
if (config->getBool("net", "ssl", false)) {
std::string cert_file = config->getValue("ssl", "certificate", "cert.pem");
if (cert_file[0] != '/')
cert_file = concatPath(std::string(CONFIG_DIR), cert_file);
logger->print(LOGLEVEL_INFO, "Using certificate file: %s", cert_file.c_str());
std::string key_file = config->getValue("ssl", "private_key", "key.pem");
if (key_file[0] != '/')
key_file = concatPath(std::string(CONFIG_DIR), key_file);
logger->print(LOGLEVEL_INFO, "Using private key file: %s", key_file.c_str());
if (!SSLManager::getInstance().initialize(cert_file, key_file)) {
logger->print(LOGLEVEL_CRITICAL, "Failed to initialize SSL");
return 1;
}
// set configured ciphers
SSLManager::getInstance().setCiphers(config->getValue("ssl", "ciphers", "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH").c_str());
// handle configured flags
bool ssl_flags = SSL_OP_NO_TICKET;
if (!config->getBool("ssl", "ssl_v2", false))
ssl_flags+=SSL_OP_NO_SSLv2;
if (!config->getBool("ssl", "ssl_v3", false))
ssl_flags+=SSL_OP_NO_SSLv3;
if (!config->getBool("ssl", "tls_v1_0", false))
ssl_flags+=SSL_OP_NO_TLSv1;
if (!config->getBool("ssl", "tls_v1_1", false))
ssl_flags+=SSL_OP_NO_TLSv1_1;
if (!config->getBool("ssl", "tls_v1_2", true))
ssl_flags+=SSL_OP_NO_TLSv1_2;
if (!config->getBool("ssl", "tls_v1_3", true))
ssl_flags+=SSL_OP_NO_TLSv1_3;
if ((ssl_flags & SSL_OP_NO_SSLv2) &&
(ssl_flags & SSL_OP_NO_SSLv3) &&
(ssl_flags & SSL_OP_NO_TLSv1) &&
(ssl_flags & SSL_OP_NO_TLSv1_1) &&
(ssl_flags & SSL_OP_NO_TLSv1_2) &&
(ssl_flags & SSL_OP_NO_TLSv1_3)) {
logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!");
}
if (!config->getBool("ssl", "compression", true))
ssl_flags+=SSL_OP_NO_COMPRESSION;
if (config->getBool("ssl", "prefer_server_ciphers", true))
ssl_flags+=SSL_OP_CIPHER_SERVER_PREFERENCE;
SSLManager::getInstance().setFlags(ssl_flags);
}
initializeAuth();
initializeFiler();
int opt = 1,
master_socket = -1,
newsock = -1,
nfds = 1,
current_size = 0,
src = 0;
runServer = true;
runCompression = false;
struct sockaddr_in ctrl_address;
if ((master_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
logger->print(LOGLEVEL_CRITICAL, "Failed creating socket");
return master_socket;
}
if ((src = setsockopt(master_socket, SOL_SOCKET, SO_REUSEADDR, (char *)&opt, sizeof(opt))) < 0) {
logger->print(LOGLEVEL_CRITICAL, "Unable to configure socket");
close(master_socket);
return src;
}
if ((src = ioctl(master_socket, FIONBIO, (char *)&opt)) < 0) {
logger->print(LOGLEVEL_CRITICAL, "Unable to read socket");
close(master_socket);
return src;
}
ctrl_address.sin_family = AF_INET;
ctrl_address.sin_addr.s_addr = INADDR_ANY;
ctrl_address.sin_port = htons(server_port);
if ((src = bind(master_socket, (struct sockaddr *)&ctrl_address, sizeof(ctrl_address))) < 0) {
logger->print(
LOGLEVEL_CRITICAL,
"Bind to %i.%i.%i.%i:%i failed",
&server_address[0],
&server_address[1],
&server_address[2],
&server_address[3],
server_port
);
close(master_socket);
return src;
}
if ((src = listen(master_socket, 3)) < 0) {
logger->print(LOGLEVEL_CRITICAL, "Unable to listen to socket");
close(master_socket);
return src;
}
memset(fds, 0, sizeof(fds));
memset(fdc, 0, sizeof(fdc));
fds[0].fd = master_socket;
fds[0].events = POLLIN;
logger->print(LOGLEVEL_INFO, "Server started.");
while (runServer) {
int pc = poll(fds, nfds, -1);
if (pc < 0) {
if (errno == EINTR) continue;
logger->print(LOGLEVEL_CRITICAL, "Connection poll faced a fatal error");
break;
}
if (pc == 0) continue;
current_size = nfds;
for (int i = 0; i < current_size; i++) {
if (fds[i].revents == 0)
continue;
// Handle poll errors properly without skipping cleanup
if (fds[i].revents != POLLIN) {
if (fds[i].fd != master_socket) {
logger->print(LOGLEVEL_ERROR, "Poll error on fd %d", fds[i].fd);
fdc[i].close = true;
}
}
if (fds[i].fd == master_socket && fds[i].revents == POLLIN) {
// Handle new connections
do {
newsock = accept(master_socket, NULL, NULL);
if (newsock < 0) {
if (errno != EWOULDBLOCK) {
logger->print(LOGLEVEL_ERROR, "accept() failed: %s", strerror(errno));
runServer = false;
}
break;
}
// Find first available slot
int slot = -1;
for (int j = 1; j < MAX_CLIENTS; j++) {
if (fds[j].fd <= 0) { // Changed from < 0 to <= 0
slot = j;
break;
}
}
if (slot < 0) {
logger->print(LOGLEVEL_ERROR, "No free slots available");
close(newsock);
continue;
}
// Set non-blocking mode
int flags = fcntl(newsock, F_GETFL, 0);
fcntl(newsock, F_SETFL, flags | O_NONBLOCK);
logger->print(LOGLEVEL_DEBUG, "C(%i) Accepted client in slot %d", newsock, slot);
fds[slot].fd = newsock;
fds[slot].events = POLLIN;
fds[slot].revents = 0;
fdc[slot].client = new Client(newsock);
fdc[slot].client->addOption("SIZE", true);
fdc[slot].client->addOption("UTF8", config->getBool("features", "utf8", true));
fdc[slot].thread = new std::thread(runClient, &fdc[slot]);
fdc[slot].close = false;
if (slot >= nfds) {
nfds = slot + 1;
}
} while (newsock != -1);
}
// Handle cleanup for any connections marked for closing
if (fds[i].fd != master_socket && (fdc[i].close || fds[i].revents != POLLIN)) {
int fd = fds[i].fd;
logger->print(LOGLEVEL_DEBUG, "C(%i) Cleaning up client in slot %d", fd, i);
// Close socket
if (fd > 0) {
shutdown(fd, SHUT_RDWR);
close(fd);
}
// Clean up thread
if (fdc[i].thread) {
if (fdc[i].thread->joinable()) {
fdc[i].thread->join();
}
delete fdc[i].thread;
}
// Clean up client
delete fdc[i].client;
// Reset slot
fds[i].fd = -1;
fds[i].events = 0;
fds[i].revents = 0;
memset(&fdc[i], 0, sizeof(struct ftpconn));
logger->print(LOGLEVEL_DEBUG, "C(%i) Cleanup completed", fd);
// Recalculate nfds if needed
if (i == nfds - 1) {
for (int j = nfds - 1; j >= 0; j--) {
if (fds[j].fd != -1) {
nfds = j + 1;
break;
}
}
}
}
}
}
// Cleanup
for (int i = 0; i < nfds; i++) {
if (fds[i].fd >= 0) {
close(fds[i].fd);
if (fdc[i].thread && fdc[i].thread->joinable()) {
fdc[i].thread->join();
}
delete fdc[i].thread;
delete fdc[i].client;
}
}
close(master_socket);
logger->print(LOGLEVEL_INFO, "Server closing...");
logger->close();
return 0;
}