diff --git a/.gitignore b/.gitignore index bd83400..c06455f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,7 @@ +# Build Directory build/ -*.bak + +# Project Files +*.cbp +*.depend +*.layout diff --git a/CMakeLists.txt b/CMakeLists.txt index c6aab64..f1a8054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,12 +1,107 @@ -CMAKE_MINIMUM_REQUIRED(VERSION 3.26) -PROJECT(digftp) +cmake_minimum_required(VERSION 3.26) -SET(CMAKE_CXX_STANDARD 20) -SET(CMAKE_CXX_STANDARD_REQUIRED True) -SET(CMAKE_CXX_FLAGS "-O3") -SET(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) -INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}) +# Project definition +project(digftp + VERSION 1.0.0 + DESCRIPTION "A modern FTP server with plugin support" + LANGUAGES C CXX +) -ADD_SUBDIRECTORY(src) -FILE(COPY conf DESTINATION ${CMAKE_BINARY_DIR}) -INSTALL(TARGETS ${PROJECT_NAME} RUNTIME DESTINATION bin) \ No newline at end of file +# Global settings +set(APPNAME "digFTP") +include(GNUInstallDirs) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +# C++ standard requirements +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Build type configuration +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING "Choose the type of build" FORCE) +endif() + +# Version handling +if(NOT CMAKE_BUILD_TYPE MATCHES "Release") + if(NOT BUILD_VERSION) + find_package(Git QUIET) + if(GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") + execute_process( + COMMAND ${GIT_EXECUTABLE} rev-parse --short HEAD + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + OUTPUT_VARIABLE BUILD_VERSION + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + else() + set(BUILD_VERSION "unknown") + endif() + endif() + set(VERSION "${BUILD_VERSION}") +endif() + +# Compile Options +option(WITH_SSL "Compile with SSL support" ON) + +# Path variables +set(CONFIG_PATH "${CMAKE_INSTALL_SYSCONFDIR}/${PROJECT_NAME}") +set(PLUGIN_PATH "${CMAKE_INSTALL_LIBDIR}/${PROJECT_NAME}/plugins") + +# Configure build header +configure_file( + "${PROJECT_SOURCE_DIR}/src/build.h.in" + "${PROJECT_BINARY_DIR}/build.h" + @ONLY +) + +# Add compile definitions +add_compile_definitions( + PLUGIN_DIR="${CMAKE_INSTALL_PREFIX}/${PLUGIN_PATH}" + CONFIG_DIR="${CMAKE_INSTALL_PREFIX}/${CONFIG_PATH}" + WITH_SSL=${WITH_SSL} +) + +# Dependencies +find_package(Threads REQUIRED) +if(WITH_SSL) + find_package(OpenSSL REQUIRED) + include_directories(${OPENSSL_INCLUDE_DIR}) +endif() + +# Output directories +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") + +# Include directories +include_directories(${CMAKE_BINARY_DIR}) + +# Define plugins list +set(PLUGINS + auth_pam + auth_passdb + filer_local +) + +# Add main source directory which includes plugins +add_subdirectory(src) + +# Installation +install( + FILES + "${PROJECT_SOURCE_DIR}/conf/ftp.conf" + "${PROJECT_SOURCE_DIR}/conf/motd" + DESTINATION + "${CONFIG_PATH}" +) + +install( + TARGETS ${PROJECT_NAME} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + +install( + TARGETS ${PLUGINS} + LIBRARY DESTINATION ${PLUGIN_PATH} +) \ No newline at end of file diff --git a/cmake/FindPAM.cmake b/cmake/FindPAM.cmake new file mode 100644 index 0000000..9f12000 --- /dev/null +++ b/cmake/FindPAM.cmake @@ -0,0 +1,37 @@ +# Find PAM headers +find_path(PAM_INCLUDE_DIR + NAMES security/pam_appl.h + PATHS /usr/include + /usr/local/include +) + +# Find PAM library +find_library(PAM_LIBRARY + NAMES pam + PATHS /usr/lib + /usr/lib64 + /usr/local/lib + /usr/local/lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(PAM + REQUIRED_VARS + PAM_LIBRARY + PAM_INCLUDE_DIR +) + +if(PAM_FOUND) + set(PAM_LIBRARIES ${PAM_LIBRARY}) + set(PAM_INCLUDE_DIRS ${PAM_INCLUDE_DIR}) + + if(NOT TARGET PAM::PAM) + add_library(PAM::PAM UNKNOWN IMPORTED) + set_target_properties(PAM::PAM PROPERTIES + IMPORTED_LOCATION "${PAM_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${PAM_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(PAM_INCLUDE_DIR PAM_LIBRARY) \ No newline at end of file diff --git a/conf/ftp.conf b/conf/ftp.conf index 0fa9ded..0f34f3f 100644 --- a/conf/ftp.conf +++ b/conf/ftp.conf @@ -1,18 +1,99 @@ -[main] -server_name=My FTP Server -motd_file=motd -auth_engine=noauth -[logging] -file=digftp.log -level=0 +## DigFTP Configuration File +## ========================= +## This is the primary configuration file. Here is where you'll find all the +## necessary options to configure your FTP server. + +## == String Formatting +## These may be used in certain options to be runtime replaced. +## %u Username +## %p Password +## %v Version +## %h Hostname + +## == Booleans +## true | false +## ----- | ----- +## 1 | 0 +## on | off +## yes | no +## Anything unrecognized will be treated as 'false'. + +## CORE OPTIONS +## These affect the very core of the server. +[core] +## The name of the server it sends to the client. +## Syntax: server_name= +server_name=digFTP %v + +## MotD command to post the output to clients after log in. +## WARNING: This method can be insecure. Use with caution. +## Syntax: motd_command= +#motd_command=cowsay -r Welcome %u. + +## MotD file to post to clients after log in. Overrides `motd_command`. +## Syntax: motd_file= +#motd_file=motd + +## MotD text to post to clients after log in. Overrides `motd_command` and +## `motd_file`. +## Syntax: motd_text= +#motd_text= + +## Path to digFTP plugins +## Syntax: plugin_path= +plugin_path=/usr/lib/digftp/plugins + +## Authentication Engine to use for logging in. +## Syntax: auth_engine= + +auth_engine=local + +## Filer engine to use for the clients file system. +## Syntax: file_engine= +## Possible values are: local +filer_engine=local + + [net] -listen_address=127.0.0.1 -control_port=21 -max_clients=255 -[noauth] -user_root=/home/%u/ -chroot=1 +## Network address and port to listen on. +## Syntax: listen=: +listen=127.0.0.1:21 + +## Whether to support SSL. Server must be compiled with WITH_SSL=ON. +## Syntax: ssl= +ssl=on + +[features] +utf8=off + +[ssl] +certificate=cert.pem +private_key=key.pem + +ciphers=ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH + +ssl_v2=no +ssl_v3=no +tls_v1=no +tls_v1_1=no +tls_v1_2=yes +tls_v1_3=yes +compression=yes +prefer_server_ciphers=yes + +[logging] +## Log messages to file. Logging has multiple options pertaining to each +## loglevel. All values are paths. +## Syntax: [console|critical|error|warning|info|debug|all]= +info=digftp.info.log +error=digftp.error.log +debug=digftp.debug.log + [passdb] -file=conf/passdb -user_root=/home/%u/ -case_sensitive=1 \ No newline at end of file +## The file for the passdb engine. +## Syntax: passdb= +file=passdb + +## Root of logged in user. Can use string formatting. +## Syntax: home_path= +home_path=/home/%u/ \ No newline at end of file diff --git a/conf/passdb b/conf/passdb deleted file mode 100644 index bd9aed6..0000000 --- a/conf/passdb +++ /dev/null @@ -1 +0,0 @@ -anonymous:anonymous@ \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2a42749..5937b98 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,14 @@ -add_executable( - ${PROJECT_NAME} - main.cpp +# Main executable sources +set(SOURCES main.cpp) + +# Add main executable +add_executable(${PROJECT_NAME} ${SOURCES}) + +target_link_libraries(${PROJECT_NAME} + PRIVATE + Threads::Threads + ${OPENSSL_LIBRARIES} + ${CMAKE_DL_LIBS} ) + +add_subdirectory(plugins) \ No newline at end of file diff --git a/src/auth.h b/src/auth.h index c2f7485..c48af90 100644 --- a/src/auth.h +++ b/src/auth.h @@ -1,27 +1,72 @@ -#ifndef AUTHMETHODS_H -#define AUTHMETHODS_H +#ifndef AUTH_H +#define AUTH_H +#include +#include #include +#include +#include + +#include "globals.h" +#include "logger.h" +#include "util.h" + +typedef struct { + char username[255] = {0}; + char password[255] = {0}; + // Hold enough for IPv4 and IPv6 + char address[63] = {0}; + char hostname[255] = {0}; + char home_dir[PATH_MAX] = {0}; +} ClientAuthDetails; class Auth { public: - Auth() {}; + std::string user_directory = "/home/%u/"; - virtual void setOptions(ConfigSection* options) {}; + virtual ~Auth() {}; - virtual bool isChroot() { return false; }; + virtual bool initialize(const std::map& config) { return true; } - virtual bool check(std::string username, std::string password) { return false; }; + virtual bool isChroot() { return this->chroot; } - virtual std::string getUserDirectory(std::string username) { return ""; }; + virtual void setChroot(bool enable) { this->chroot = enable; } + + virtual bool isPasswordRequired() { return this->require_password; } + + virtual void setPasswordRequired(bool require) { this->require_password = require; } + + virtual bool authenticate(ClientAuthDetails* details) { return false; } + + virtual void setUserDirectory(ClientAuthDetails* details) { + std::string userdir = this->user_directory; + // Replace escaped var symbol with a dummy character + userdir = replace(userdir, std::string(CNF_PERCENTSYM_VAR), std::string(1, 0xFE)); + userdir = replace(userdir, std::string(CNF_USERNAME_VAR), std::string(details->username)); + userdir = replace(userdir, std::string(CNF_PASSWORD_VAR), std::string(details->password)); + userdir = replace(userdir, std::string(CNF_HOSTNAME_VAR), std::string(details->hostname)); + // Replace dummy character and return + userdir = replace(userdir, std::string(1, 0xFE), "%"); + + // Store the computed path in the details structure + strncpy(details->home_dir, userdir.c_str(), sizeof(details->home_dir) - 1); + details->home_dir[sizeof(details->home_dir) - 1] = '\0'; + } + + virtual uint64_t getTotalSpace() { return this->max_filespace; } + + virtual uint64_t getFreeSpace() { return this->getTotalSpace(); } + + virtual void setTotalSpace(uint64_t space) { this->max_filespace = space; } + + static void setLogger(Logger* log) { logger = log; } +private: + bool require_password = false; + uint64_t max_filespace = -1; + bool chroot = false; +protected: + static Logger* logger; }; -#include "auth/noauth.h" -#include "auth/passdb.h" - -Auth* getAuthByName(std::string name) { - if (name == "noauth") return new NoAuth(); - if (name == "passdb") return new PassdbAuth(); - return new Auth(); -} +Logger* Auth::logger = nullptr; #endif \ No newline at end of file diff --git a/src/auth/noauth.h b/src/auth/noauth.h deleted file mode 100644 index 3db77ca..0000000 --- a/src/auth/noauth.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _AUTH_NOAUTH -#define _AUTH_NOAUTH -#include "../util.h" - -class NoAuth : public Auth { -public: - NoAuth() {}; - - void setOptions(ConfigSection* options) { - this->options = options; - }; - - ConfigSection* getOptions() { - return options; - } - - bool isChroot() { - return options->getInt("chroot", 1) == 1; - }; - - bool check(std::string username, std::string password) { return true; }; - - std::string getUserDirectory(std::string username) { - std::string authDir = options->getValue("user_root", "/home/%u/"); - authDir = replace(authDir, "%u", username); - return authDir; - }; -private: - ConfigSection* options; -}; - -#endif \ No newline at end of file diff --git a/src/auth/passdb.h b/src/auth/passdb.h deleted file mode 100644 index 464b035..0000000 --- a/src/auth/passdb.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _AUTH_PASSDB -#define _AUTH_PASSDB -#include "../util.h" - -class PassdbAuth : public Auth { -public: - PassdbAuth() {}; - - void setOptions(ConfigSection* options) { - std::ifstream file(options->getValue("file", "passdb")); - std::string line; - while (std::getline(file, line)) { - if (line.length() == 0) continue; - else if (line.rfind("#", 0) == 0) continue; - else { - std::string username = line.substr(0, line.find(":")); - if (options->getInt("case_sensitive", 1) != 0) - username = toLower(username); - userlist[username] = line.substr(line.find(":")+1, line.length()); - } - } - userRoot = options->getValue("user_root", "/home/%u/"); - }; - - bool isChroot() { return true; }; - - bool check(std::string username, std::string password) { - if (userlist.find(username) == userlist.end()) return false; - if (userlist[username] != password) return false; - return true; - }; - - std::string getUserDirectory(std::string username) { - std::string authDir = userRoot; - authDir = replace(authDir, "%u", username); - return authDir; - }; -private: - std::map userlist; - std::string userRoot; -}; - -#endif \ No newline at end of file diff --git a/src/auth_manager.cpp b/src/auth_manager.cpp new file mode 100644 index 0000000..b2852ca --- /dev/null +++ b/src/auth_manager.cpp @@ -0,0 +1,124 @@ +#ifndef AUTH_MANAGER_H +#define AUTH_MANAGER_H + +#include +#include +#include +#include +#include "main.h" +#include "auth_plugin.h" + +class AuthManager { +public: + static AuthManager& getInstance() { + static AuthManager instance; + return instance; + } + + void setLogger(Logger* log) { + logger = log; + } + + Auth* createAuth(const std::string& type) { + auto it = plugins.find(type); + if (it != plugins.end()) { + return it->second.create(); + } + if (logger) logger->print(LOGLEVEL_ERROR, "Auth plugin type '%s' not found", type.c_str()); + return nullptr; + } + + bool loadPlugin(const std::string& path) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Loading auth plugin: %s", path.c_str()); + + void* handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load auth plugin %s: %s", + path.c_str(), dlerror()); + return false; + } + + // Load plugin functions + CreateAuthFunc create = (CreateAuthFunc)dlsym(handle, "createAuthPlugin"); + DestroyAuthFunc destroy = (DestroyAuthFunc)dlsym(handle, "destroyAuthPlugin"); + GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion"); + SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger"); + + const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName"); + const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription"); + const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion"); + + // Check each required function and log specific missing functions + if (!create || !destroy || !get_api_version || !setLogger || !get_name || !get_desc || !get_ver) { + if (logger) { + const char* missing = !create ? "createAuthPlugin" : + !destroy ? "destroyAuthPlugin" : + !get_api_version ? "getAPIVersion" : + !setLogger ? "setLogger" : + !get_name ? "getPluginName" : + !get_desc ? "getPluginDescription" : + "getPluginVersion"; + logger->print(LOGLEVEL_ERROR, "Invalid auth plugin %s: missing required function '%s'", + path.c_str(), missing); + } + dlclose(handle); + return false; + } + + // Check API version + int api_version = get_api_version(); + if (api_version != AUTH_PLUGIN_API_VERSION) { + if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible auth plugin API version in %s (got %d, expected %d)", + path.c_str(), api_version, AUTH_PLUGIN_API_VERSION); + dlclose(handle); + return false; + } + + std::string plugin_name = get_name(); + setLogger(logger); + + // Check if plugin is already loaded + if (plugins.find(plugin_name) != plugins.end()) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str()); + dlclose(handle); + return true; + } + + // Store plugin info + AuthPluginInfo plugin = { + plugin_name, + get_desc(), + get_ver(), + create, + destroy, + get_api_version, + setLogger, + handle + }; + + plugins[plugin.name] = plugin; + if (logger) logger->print(LOGLEVEL_INFO, "Loaded auth plugin: %s v%s", + plugin.name.c_str(), plugin.version.c_str()); + return true; + } + + void unloadPlugins() { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all auth plugins"); + for (auto& pair : plugins) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str()); + dlclose(pair.second.handle); + } + plugins.clear(); + } + + ~AuthManager() { + unloadPlugins(); + } + +private: + AuthManager() : logger(nullptr) {} // Singleton + std::map plugins; + Logger* logger; +}; + +#endif \ No newline at end of file diff --git a/src/auth_plugin.h b/src/auth_plugin.h new file mode 100644 index 0000000..0e475cf --- /dev/null +++ b/src/auth_plugin.h @@ -0,0 +1,37 @@ +#ifndef AUTH_PLUGIN_H +#define AUTH_PLUGIN_H +#include +#include "auth.h" + +#define AUTH_PLUGIN_API_VERSION 1 + +typedef Auth* (*CreateAuthFunc)(); +typedef void (*DestroyAuthFunc)(Auth*); +typedef int (*GetAPIVersionFunc)(); +typedef void (*SetLoggerFunc)(Logger*); + +struct AuthPluginInfo { + std::string name; + std::string description; + std::string version; + CreateAuthFunc create; + DestroyAuthFunc destroy; + GetAPIVersionFunc get_api_version; + SetLoggerFunc setLogger; + void* handle; +}; + +#define AUTH_PLUGIN_EXPORT extern "C" + +#define IMPLEMENT_AUTH_PLUGIN(classname, name, description, version) \ + extern "C" { \ + Auth* createAuthPlugin() { return new classname(); } \ + void destroyAuthPlugin(Auth* plugin) { delete plugin; } \ + const char* getPluginName() { return name; } \ + const char* getPluginDescription() { return description; } \ + const char* getPluginVersion() { return version; } \ + int getAPIVersion() { return AUTH_PLUGIN_API_VERSION; } \ + void setLogger(Logger* log) { Auth::setLogger(log); } \ + } + +#endif \ No newline at end of file diff --git a/src/build.h.in b/src/build.h.in new file mode 100644 index 0000000..137bd82 --- /dev/null +++ b/src/build.h.in @@ -0,0 +1,5 @@ +#ifndef BUILD_H +#define BUILD_H +#define APPNAME "@APPNAME@" +#define VERSION "@VERSION@" +#endif \ No newline at end of file diff --git a/src/client.cpp b/src/client.cpp index 3c61113..91f9f6f 100644 --- a/src/client.cpp +++ b/src/client.cpp @@ -8,67 +8,167 @@ #include #include #include -#include #include "main.h" -#include "server.h" -#include "filer.cpp" +#include "filer.h" #include "util.h" -namespace fs = std::filesystem; +#define FTP_STATE_CLOSE 0 +#define FTP_STATE_GUEST 1 +#define FTP_STATE_AUTHED 2 +#define FTP_STATE_ONDATA 3 class Client { public: int control_sock; - /* == State - * -1 - Delete - * 0 - Visitor / Unauthenticated - * 1 - Authenticated - * 2 - Waiting for data - */ - int state = 0; - std::thread thread; + + int getState() const { return state; } + bool isSecure() const { return is_secure; } + SSL* getSSL() const { return control_ssl; } + bool isHandshakeComplete() const { return ssl_handshake_complete; } + void setHandshakeComplete(bool complete) { ssl_handshake_complete = complete; } - Client(int &_sock) { - control_sock = _sock; - submit(230, config->getValue("main", "server_name", "digFTP Server")); - }; + Client(int _sock) : control_sock(_sock), control_ssl(nullptr), data_ssl(nullptr), + is_secure(false), ssl_handshake_complete(false), + cached_session(nullptr) { + if (!default_filer_factory) { + logger->print(LOGLEVEL_ERROR, "No filer factory available"); + return; + } + filer = default_filer_factory(); + submit(230, replace(server_name, CNF_VERSION_VAR, VERSION)); + } + + void addOption(std::string name, bool toggle) { + this->options[name] = toggle; + } int receive(std::string cmd, std::string argstr) { - if (cmd == "QUIT") { + if (control_sock <= 0) return 1; + + if (is_secure && control_ssl && !ssl_handshake_complete) { + int rc = SSL_accept(control_ssl); + if (rc <= 0) { + int err = SSL_get_error(control_ssl, rc); + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + // Need more data, return without error + return 0; + } + // Fatal error + logger->print(LOGLEVEL_ERROR, "C(%i) SSL handshake failed with error: %d", control_sock, err); + ERR_print_errors_fp(stderr); + SSL_free(control_ssl); + control_ssl = nullptr; + is_secure = false; + submit(431, "SSL handshake failed"); + return 1; + } + // Handshake completed successfully + ssl_handshake_complete = true; + logger->print(LOGLEVEL_DEBUG, "C(%i) SSL handshake completed", control_sock); + return 0; + } + + if (cmd == "AUTH") { + if (argstr == "TLS" || argstr == "SSL") { + if (setupControlSSL()) { + submit(234, "AUTH TLS successful"); + is_secure = true; + return 0; + } else { + submit(431, "AUTH TLS failed"); + } + } else { + submit(504, "AUTH type not supported"); + } + return 0; + } else if (cmd == "PBSZ") { + if (!is_secure) { + submit(503, "PBSZ not allowed on insecure control connection"); + } else { + submit(200, "PBSZ=0"); + } + return 0; + } else if (cmd == "PROT") { + if (!is_secure) { + submit(503, "PROT not allowed on insecure control connection"); + } else if (argstr == "P") { + protect_data = true; + submit(200, "Protection level set to Private"); + } else if (argstr == "C") { + protect_data = false; + submit(200, "Protection level set to Clear"); + } else { + submit(504, "PROT level not supported"); + } + return 0; + } else if (cmd == "QUIT") { + state = FTP_STATE_CLOSE; submit(250, "Goodbye!"); - return -1; - } else if (state > 0) { + shutdown(control_sock, SHUT_RDWR); // Add immediate shutdown + return 1; // Signal thread to terminate + } else if (state >= FTP_STATE_AUTHED) { if (cmd == "SYST") { submit(215, "UNIX Type: L8"); - } else if (cmd == "PWD") { - submit(257, "'"+filer->cwd.string()+"'"); - } else if (cmd == "CWD") { - int rc = filer->traverse(argstr); - if (rc == 0) submit(250, "OK"); - else submit(431, "No such directory"); - } else if (cmd == "CDUP") { - int rc = filer->traverse(".."); - if (rc == 0) submit(250, "OK"); - else submit(431, "No such directory"); - } else if (cmd == "MKD") { - int rc = filer->createDirectory(argstr); - if (rc == 0) submit(257, "\""+filer->relPath(argstr).string()+"\" directory created"); - else if (rc == -1) submit(521, "Directory already exists"); - else submit(550, "Access Denied"); } else if (cmd == "TYPE") { - sscanf(argstr.c_str(), "%c", &(filer->type)); + char type = 0; + sscanf(argstr.c_str(), "%c", &(type)); + filer->setTransferMode(type); submit(226, "OK"); + } else if (cmd == "OPTS") { + std::vector args = parseArgs(argstr); + if (args.size() > 0) { + std::string name = toUpper(args[0]); + int rc = false; + + if (args.size() > 1) rc = toggleOption(name, toUpper(args[1])=="ON"); + else rc = toggleOption(name, true); + + if (rc == 0) submit(200, "OK"); + else if (rc == 1) submit(451, "Option not enabled or not recognized"); + else submit(550, "Unknown error"); + } else submit(550, "Malformed Request"); + } else if (cmd == "PWD") { + submit(257, "'"+filer->getCWD().string()+"'"); + } else if (cmd == "CWD") { + struct file_data fd = filer->traverse(argstr); + if (fd.error.code == 0) submit(250, "OK"); + else submit(431, std::string(fd.error.msg)); + } else if (cmd == "CDUP") { + struct file_data fd = filer->traverse(".."); + if (fd.error.code == 0) submit(250, "OK"); + else submit(431, std::string(fd.error.msg)); + } else if (cmd == "MKD") { + struct file_data fd = filer->createDirectory(argstr); + if (fd.error.code == 0) submit(257, "\""+std::string(fd.path)+"\" directory created"); + else if (fd.error.code == FilerStatusCodes::FileExists) submit(521, std::string(fd.error.msg)); + else submit(550, std::string(fd.error.msg)); + } else if (cmd == "SIZE") { + struct file_data fd = filer->fileSize(argstr); + if (fd.error.code == 0) submit(213, std::to_string(fd.size)); + else submit(550, std::string(fd.error.msg)); + } else if (cmd == "FEAT") { + submit("211-Extensions supported:"); + for (const auto &opt : options) { + submit(" "+toUpper(opt.first)); + } + submit(211, "END"); + } else if (cmd == "NOOP") { + submit(226, "OK"); + } else if (cmd == "DELE" || cmd == "RMD") { + struct file_data fd = filer->deleteFile(argstr); + if (fd.error.code == 0) submit(250, "OK"); + else submit(550, std::string(fd.error.msg)); } else if (cmd == "PASV") { if ((data_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) { perror("pasv socket() failed"); - return -1; + return 1; } if (setsockopt(data_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, (char *)&opt, sizeof(opt)) < 0) { perror("pasv setsockopt() failed"); close(data_fd); - return -1; + return 1; } uint16_t dataport; @@ -83,189 +183,474 @@ public: if (bind(data_fd, (struct sockaddr *)&data_address, sizeof(data_address)) < 0) { perror("pasv bind() failed"); close(data_fd); - return -1; + return 1; } if (listen(data_fd, 3) < 0) { perror("pasv listen() failed"); close(data_fd); - return -1; + return 1; } if (getsockname(data_fd, (struct sockaddr *)&data_address, &datalen) == 0) { - sscanf(config->getValue("net", "listen_address", "127.0.0.1").c_str(), "%d.%d.%d.%d", &netaddr[0], &netaddr[1], &netaddr[2], &netaddr[3]); dataport = ntohs(data_address.sin_port); memcpy(&netport[0], &dataport, 2); - logger->print(LOGLEVEL_DEBUG, "D(%i) PASV initialized: %u.%u.%u.%u:%u", data_fd, netaddr[0], netaddr[1], netaddr[2], netaddr[3], dataport); + logger->print(LOGLEVEL_DEBUG, "D(%i) PASV initialized: %u.%u.%u.%u:%u", data_fd, server_address[0], server_address[1], server_address[2], server_address[3], dataport); } else { perror("pasv getpeername() failed"); close(data_fd); - return -1; + return 1; } char* pasvok; asprintf( &pasvok, "Entering Passive Mode (%u,%u,%u,%u,%u,%u).", - netaddr[0], - netaddr[1], - netaddr[2], - netaddr[3], + server_address[0], + server_address[1], + server_address[2], + server_address[3], netport[1], netport[0] ); submit(227, std::string(pasvok)); free(pasvok); - if ((data_sock = accept(data_fd, NULL, NULL)) >= 0) { logger->print(LOGLEVEL_INFO, "D(%i) PASV accepted: %i", data_fd, data_sock); - state = 2; + + if (is_secure && protect_data) { + if (!setupDataSSL()) { + submit(425, "Can't setup secure data connection"); + data_close(); + return 0; + } + } + + state = FTP_STATE_ONDATA; } else { - perror("accept() failed"); - submit(425, "Unknown Error"); + logger->print(LOGLEVEL_ERROR, "D(%i) PASV accept failed: %s", data_fd, strerror(errno)); + submit(425, "Can't open data connection"); data_close(); } - } else if (cmd == "SIZE") { - int ret = filer->fileSize(argstr); - if (ret >= 0) submit(213, std::to_string(ret)); - else submit(550, "Access Denied "+std::to_string(ret)); - } else if (cmd == "FEAT") { - submit("211-Extensions supported:"); - submit(" UTF8"); - submit(" SIZE"); - submit(211, "END"); - } else if (cmd == "NOOP") { - submit(226, "OK"); - } else if (cmd == "DELE" || cmd == "RMD") { - int ret = filer->deleteFile(argstr); - if (ret == 0) { - submit(250, "OK"); - } else { - submit(550, "Access Denied "+std::to_string(ret)); + } else if (cmd == "RNFR") { + if (argstr.empty()) { + submit(501, "Syntax error in parameters or arguments."); + return 0; } - } else if (state == 2) { - if (data_sock <= 0) return -1; + + // Check if source file exists + struct file_data fd = filer->fileSize(argstr); + if (fd.error.code != 0) { + submit(550, std::string(fd.error.msg)); + rename_pending = false; + return 0; + } + + // Store the source path and mark rename as pending + rename_from = argstr; + rename_pending = true; + submit(350, "Ready for RNTO."); + } else if (cmd == "RNTO") { + if (!rename_pending) { + submit(503, "RNFR required first."); + return 0; + } + + if (argstr.empty()) { + submit(501, "Syntax error in parameters or arguments."); + rename_pending = false; + return 0; + } + + // Add rename functionality to Filer class + struct file_data fd = filer->renameFile(rename_from, argstr); + rename_pending = false; // Reset rename state + + if (fd.error.code == 0) { + submit(250, "Rename successful."); + } else { + submit(550, std::string(fd.error.msg)); + } + } else if (state == FTP_STATE_ONDATA) { + if (data_sock <= 0) return 1; if (cmd == "LIST" || cmd == "NLST") { submit(150, "Transferring"); std::string dirname = ""; if (argstr.find_first_of('/') != std::string::npos) dirname = argstr.substr(argstr.find_first_of('/')); - std::string output = filer->list(dirname); - char out[output.size()] = {0}; - strncpy(out, output.c_str(), output.size()); - data_submit(out, output.size()); + struct file_data fd = filer->list(dirname); + data_submit(fd.bin, fd.size); submit(226, "OK"); } else if (cmd == "RETR") { struct file_data fd = filer->readFile(argstr); - if (fd.ecode == 0) { + if (fd.error.code == 0 && fd.stream && fd.stream->is_open()) { submit(150, "Transferring"); - data_submit(fd.data, fd.size); - free(fd.data); - submit(226, "OK"); + + char buffer[8192]; + bool transfer_ok = true; + + while (!fd.stream->eof()) { + fd.stream->read(buffer, sizeof(buffer)); + size_t bytes_read = fd.stream->gcount(); + + if (bytes_read > 0) { + if (data_submit(buffer, bytes_read) != 0) { + submit(426, "Transfer failed"); + transfer_ok = false; + break; + } + } + + if (fd.stream->fail() && !fd.stream->eof()) { + submit(426, "Read error"); + transfer_ok = false; + break; + } + } + + if (transfer_ok) { + submit(226, "OK"); + } } else { - if (fd.ecode == -3) submit(550, "I refuse to transmit in ASCII mode!"); - else submit(550, "Access Denied"); + submit(550, std::string(fd.error.msg)); } + data_close(); } else if (cmd == "STOR") { unsigned char inbuf[BUFFERSIZE] = {0}; submit(150, "Transferring"); int psize; - // Write nothing to file to make sure it exists. - // Also so I don't have to write a more complex writer. - int ret = filer->writeFile(argstr, inbuf, 0); - while ((psize = recv(data_sock, inbuf, sizeof(inbuf), 0)) > 0) { - ret = filer->appendFile(argstr, inbuf, psize); - if (ret < 0) { - submit(550, "Access Denied "+std::to_string(ret)); - data_close(); - return -1; + struct file_data fd = filer->writeFile(argstr, inbuf, 0); + + while (true) { + if (is_secure && protect_data && data_ssl) { + psize = SSL_read(data_ssl, inbuf, sizeof(inbuf)); + if (psize <= 0) { + int err = SSL_get_error(data_ssl, psize); + if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { + continue; + } + break; + } + } else { + psize = recv(data_sock, inbuf, sizeof(inbuf), 0); + if (psize <= 0) break; } - for (int i = 0; i < BUFFERSIZE; i++) inbuf[i] = 0; + + fd = filer->writeFile(argstr, inbuf, psize, true); + if (fd.error.code != 0) { + submit(550, "Access Denied: "+std::string(fd.error.msg)); + data_close(); + break; + } + memset(inbuf, 0, BUFFERSIZE); } submit(226, "OK"); + } else { + submit(502, "Command not implemented"); } data_close(); } else { submit(502, "Command not implemented"); } } else { - state = 0; // If we reach here, force state to zero just in case. + state = FTP_STATE_GUEST; // If we reach here, force state to zero just in case. if (cmd == "USER") { - name = argstr; - submit(331, "Password required"); + if (argstr.length() >= sizeof(auth_data->username)) { + throw std::runtime_error("Username too long"); + } + std::strncpy(auth_data->username, argstr.c_str(), sizeof(auth_data->username) - 1); + auth_data->username[sizeof(auth_data->username) - 1] = '\0'; + if (auth->isPasswordRequired()) + submit(331, "Password required"); + else authedInit(); } else if (cmd == "PASS") { - if (auth->check(name, argstr)) { - logger->print(LOGLEVEL_INFO, "(%i) logged in as '%s'", control_sock, name.c_str()); - // We can now set the root safely (I hope). - filer->setRoot(auth->getUserDirectory(name)); - state = 1; - submit(230, "Login OK"); - } else { - name = ""; + std::strncpy(auth_data->password, argstr.c_str(), sizeof(auth_data->password) - 1); + auth_data->password[sizeof(auth_data->password) - 1] = '\0'; + if (auth->authenticate(auth_data)) authedInit(); + else { + auth_data->username[0] = {}; submit(530, "Invalid Credentials."); } - } else if (cmd == "AUTH") { - submit(502, "Command not implemented"); } else { submit(332, "Not Logged In!"); - return -1; + return 1; } } return 0; } + int toggleOption(std::string name, bool toggle) { + auto feat_it = options.find(name); + if (feat_it != options.end()) { + feat_it->second = toggle; + return 0; + } + return 1; + } + + bool getOption(std::string name) { + auto feat_it = options.find(name); + if (feat_it != options.end()) { + return feat_it->second; + } + return false; + } + int submit(std::string msg) { - std::string out = msg+"\r\n"; - int bytes = send(control_sock, out.c_str(), out.size(), 0); - if (bytes < 0) { - logger->print(LOGLEVEL_ERROR, "C(%i) !< %s", control_sock, msg.c_str()); - return 1; + std::string out = msg + "\r\n"; + if (fcntl(control_sock, F_GETFD) < 0) return 1; + + int result; + if (is_secure && control_ssl) { + result = SSL_write(control_ssl, out.c_str(), out.size()); + if (result <= 0) { + int err = SSL_get_error(control_ssl, result); + if (err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_WANT_READ) { + // The socket is not ready, main loop will handle it + return 0; + } + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_write failed with error: %d", control_sock, err); + return 1; + } + } else { + result = send(control_sock, out.c_str(), out.size(), MSG_NOSIGNAL); } - logger->print(LOGLEVEL_DEBUG, "C(%i) << %s", control_sock, msg.c_str()); - return 0; + + if (result >= 0) { + logger->print(LOGLEVEL_DEBUG, "C(%i) << %s", control_sock, msg.c_str()); + return 0; + } + logger->print(LOGLEVEL_ERROR, "C(%i) !< %s", control_sock, msg.c_str()); + return 1; } - int submit(int code, std::string msg) { - std::string out = std::to_string(code)+" "+msg+"\r\n"; - int bytes = send(control_sock, out.c_str(), out.size(), 0); - if (bytes < 0) { - logger->print(LOGLEVEL_ERROR, "C(%i) !< %i %s", control_sock, code, msg.c_str()); - return 1; + void submit(int code, const std::string& msg) { + std::string response = std::to_string(code) + " " + msg + "\r\n"; + + if (is_secure && ssl_handshake_complete && control_ssl) { + // Send through SSL if secure connection is established + int written = SSL_write(control_ssl, response.c_str(), response.length()); + if (written <= 0) { + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_write failed", control_sock); + return; + } + } else { + // Regular send for non-secure or pre-handshake messages + send(control_sock, response.c_str(), response.length(), 0); } - logger->print(LOGLEVEL_DEBUG, "C(%i) << %i %s", control_sock, code, msg.c_str()); - return 0; + + logger->print(LOGLEVEL_DEBUG, "C(%i) << %d %s", control_sock, code, msg.c_str()); } - int data_submit(char* out, int size) { - int bytes = send(data_sock, out, size, 0); - if (bytes < 0) { - logger->print(LOGLEVEL_DEBUG, "D(%i) !< %i", data_sock, size); - return 1; + int data_submit(char* out, size_t size) { + size_t total_sent = 0; + while (total_sent < size) { + int bytes; + if (is_secure && protect_data && data_ssl) { + bytes = SSL_write(data_ssl, out + total_sent, size - total_sent); + if (bytes <= 0) { + int ssl_err = SSL_get_error(data_ssl, bytes); + if (ssl_err == SSL_ERROR_WANT_WRITE || ssl_err == SSL_ERROR_WANT_READ) { + continue; // Need to retry + } + logger->print(LOGLEVEL_DEBUG, "D(%i) SSL write error: %d", data_sock, ssl_err); + return 1; + } + } else { + bytes = send(data_sock, out + total_sent, size - total_sent, 0); + if (bytes < 0) { + if (errno == EINTR) continue; + logger->print(LOGLEVEL_DEBUG, "D(%i) !< Error: %s", data_sock, strerror(errno)); + return 1; + } + } + total_sent += bytes; } - logger->print(LOGLEVEL_DEBUG, "D(%i) << %i", data_sock, size); return 0; } int data_close() { if (data_sock <= 0) return 0; logger->print(LOGLEVEL_DEBUG, "D(%i) Closing...", data_sock); + + if (data_ssl) { + SSL_shutdown(data_ssl); + SSL_free(data_ssl); + data_ssl = nullptr; + } + + shutdown(data_sock, SHUT_RDWR); close(data_sock); data_sock = -1; - close(data_fd); - data_fd = -1; - state = 1; + + if (data_fd >= 0) { + close(data_fd); + data_fd = -1; + } + + state = FTP_STATE_AUTHED; return 0; } + ~Client() { + if (cached_session) { + SSL_SESSION_free(cached_session); + cached_session = nullptr; + } + if (filer) delete filer; + if (control_ssl) { + SSL_shutdown(control_ssl); + SSL_free(control_ssl); + control_ssl = nullptr; + } + if (data_ssl) { + SSL_shutdown(data_ssl); + SSL_free(data_ssl); + data_ssl = nullptr; + } + } + private: - std::string name; - Filer* filer = new Filer(); + std::map options; + ClientAuthDetails* auth_data = new ClientAuthDetails{}; + Filer* filer = {}; const int opt = 1; + uint8_t flags = {}; + int state = FTP_STATE_GUEST; int data_fd; int data_sock; struct sockaddr_in data_address; + + std::string rename_from; + bool rename_pending = false; + + void authedInit() { + logger->print(LOGLEVEL_INFO, "C(%i) logging in as '%s'", control_sock, auth_data->username); + struct file_data fd; + if (auth->isChroot()) { + if ( + ((struct file_data)filer->setRoot(std::string(this->auth_data->home_dir))).error.code == 0 && + ((struct file_data)filer->setCWD("/")).error.code == 0 + ) { + state = FTP_STATE_AUTHED; + submit(230, "Login OK"); + logger->print(LOGLEVEL_INFO, "C(%i) Set chrooted root of '%s' to '%s'", control_sock, auth_data->username, filer->getRoot().c_str()); + return; + } + } else { + if ( + ((struct file_data)filer->setRoot("/")).error.code == 0 && + ((struct file_data)filer->setCWD(std::string(this->auth_data->home_dir))).error.code == 0 + ) { + state = FTP_STATE_AUTHED; + submit(230, "Login OK"); + logger->print(LOGLEVEL_INFO, "C(%i) Set home of '%s' to '%s'", control_sock, auth_data->username, filer->getCWD().c_str()); + return; + } + } + submit(530, "An error occured setting root and/or cwd"); + return; + } + + SSL* control_ssl; + SSL* data_ssl; + SSL_SESSION* cached_session = nullptr; + bool is_secure; + bool protect_data; + bool ssl_handshake_complete; + + bool setupControlSSL() { + control_ssl = SSL_new(SSLManager::getInstance().getContext()); + if (!control_ssl) { + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_new failed", control_sock); + return false; + } + + if (!SSL_set_fd(control_ssl, control_sock)) { + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_set_fd failed", control_sock); + SSL_free(control_ssl); + control_ssl = nullptr; + return false; + } + + ssl_handshake_complete = false; + logger->print(LOGLEVEL_DEBUG, "C(%i) SSL setup complete, handshake pending", control_sock); + + // Cache the session after successful handshake + if (cached_session) { + SSL_SESSION_free(cached_session); + cached_session = nullptr; + } + cached_session = SSL_get1_session(control_ssl); // Increment reference count + + return true; + } + + bool setupDataSSL() { + if (!is_secure || !protect_data) return true; + + data_ssl = SSL_new(SSLManager::getInstance().getContext()); + if (!data_ssl) { + logger->print(LOGLEVEL_ERROR, "C(%i) Data SSL_new failed", control_sock); + return false; + } + + // Set the cached session for reuse + if (cached_session) { + SSL_set_session(data_ssl, cached_session); + } + + SSL_set_accept_state(data_ssl); + if (!SSL_set_fd(data_ssl, data_sock)) { + logger->print(LOGLEVEL_ERROR, "C(%i) Data SSL_set_fd failed", control_sock); + SSL_free(data_ssl); + data_ssl = nullptr; + return false; + } + + // Set non-blocking mode for SSL handshake + int flags = fcntl(data_sock, F_GETFL, 0); + fcntl(data_sock, F_SETFL, flags | O_NONBLOCK); + + int ret; + while ((ret = SSL_accept(data_ssl)) <= 0) { + int ssl_err = SSL_get_error(data_ssl, ret); + if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { + // Need more data, wait a bit and retry + struct pollfd pfd; + pfd.fd = data_sock; + pfd.events = (ssl_err == SSL_ERROR_WANT_READ) ? POLLIN : POLLOUT; + + if (poll(&pfd, 1, 1000) <= 0) { + logger->print(LOGLEVEL_ERROR, "C(%i) Data SSL handshake timeout", control_sock); + SSL_free(data_ssl); + data_ssl = nullptr; + return false; + } + continue; + } + + logger->print(LOGLEVEL_ERROR, "C(%i) Data SSL handshake failed: %s", + control_sock, ERR_error_string(ERR_get_error(), nullptr)); + SSL_free(data_ssl); + data_ssl = nullptr; + return false; + } + + // Reset blocking mode + fcntl(data_sock, F_SETFL, flags); + + // Log whether session was reused + if (SSL_session_reused(data_ssl)) { + logger->print(LOGLEVEL_DEBUG, "C(%i) Data SSL session resumed", control_sock); + } else { + logger->print(LOGLEVEL_DEBUG, "C(%i) Data SSL new session", control_sock); + } + + return true; + } }; diff --git a/src/conf.h b/src/conf.h index 2e32dff..6dce477 100644 --- a/src/conf.h +++ b/src/conf.h @@ -5,6 +5,9 @@ #include #include +#include "util.h" +#include "build.h" + class ConfigSection { public: ConfigSection() {}; @@ -34,6 +37,11 @@ public: else return def; } + bool getBool(std::string name, bool def) { + if (this->options.find(name) != this->options.end()) return this->options[name]=="on" || this->options[name] == "1" || this->options[name] == "true" || this->options[name] == "yes"; + else return def; + } + void setValue(std::string name, std::string value) { this->options[name] = value; } @@ -41,6 +49,10 @@ public: void setInt(std::string name, int value) { this->options[name] = std::to_string(value); } + + void setBool(std::string name, bool value) { + this->options[name] = value?"true":"false"; + } private: std::map options; }; @@ -50,7 +62,7 @@ public: ConfigFile() {}; ConfigFile(std::string path) { std::ifstream file(path); - ConfigSection* section; + ConfigSection* section = new ConfigSection(); std::string line; while (std::getline(file, line)) { if (line.length() == 0) continue; @@ -62,7 +74,6 @@ public: section->setValue(line.substr(0, line.find("=")), line.substr(line.find("=")+1, line.length())); } } - printf("Configuration loaded: \"%s\"\n", path.c_str()); } std::map get() { @@ -70,7 +81,9 @@ public: } ConfigSection* get(std::string section) { - return this->sections[section]; + if (this->sections.find(section) == this->sections.end()) + return new ConfigSection(); + else return this->sections[section]; } std::vector getKeyValues() { @@ -82,23 +95,44 @@ public: } std::string getValue(std::string section, std::string name, std::string def) { - if (this->sections[section]->get().find(name) != this->sections[section]->get().end()) return this->sections[section]->getValue(name, def); + ConfigSection* cs = this->get(section); + if (cs->get().find(name) != cs->get().end()) return cs->getValue(name, def); else return def; } int getInt(std::string section, std::string name, int def) { - if (this->sections[section]->get().find(name) != this->sections[section]->get().end()) return this->sections[section]->getInt(name, def); + ConfigSection* cs = this->get(section); + if (cs->get().find(name) != cs->get().end()) return cs->getInt(name, def); + else return def; + } + + bool getBool(std::string section, std::string name, bool def) { + ConfigSection* cs = this->get(section); + if (cs->get().find(name) != cs->get().end()) return cs->getBool(name, def); else return def; } void setValue(std::string section, std::string name, std::string value) { + if (this->sections.find(section) == this->sections.end()) + this->sections[section] = new ConfigSection(); + this->sections[section]->setValue(name, value); } void setInt(std::string section, std::string name, int value) { + if (this->sections.find(section) == this->sections.end()) + this->sections[section] = new ConfigSection(); + this->sections[section]->setInt(name, value); } + void setBool(std::string section, std::string name, bool value) { + if (this->sections.find(section) == this->sections.end()) + this->sections[section] = new ConfigSection(); + + this->sections[section]->setBool(name, value); + } + void write(std::string path) { std::ofstream outfile(path); for (std::map::iterator si = this->sections.begin(); si != this->sections.end(); ++si) { diff --git a/src/filer.cpp b/src/filer.cpp deleted file mode 100644 index 442044c..0000000 --- a/src/filer.cpp +++ /dev/null @@ -1,232 +0,0 @@ -#include -#include -#include -#include -#include -#include - -/* === THE FILER === - * This class and its structs handle file operations for - * specified users. Its goal is an easy, consistent, and - * modular class to simply call on from inside Client. - * - * A Filer object has its defined root that it should never - * break out of. If it does, best hope we aren't running as - * root. - */ - -/* == STATUS CODES == - * - * -3 - Uhhhhhh - * -2 - Invalid Permissions - * -1 - File Not Found - * 0 - Generic Success - */ - -struct file_data { - char* data; - int size; - int ecode; -}; - -namespace fs = std::filesystem; - -class Filer { -public: - fs::path cwd; - char type = 'A'; - - Filer() {}; - - int traverse(std::string dir) { - fs::path ndir = fs::weakly_canonical(cwd / dir); - fs::path fdir = fullPath(dir); - logger->print(LOGLEVEL_INFO, "Traversing: %s", ndir.c_str()); - if (fdir.string().rfind(root.string(), 0) != 0) return -2; - else if (!fs::exists(fdir) || !fs::is_directory(fdir)) return -1; - else { - cwd = ndir; - return 0; - } - } - - int setRoot(std::string _root) { - fs::path froot = fs::weakly_canonical(_root); - logger->print(LOGLEVEL_INFO, "Setting root: %s", froot.c_str()); - if (!fs::exists(froot)) fs::create_directory(froot); - root = fs::absolute(froot); - cwd = "/"; - return 0; - } - - int createDirectory(std::string dir) { - fs::path ndir = fs::weakly_canonical(cwd / dir); - fs::path fdir = fullPath(dir); - if (fdir.string().rfind(root.string(), 0) != 0) return -2; - else if (fs::exists(fdir)) return -1; - else { - fs::create_directory(fdir); - return 0; - } - } - - fs::path fullPath() { - return fs::weakly_canonical(root.string()+"/"+cwd.string()); - } - - fs::path fullPath(std::string in) { - return fs::weakly_canonical(root.string()+"/"+(cwd / in).string()); - } - - fs::path relPath() { - return fs::weakly_canonical("/"+cwd.string()); - } - - fs::path relPath(std::string in) { - return fs::weakly_canonical("/"+(cwd / in).string()); - } - - int fileSize(std::string name) { - fs::path nfile = fs::weakly_canonical(cwd / name); - fs::path ffile = fullPath(name); - logger->print(LOGLEVEL_INFO, "Retreiving filesize: %s", nfile.c_str()); - if (ffile.string().rfind(root.string(), 0) != 0) return -2; - else if (!fs::exists(ffile)) return -1; - else if (type == 'A') return -3; - else { - std::ifstream infile(ffile, std::ios::in|std::ios::binary|std::ios::ate); - if (infile.is_open()) { - return infile.tellg(); - } else return -2; - } - return 0; - } - - int deleteFile(std::string name) { - fs::path nfile = fs::weakly_canonical(cwd / name); - fs::path ffile = fullPath(name); - logger->print(LOGLEVEL_INFO, "Deleting file: %s", nfile.c_str()); - if (ffile.string().rfind(root.string(), 0) != 0) return -2; - else if (!fs::exists(ffile)) return -1; - else { - if (fs::is_directory(ffile) && !fs::is_empty(ffile)) return -5; - else return fs::remove(ffile)?0:-4; - } - } - - file_data readFile(std::string name) { - struct file_data fd; - - fs::path nfile = fs::weakly_canonical(cwd / name); - fs::path ffile = fullPath(name); - logger->print(LOGLEVEL_INFO, "Retreiving file: %s", nfile.c_str()); - if (ffile.string().rfind(root.string(), 0) != 0) fd.ecode = -2; - else if (!fs::exists(ffile)) fd.ecode = -1; - else if (type != 'A') { - std::ifstream infile(ffile, std::ios::in|std::ios::binary|std::ios::ate); - if (infile.is_open()) { - fd.size = infile.tellg(); - fd.data = new char[fd.size]; - infile.seekg(0, std::ios::beg); - infile.read(fd.data, fd.size); - infile.close(); - fd.ecode = 0; - } else fd.ecode = -2; - } else fd.ecode = -3; - return fd; - } - - // Yes, there are two separate functions for essentially the same thing - // but with a different flag. Yes, I could've easily combined the two. - // No, I don't wanna. - int writeFile(std::string name, unsigned char* data, int size) { - fs::path nfile = fs::weakly_canonical(cwd / name); - fs::path ffile = fullPath(name); - logger->print(LOGLEVEL_INFO, "Storing file: %s", nfile.c_str()); - if (ffile.string().rfind(root.string(), 0) != 0) return -2; - else { - std::ofstream outfile(ffile, std::ios::out|std::ios::binary); - outfile.write((char *)data, size); - outfile.close(); - return size; - } - return 0; - } - - int appendFile(std::string name, unsigned char* data, int size) { - fs::path nfile = fs::weakly_canonical(cwd / name); - fs::path ffile = fullPath(name); - if (ffile.string().rfind(root.string(), 0) != 0) return -2; - else { - std::ofstream outfile(ffile, std::ios::out|std::ios::binary|std::ios::app); - outfile.write((char *)data, size); - outfile.close(); - return size; - } - return 0; - } - - std::string list(std::string path) { - fs::path fpath = fullPath(path); - std::ostringstream listStream; - // Not checking for pwd existence. If it doesn't exist and we - // got this far, we fucked up anyway. - for(fs::directory_entry const& p : fs::directory_iterator(fpath)) { - char *line; - - // Stole part of this from here: - // https://github.com/Siim/ftp/blob/master/handles.c#L154 - // It's amazing how simple shit is missing from std::filesystem - // Thanks boost! - struct stat fstat; - struct tm *time; - time_t rawtime; - char timebuff[80]; - - if (stat(p.path().c_str(), &fstat) == -1) { - return ""; - } - - /* Convert time_t to tm struct */ - rawtime = fstat.st_mtime; - time = localtime(&rawtime); - strftime(timebuff, 80, "%b %d %H:%M", time); - - // God should've smitten me before I wrote such attrocities. - fs::perms fperms = fs::status(p).permissions(); - asprintf( - &line, - "%c%c%c%c%c%c%c%c%c%c %u %4u %4u %12u %s %s\r\n", - p.is_directory()?'d':'-', - (fperms & fs::perms::owner_read) != fs::perms::none?'r':'-', - (fperms & fs::perms::owner_write) != fs::perms::none?'w':'-', - (fperms & fs::perms::owner_exec) != fs::perms::none?'x':'-', - (fperms & fs::perms::group_read) != fs::perms::none?'r':'-', - (fperms & fs::perms::group_write) != fs::perms::none?'w':'-', - (fperms & fs::perms::group_exec) != fs::perms::none?'x':'-', - (fperms & fs::perms::others_read) != fs::perms::none?'r':'-', - (fperms & fs::perms::others_write) != fs::perms::none?'w':'-', - (fperms & fs::perms::others_exec) != fs::perms::none?'x':'-', - fs::hard_link_count(p), - fstat.st_uid, - fstat.st_gid, - fstat.st_size, - timebuff, - p.path().filename().c_str() - ); - - listStream << std::string(line); - free(line); - } - return listStream.str(); - } - - // Gets a list of files and folders within current working directory. - // Outputs in the format of 'ls -lA'. - std::string list() { - return list(""); - } - -private: - fs::path root; -}; \ No newline at end of file diff --git a/src/filer.h b/src/filer.h new file mode 100644 index 0000000..80e7d9a --- /dev/null +++ b/src/filer.h @@ -0,0 +1,71 @@ +#ifndef FILER_H +#define FILER_H +#include +#include "logger.h" + +struct file_error { + uint8_t code = 0; + const char* msg = {}; +}; + +struct file_data { + ~file_data() { + delete[] bin; + } + std::shared_ptr stream = nullptr; + const char* path = nullptr; + char* bin = nullptr; + size_t size = 0; + file_error error = {}; +}; + +enum FilerStatusCodes : uint8_t { + Success = 0, + AccessDenied = 249, + FileExists = 250, + DirectoryNotEmpty = 251, + InvalidTransferMode = 252, + NotFound = 253, + NoPermission = 254, + Exception = 255 +}; + +class Filer { +public: + virtual ~Filer() {}; + + static void setLogger(Logger* log) { + logger = log; + } + + // Core operations + virtual file_data setRoot(std::string _root) = 0; + virtual file_data setCWD(std::string _cwd) = 0; + virtual std::filesystem::path getRoot() = 0; + virtual std::filesystem::path getCWD() = 0; + virtual std::filesystem::path resolvePath(const std::string& path) = 0; + virtual void setTransferMode(char type) { + this->transfer_mode = type; + } + + virtual char getTransferMode() { + return this->transfer_mode; + } + + // File operations + virtual file_data traverse(std::string dir) = 0; + virtual file_data createDirectory(std::string dir) = 0; + virtual file_data fileSize(std::string name) = 0; + virtual file_data deleteFile(std::string name) = 0; + virtual file_data readFile(std::string name) = 0; + virtual file_data writeFile(std::string name, unsigned char* data, int size, bool append = false) = 0; + virtual file_data renameFile(const std::string& from, const std::string& to) = 0; + virtual file_data list(std::string path = ".") = 0; +protected: + static Logger* logger; + char transfer_mode = 'I'; +}; + +Logger* Filer::logger = nullptr; + +#endif \ No newline at end of file diff --git a/src/filer_manager.cpp b/src/filer_manager.cpp new file mode 100644 index 0000000..9d51444 --- /dev/null +++ b/src/filer_manager.cpp @@ -0,0 +1,130 @@ +#ifndef FILER_MANAGER_H +#define FILER_MANAGER_H + +#include +#include +#include +#include +#include "filer_plugin.h" +#include "main.h" + +class FilerManager { +public: + static FilerManager& getInstance() { + static FilerManager instance; + return instance; + } + + void setLogger(Logger* log) { + logger = log; + } + + Filer* createFiler(const std::string& type) { + auto it = plugins.find(type); + if (it != plugins.end()) { + return it->second.create(); + } + if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str()); + return nullptr; + } + + bool loadPlugin(const std::string& path) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Loading filer plugin: %s", path.c_str()); + + void* handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load filer plugin %s: %s", + path.c_str(), dlerror()); + return false; + } + + // Load plugin functions + CreateFilerFunc create = (CreateFilerFunc)dlsym(handle, "createFilerPlugin"); + DestroyFilerFunc destroy = (DestroyFilerFunc)dlsym(handle, "destroyFilerPlugin"); + GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion"); + SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger"); + + const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName"); + const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription"); + const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion"); + + if (!create || !destroy || !get_api_version || !setLogger || + !get_name || !get_desc || !get_ver) { + if (logger) { + const char* missing = !create ? "createFilerPlugin" : + !destroy ? "destroyFilerPlugin" : + !get_api_version ? "getAPIVersion" : + !setLogger ? "setLogger" : + !get_name ? "getPluginName" : + !get_desc ? "getPluginDescription" : + "getPluginVersion"; + logger->print(LOGLEVEL_ERROR, "Invalid filer plugin %s: missing required function '%s'", + path.c_str(), missing); + } + dlclose(handle); + return false; + } + + int api_version = get_api_version(); + if (api_version != FILER_PLUGIN_API_VERSION) { + if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible filer plugin API version in %s (got %d, expected %d)", + path.c_str(), api_version, FILER_PLUGIN_API_VERSION); + dlclose(handle); + return false; + } + + std::string plugin_name = get_name(); + setLogger(logger); + + if (plugins.find(plugin_name) != plugins.end()) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str()); + dlclose(handle); + return true; + } + + FilerPluginInfo plugin = { + plugin_name, + get_desc(), + get_ver(), + create, + destroy, + get_api_version, + setLogger, + handle + }; + + plugins[plugin.name] = plugin; + if (logger) logger->print(LOGLEVEL_INFO, "Loaded filer plugin: %s v%s", + plugin.name.c_str(), plugin.version.c_str()); + return true; + } + + CreateFilerFunc getFactory(const std::string& type) { + auto it = plugins.find(type); + if (it != plugins.end()) { + return it->second.create; + } + if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str()); + return nullptr; + } + + void unloadPlugins() { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all filer plugins"); + for (auto& pair : plugins) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str()); + dlclose(pair.second.handle); + } + plugins.clear(); + } + + ~FilerManager() { + unloadPlugins(); + } + +private: + FilerManager() : logger(nullptr) {} + std::map plugins; + Logger* logger; +}; + +#endif \ No newline at end of file diff --git a/src/filer_plugin.h b/src/filer_plugin.h new file mode 100644 index 0000000..4f0fa3a --- /dev/null +++ b/src/filer_plugin.h @@ -0,0 +1,37 @@ +#ifndef FILER_PLUGIN_H +#define FILER_PLUGIN_H +#include +#include "filer.h" + +#define FILER_PLUGIN_API_VERSION 1 + +typedef Filer* (*CreateFilerFunc)(); +typedef void (*DestroyFilerFunc)(Filer*); +typedef int (*GetAPIVersionFunc)(); +typedef void (*SetLoggerFunc)(Logger*); + +struct FilerPluginInfo { + std::string name; + std::string description; + std::string version; + CreateFilerFunc create; + DestroyFilerFunc destroy; + GetAPIVersionFunc get_api_version; + SetLoggerFunc setLogger; + void* handle; +}; + +#define FILER_PLUGIN_EXPORT extern "C" + +#define IMPLEMENT_FILER_PLUGIN(classname, name, description, version) \ + extern "C" { \ + Filer* createFilerPlugin() { return new classname(); } \ + void destroyFilerPlugin(Filer* plugin) { delete plugin; } \ + const char* getPluginName() { return name; } \ + const char* getPluginDescription() { return description; } \ + const char* getPluginVersion() { return version; } \ + int getAPIVersion() { return FILER_PLUGIN_API_VERSION; } \ + void setLogger(Logger* log) { Filer::setLogger(log); } \ + } + +#endif \ No newline at end of file diff --git a/src/globals.h b/src/globals.h new file mode 100644 index 0000000..9a2e7bb --- /dev/null +++ b/src/globals.h @@ -0,0 +1,10 @@ +#ifndef GLOBALS_H +#define GLOBALS_H + +#define CNF_PERCENTSYM_VAR "%%" +#define CNF_VERSION_VAR "%v" +#define CNF_USERNAME_VAR "%u" +#define CNF_PASSWORD_VAR "%p" +#define CNF_HOSTNAME_VAR "%h" + +#endif \ No newline at end of file diff --git a/src/logger.h b/src/logger.h index 6c275df..b43a06e 100644 --- a/src/logger.h +++ b/src/logger.h @@ -1,42 +1,52 @@ #ifndef LOGGER_H #define LOGGER_H +#include #include +#include +#include +#define LOGLEVEL_CONSOLE 5 #define LOGLEVEL_CRITICAL 4 #define LOGLEVEL_ERROR 3 #define LOGLEVEL_WARNING 2 #define LOGLEVEL_INFO 1 #define LOGLEVEL_DEBUG 0 +#define LOGLEVEL_MAX -1 + +std::mutex logMutex; class Logger { + std::mutex log_mutex; public: - Logger(const char* path) { - logfile.open(path, std::ios::binary|std::ios::app); - }; + Logger() {}; - void setLevel(int level) { this->logLevel = level; } + void openFileOnLevel(int level, const char* path) { + if (strlen(path) > 0) + logfiles[level] = std::ofstream(path, std::ios::binary|std::ios::app); + } template void print(int level, const char* message, Args... args) { - if (level >= this->logLevel) { - char* prepared; - char* formatted; - asprintf(&prepared, "[%c] %s\n", logTypeChar(level), message); - asprintf(&formatted, prepared, args...); - fprintf(stdout, formatted); - if (logfile.is_open()) { - logfile.write(formatted, strlen(formatted)); - logfile.flush(); - } + std::lock_guard lock(log_mutex); + char* prepared; + asprintf(&prepared, "[%c] %s\n", logTypeChar(level), message); + char* formatted; + asprintf(&formatted, prepared, args...); + fprintf(stdout, formatted); + if (logMutex.try_lock() && logfiles[level].is_open()) { + logfiles[level].write(formatted, strlen(formatted)); + logfiles[level].flush(); + logMutex.unlock(); } }; void close() { - logfile.close(); + for (auto &f : logfiles) { + f.second.close(); + } } private: - std::ofstream logfile; - int logLevel = 0; + std::map logfiles; static const char logTypeChar(int level) { switch(level) { diff --git a/src/main.cpp b/src/main.cpp index fdab44a..cd57a1f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,6 +7,10 @@ #include #include #include +#include +#include +#include +#include #include #include #include @@ -16,199 +20,502 @@ #include #include "main.h" -#include "server.h" +#include "auth_manager.cpp" #include "client.cpp" #include "util.h" using namespace std::chrono_literals; +std::mutex client_mutex; -const uint16_t max_clients = config->getInt("net", "max_clients", 255); -const uint16_t cport = config->getInt("net", "control_port", 21); - -struct pollfd fds[65535]; -struct clientfd { +struct pollfd fds[MAX_CLIENTS]; +struct ftpconn { Client* client; + std::thread* thread; bool close = false; -} fdc[65535]; +} fdc[MAX_CLIENTS]; -bool run = true, compress_array = false; +void runClient(struct ftpconn* cfd) { + if (!cfd) { + logger->print(LOGLEVEL_ERROR, "Invalid connection handle"); + return; + } + + std::unique_lock lock(client_mutex); + if (!cfd->client) { + logger->print(LOGLEVEL_ERROR, "Invalid client handle"); + return; + } + + int client_sock = cfd->client->control_sock; + Client* client = cfd->client; + lock.unlock(); -void runClient(struct clientfd* cfd) { char inbuf[BUFFERSIZE]; - //printf("[d] C(%i) Initialized\n", cfd->client->control_sock); - // Loop as long as it is a valid file descriptor. - try { - //while (fcntl(cfd->client->control_sock, F_GETFD) != -1) { - while (true) { - if (cfd->client == nullptr) { break; } - int rc = recv(cfd->client->control_sock, inbuf, sizeof(inbuf), 0); - if (rc < 0) { - logger->print(LOGLEVEL_WARNING, "C(%i) Recieved empty packet", cfd->client->control_sock); - if (errno != EWOULDBLOCK) { - perror("recv() failed"); + logger->print(LOGLEVEL_DEBUG, "C(%i) Client initialized", client_sock); + + while (true) { + memset(inbuf, 0, BUFFERSIZE); + + if (fcntl(client_sock, F_GETFD) < 0) { + logger->print(LOGLEVEL_DEBUG, "C(%i) Socket closed", client_sock); + break; + } + + struct timeval tv; + tv.tv_sec = 60; + tv.tv_usec = 0; + + fd_set readfds, writefds; + FD_ZERO(&readfds); + FD_ZERO(&writefds); + FD_SET(client_sock, &readfds); + + // Add socket to writefds if SSL wants to write + if (client->isSecure() && client->getSSL()) { + FD_SET(client_sock, &writefds); + } + + int select_result = select(client_sock + 1, &readfds, &writefds, NULL, &tv); + if (select_result < 0) { + if (errno == EINTR) continue; + logger->print(LOGLEVEL_ERROR, "C(%i) Select failed: %s", client_sock, strerror(errno)); + break; + } + + if (select_result == 0) { + logger->print(LOGLEVEL_DEBUG, "C(%i) Connection timeout", client_sock); + break; + } + + int rc; + if (client->isSecure() && client->getSSL()) { + if (!client->isHandshakeComplete()) { + // Continue SSL handshake + ERR_clear_error(); // Clear any previous errors + int ret = SSL_accept(client->getSSL()); + if (ret <= 0) { + int ssl_err = SSL_get_error(client->getSSL(), ret); + if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { + continue; // Need more data for handshake + } + unsigned long err = ERR_get_error(); + char err_buf[256]; + ERR_error_string_n(err, err_buf, sizeof(err_buf)); + logger->print(LOGLEVEL_ERROR, "C(%i) SSL handshake failed with error: %d (%s)", + client_sock, ssl_err, err_buf); break; } + client->setHandshakeComplete(true); + logger->print(LOGLEVEL_DEBUG, "C(%i) SSL handshake completed", client_sock); continue; + } else { + // Normal SSL read after handshake + ERR_clear_error(); // Clear any previous errors + rc = SSL_read(client->getSSL(), inbuf, sizeof(inbuf) - 1); + if (rc <= 0) { + int ssl_err = SSL_get_error(client->getSSL(), rc); + if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { + continue; + } + if (ssl_err == SSL_ERROR_SYSCALL) { + unsigned long err = ERR_get_error(); + if (err == 0 && rc == 0) { + logger->print(LOGLEVEL_DEBUG, "C(%i) SSL connection closed", client_sock); + } else { + char err_buf[256]; + ERR_error_string_n(err, err_buf, sizeof(err_buf)); + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read syscall error: %s", + client_sock, err_buf); + } + } else { + unsigned long err = ERR_get_error(); + char err_buf[256]; + ERR_error_string_n(err, err_buf, sizeof(err_buf)); + logger->print(LOGLEVEL_ERROR, "C(%i) SSL_read failed with error: %d (%s)", + client_sock, ssl_err, err_buf); + } + break; + } } - - if (rc == 0 || cfd->client == nullptr) { - //printf("[d] C(%i) closed\n", cfd->client->control_sock); + } else { + rc = recv(client_sock, inbuf, sizeof(inbuf) - 1, 0); + if (rc <= 0) { + if (rc == 0) { + logger->print(LOGLEVEL_DEBUG, "C(%i) Client disconnected", client_sock); + } else { + logger->print(LOGLEVEL_ERROR, "C(%i) Recv failed: %s", client_sock, strerror(errno)); + } break; } - - std::string lin(inbuf); - int len = lin.find("\r\n", 0); - int cmdend = lin.find(" ", 0); - if (cmdend >= len || cmdend == std::string::npos) cmdend = len; - std::string cmd = toUpper(lin.substr(0, cmdend)); - std::string args = ""; - if (len > cmdend) args = lin.substr(cmdend+1, len-cmdend-1); - - logger->print(LOGLEVEL_DEBUG, "C(%i) >> '%s' '%s'", cfd->client->control_sock, cmd.c_str(), args.c_str()); - - if (cfd->client->receive(cmd, args) < 0) break; - inbuf[0] = '\0'; } - } catch (...) { - logger->print(LOGLEVEL_ERROR, "C(%i) Caught error!", cfd->client->control_sock); + + inbuf[rc] = '\0'; + if (rc >= 2 && inbuf[rc-2] == '\r' && inbuf[rc-1] == '\n') { + rc -= 2; + inbuf[rc] = '\0'; + } + + std::string input(inbuf, rc); + logger->print(LOGLEVEL_DEBUG, "C(%i) >> %s", client_sock, input.c_str()); + + std::string::size_type space_pos = input.find(" "); + std::string cmd = space_pos != std::string::npos ? + toUpper(input.substr(0, space_pos)) : toUpper(input); + std::string args = space_pos != std::string::npos ? + input.substr(space_pos + 1) : ""; + + lock.lock(); + if (!cfd->client) { + lock.unlock(); + break; + } + + int revc = client->receive(cmd, args); + lock.unlock(); + + if (revc != 0) break; } + + // Mark for cleanup + logger->print(LOGLEVEL_DEBUG, "C(%i) Client thread ending", client_sock); cfd->close = true; } +void initializeAuth() { + AuthManager& auth_manager = AuthManager::getInstance(); + auth_manager.setLogger(logger); + + // Load auth plugins from plugin directory + std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); + + // Load any additional plugins from plugin directory + for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { + if (entry.path().extension() == ".so" && + entry.path().filename().string().find("libauth_") == 0) { + auth_manager.loadPlugin(entry.path().string()); + } + } + + // Create auth instance based on config + std::string auth_type = config->getValue("core", "auth_engine", "pam"); + auth = auth_manager.createAuth(auth_type); + + if (!auth) { + logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str()); + exit(1); + } + + // Initialize auth plugin with config + std::map auth_config = config->get("auth_"+auth_type)->get(); + + if (!auth->initialize(auth_config)) { + logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str()); + exit(1); + } +} + +void initializeFiler() { + FilerManager& filer_manager = FilerManager::getInstance(); + filer_manager.setLogger(logger); + + std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); + + for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { + if (entry.path().extension() == ".so" && + entry.path().filename().string().find("libfiler_") == 0) { + filer_manager.loadPlugin(entry.path().string()); + } + } + + std::string filer_type = config->getValue("core", "filer_engine", "local"); + + default_filer_factory = filer_manager.getFactory(filer_type); + + if (!default_filer_factory) { + logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str()); + exit(1); + } + + std::map filer_config = config->get("filer_"+filer_type)->get(); + + // Test create using the factory + Filer* test_filer = default_filer_factory(); + if (!test_filer) { + logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance"); + exit(1); + } + delete test_filer; +} + int main(int argc , char *argv[]) { - logger = new Logger(config->getValue("logging", "file", "digftp.log").c_str()); - logger->setLevel(config->getInt("logging", "level", 0)); - std::string authType = config->getValue("main", "auth_engine", "plain"); - auth = getAuthByName(authType); - auth->setOptions(config->get(authType)); + printf("%s %s Copyright (C) 2024 Worlio LLC\n", APPNAME, VERSION); + printf("This program comes with ABSOLUTELY NO WARRANTY.\n"); + printf("This is free software, and you are welcome to redistribute it under certain conditions.\n\n"); + + // SIGNALS + signal(SIGPIPE, SIG_IGN); + + config = new ConfigFile(concatPath(std::string(CONFIG_DIR), "ftp.conf")); + server_name = config->getValue("core", "server_name", "digFTP %v"); + + sscanf( + config->getValue("net", "listen", "127.0.0.1:21").c_str(), + "%u.%u.%u.%u:%u", + &server_address[0], + &server_address[1], + &server_address[2], + &server_address[3], + &server_port + ); + + logger = new Logger(); + std::string mainLogFile = config->getValue("logging", "all", ""); + logger->openFileOnLevel(LOGLEVEL_MAX, mainLogFile.c_str()); + logger->openFileOnLevel(LOGLEVEL_DEBUG, config->getValue("logging", "debug", mainLogFile).c_str()); + logger->openFileOnLevel(LOGLEVEL_INFO, config->getValue("logging", "info", mainLogFile).c_str()); + logger->openFileOnLevel(LOGLEVEL_WARNING, config->getValue("logging", "warning", mainLogFile).c_str()); + logger->openFileOnLevel(LOGLEVEL_ERROR, config->getValue("logging", "error", mainLogFile).c_str()); + logger->openFileOnLevel(LOGLEVEL_CRITICAL, config->getValue("logging", "critical", mainLogFile).c_str()); + + if (config->getBool("net", "ssl", false)) { + std::string cert_file = config->getValue("ssl", "certificate", "cert.pem"); + if (cert_file[0] != '/') + cert_file = concatPath(std::string(CONFIG_DIR), cert_file); + logger->print(LOGLEVEL_INFO, "Using certificate file: %s", cert_file.c_str()); + std::string key_file = config->getValue("ssl", "private_key", "key.pem"); + if (key_file[0] != '/') + key_file = concatPath(std::string(CONFIG_DIR), key_file); + logger->print(LOGLEVEL_INFO, "Using private key file: %s", key_file.c_str()); + if (!SSLManager::getInstance().initialize(cert_file, key_file)) { + logger->print(LOGLEVEL_CRITICAL, "Failed to initialize SSL"); + return 1; + } + // set configured ciphers + SSLManager::getInstance().setCiphers(config->getValue("ssl", "ciphers", "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH").c_str()); + // handle configured flags + bool ssl_flags = SSL_OP_NO_TICKET; + if (!config->getBool("ssl", "ssl_v2", false)) + ssl_flags+=SSL_OP_NO_SSLv2; + if (!config->getBool("ssl", "ssl_v3", false)) + ssl_flags+=SSL_OP_NO_SSLv3; + if (!config->getBool("ssl", "tls_v1_0", false)) + ssl_flags+=SSL_OP_NO_TLSv1; + if (!config->getBool("ssl", "tls_v1_1", false)) + ssl_flags+=SSL_OP_NO_TLSv1_1; + if (!config->getBool("ssl", "tls_v1_2", true)) + ssl_flags+=SSL_OP_NO_TLSv1_2; + if (!config->getBool("ssl", "tls_v1_3", true)) + ssl_flags+=SSL_OP_NO_TLSv1_3; + + if ((ssl_flags & SSL_OP_NO_SSLv2) && + (ssl_flags & SSL_OP_NO_SSLv3) && + (ssl_flags & SSL_OP_NO_TLSv1) && + (ssl_flags & SSL_OP_NO_TLSv1_1) && + (ssl_flags & SSL_OP_NO_TLSv1_2) && + (ssl_flags & SSL_OP_NO_TLSv1_3)) { + logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!"); + } + + if (!config->getBool("ssl", "compression", true)) + ssl_flags+=SSL_OP_NO_COMPRESSION; + if (config->getBool("ssl", "prefer_server_ciphers", true)) + ssl_flags+=SSL_OP_CIPHER_SERVER_PREFERENCE; + SSLManager::getInstance().setFlags(ssl_flags); + } + + initializeAuth(); + initializeFiler(); int opt = 1, master_socket = -1, newsock = -1, nfds = 1, - current_size = 0; - + current_size = 0, + src = 0; + runServer = true; + runCompression = false; + struct sockaddr_in ctrl_address; - + if ((master_socket = socket(AF_INET, SOCK_STREAM, 0)) < 0) { - perror("socket() failed"); - exit(-1); + logger->print(LOGLEVEL_CRITICAL, "Failed creating socket"); + return master_socket; } - - if (setsockopt(master_socket, SOL_SOCKET, SO_REUSEADDR, (char *)&opt, sizeof(opt)) < 0) { - perror("setsockopt() failed"); + + if ((src = setsockopt(master_socket, SOL_SOCKET, SO_REUSEADDR, (char *)&opt, sizeof(opt))) < 0) { + logger->print(LOGLEVEL_CRITICAL, "Unable to configure socket"); close(master_socket); - exit(-1); + return src; } - - if (ioctl(master_socket, FIONBIO, (char *)&opt) < 0) { - perror("ioctl() failed"); + + if ((src = ioctl(master_socket, FIONBIO, (char *)&opt)) < 0) { + logger->print(LOGLEVEL_CRITICAL, "Unable to read socket"); close(master_socket); - exit(-1); + return src; } - + ctrl_address.sin_family = AF_INET; ctrl_address.sin_addr.s_addr = INADDR_ANY; - ctrl_address.sin_port = htons(cport); - - if (bind(master_socket, (struct sockaddr *)&ctrl_address, sizeof(ctrl_address))<0 < 0) { - perror("bind() failed"); + ctrl_address.sin_port = htons(server_port); + + if ((src = bind(master_socket, (struct sockaddr *)&ctrl_address, sizeof(ctrl_address))) < 0) { + logger->print( + LOGLEVEL_CRITICAL, + "Bind to %i.%i.%i.%i:%i failed", + &server_address[0], + &server_address[1], + &server_address[2], + &server_address[3], + server_port + ); close(master_socket); - exit(-1); + return src; } - - if (listen(master_socket, 3) < 0) { - perror("listen() failed"); + + if ((src = listen(master_socket, 3)) < 0) { + logger->print(LOGLEVEL_CRITICAL, "Unable to listen to socket"); close(master_socket); - exit(-1); + return src; } - + memset(fds, 0, sizeof(fds)); memset(fdc, 0, sizeof(fdc)); - + fds[0].fd = master_socket; fds[0].events = POLLIN; - + logger->print(LOGLEVEL_INFO, "Server started."); - while (run) { + while (runServer) { int pc = poll(fds, nfds, -1); - + if (pc < 0) { - perror("poll() failed"); + if (errno == EINTR) continue; + logger->print(LOGLEVEL_CRITICAL, "Connection poll faced a fatal error"); break; } - - if (pc == 0) { - perror("poll() timed out\n"); - break; - } - + + if (pc == 0) continue; + current_size = nfds; for (int i = 0; i < current_size; i++) { - if(fds[i].revents == 0) + if (fds[i].revents == 0) continue; - - if(fds[i].revents != POLLIN) { - logger->print(LOGLEVEL_ERROR, "C(%i) Error! revents = %d", fds[i].fd, fds[i].revents); - goto conn_close; + + // Handle poll errors properly without skipping cleanup + if (fds[i].revents != POLLIN) { + if (fds[i].fd != master_socket) { + logger->print(LOGLEVEL_ERROR, "Poll error on fd %d", fds[i].fd); + fdc[i].close = true; + } } - if (fds[i].fd == master_socket) { + + if (fds[i].fd == master_socket && fds[i].revents == POLLIN) { + // Handle new connections do { newsock = accept(master_socket, NULL, NULL); if (newsock < 0) { if (errno != EWOULDBLOCK) { - perror("accept() failed"); - run = false; + logger->print(LOGLEVEL_ERROR, "accept() failed: %s", strerror(errno)); + runServer = false; } break; } - logger->print(LOGLEVEL_DEBUG, "C(%i) Accepted client", newsock); - fds[nfds].fd = newsock; - fds[nfds].events = POLLIN; - fdc[nfds].close = false; - fdc[nfds].client = new Client(newsock); - fdc[nfds].client->thread = std::thread(runClient, &fdc[nfds]); - nfds++; - } while (newsock != -1); - } else { - if (fdc[i].close) { - conn_close: - logger->print(LOGLEVEL_DEBUG, "C(%i) Deleting client...", fds[i].fd); - close(fds[i].fd); - fds[i].fd = -1; - if (fdc[i].client->thread.joinable()) - fdc[i].client->thread.detach(); - fdc[i].client = nullptr; - fdc[i].close = false; - compress_array = true; - } else { - fds[i].revents = 0; - } - } - } - if (compress_array) { - compress_array = false; - for (int i = 0; i < nfds; i++) { - if (fds[i].fd == -1) { - for(int j = i; j < nfds; j++) { - if (fds[j].fd == -1) { - logger->print(LOGLEVEL_DEBUG, "Compressing: %i(fd:%i) <= %i(fd:%i)", j, fds[j].fd, j+1, fds[j+1].fd); - fds[j].fd = fds[j+1].fd; - fds[j].revents = fds[j+1].revents; - fds[j+1].fd = -1; - logger->print(LOGLEVEL_DEBUG, "Reinitialized fds of %i", j+1); - fdc[j].client = fdc[j+1].client; - fdc[j].close = fdc[j+1].close; - fdc[j+1] = {}; - logger->print(LOGLEVEL_DEBUG, "Reinitialized fdc of %i", j+1); + + // Find first available slot + int slot = -1; + for (int j = 1; j < MAX_CLIENTS; j++) { + if (fds[j].fd <= 0) { // Changed from < 0 to <= 0 + slot = j; + break; + } + } + + if (slot < 0) { + logger->print(LOGLEVEL_ERROR, "No free slots available"); + close(newsock); + continue; + } + + // Set non-blocking mode + int flags = fcntl(newsock, F_GETFL, 0); + fcntl(newsock, F_SETFL, flags | O_NONBLOCK); + + logger->print(LOGLEVEL_DEBUG, "C(%i) Accepted client in slot %d", newsock, slot); + + fds[slot].fd = newsock; + fds[slot].events = POLLIN; + fds[slot].revents = 0; + + fdc[slot].client = new Client(newsock); + fdc[slot].client->addOption("SIZE", true); + fdc[slot].client->addOption("UTF8", config->getBool("features", "utf8", true)); + fdc[slot].thread = new std::thread(runClient, &fdc[slot]); + fdc[slot].close = false; + + if (slot >= nfds) { + nfds = slot + 1; + } + + } while (newsock != -1); + } + + // Handle cleanup for any connections marked for closing + if (fds[i].fd != master_socket && (fdc[i].close || fds[i].revents != POLLIN)) { + int fd = fds[i].fd; + logger->print(LOGLEVEL_DEBUG, "C(%i) Cleaning up client in slot %d", fd, i); + + // Close socket + if (fd > 0) { + shutdown(fd, SHUT_RDWR); + close(fd); + } + + // Clean up thread + if (fdc[i].thread) { + if (fdc[i].thread->joinable()) { + fdc[i].thread->join(); + } + delete fdc[i].thread; + } + + // Clean up client + delete fdc[i].client; + + // Reset slot + fds[i].fd = -1; + fds[i].events = 0; + fds[i].revents = 0; + memset(&fdc[i], 0, sizeof(struct ftpconn)); + + logger->print(LOGLEVEL_DEBUG, "C(%i) Cleanup completed", fd); + + // Recalculate nfds if needed + if (i == nfds - 1) { + for (int j = nfds - 1; j >= 0; j--) { + if (fds[j].fd != -1) { + nfds = j + 1; + break; } } - i--; - nfds--; } } } } + + // Cleanup + for (int i = 0; i < nfds; i++) { + if (fds[i].fd >= 0) { + close(fds[i].fd); + if (fdc[i].thread && fdc[i].thread->joinable()) { + fdc[i].thread->join(); + } + delete fdc[i].thread; + delete fdc[i].client; + } + } + + close(master_socket); + logger->print(LOGLEVEL_INFO, "Server closing..."); logger->close(); return 0; -} \ No newline at end of file +} diff --git a/src/main.h b/src/main.h index 631cd37..0cfea10 100644 --- a/src/main.h +++ b/src/main.h @@ -1,17 +1,27 @@ #ifndef MAIN_H #define MAIN_H +#define MAX_CLIENTS 4096 +#define BUFFERSIZE 1024*512 -#define CONF_DIR "conf" -#define LOG_DIR "log" +#include +#include "globals.h" +#include "build.h" #include "conf.h" #include "logger.h" +#include "ssl.h" +#include "auth_manager.cpp" +#include "filer_manager.cpp" -#include "auth.h" -#include "auth/noauth.h" - -ConfigFile* config = new ConfigFile("conf/ftp.conf"); +unsigned char* server_address = new unsigned char[127]; +uint16_t server_port = 21; +std::string server_name; +std::string motd; +ConfigFile* config; Auth* auth; Logger* logger; +static CreateFilerFunc default_filer_factory = nullptr; +bool runServer; +bool runCompression; #endif \ No newline at end of file diff --git a/src/plugins/CMakeLists.txt b/src/plugins/CMakeLists.txt new file mode 100644 index 0000000..07e9418 --- /dev/null +++ b/src/plugins/CMakeLists.txt @@ -0,0 +1,14 @@ +foreach(plugin ${PLUGINS}) + add_library(${plugin} MODULE + ${plugin}/${plugin}.cpp + ) + + target_include_directories(${plugin} PRIVATE ${PROJECT_SOURCE_DIR}/src) + + set_target_properties(${plugin} PROPERTIES + PREFIX "lib" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/plugins" + ) + + add_subdirectory(${plugin}) +endforeach() \ No newline at end of file diff --git a/src/plugins/auth_pam/CMakeLists.txt b/src/plugins/auth_pam/CMakeLists.txt new file mode 100644 index 0000000..6fc36cb --- /dev/null +++ b/src/plugins/auth_pam/CMakeLists.txt @@ -0,0 +1,14 @@ +# Find PAM package +find_package(PAM REQUIRED) + +# For each plugin that needs PAM +target_link_libraries(${plugin} + PRIVATE + PAM::PAM ${CMAKE_DL_LIBS} +) + +# Add PAM include directories +target_include_directories(${plugin} + PRIVATE + ${PAM_INCLUDE_DIR} +) \ No newline at end of file diff --git a/src/plugins/auth_pam/auth_pam.cpp b/src/plugins/auth_pam/auth_pam.cpp new file mode 100644 index 0000000..e3930a4 --- /dev/null +++ b/src/plugins/auth_pam/auth_pam.cpp @@ -0,0 +1,131 @@ +#include "auth_plugin.h" +#include +#include +#include +#include +#include +#include + +class PAMAuthPlugin : public Auth { +private: + std::string service_name; + uid_t user_uid; + gid_t user_gid; + + // PAM Conversation function + static int pam_conv_func(int num_msg, const struct pam_message **msg, + struct pam_response **resp, void *appdata_ptr) { + const char* password = static_cast(appdata_ptr); + if (!password) return PAM_CONV_ERR; + + // Allocate memory for responses + *resp = static_cast(calloc(num_msg, sizeof(struct pam_response))); + if (*resp == nullptr) return PAM_CONV_ERR; + + // Handle messages + for (int i = 0; i < num_msg; i++) { + if (msg[i]->msg_style == PAM_PROMPT_ECHO_OFF) { + (*resp)[i].resp = strdup(password); + if ((*resp)[i].resp == nullptr) { + for (int j = 0; j < i; j++) { + free((*resp)[j].resp); + } + free(*resp); + *resp = nullptr; + return PAM_CONV_ERR; + } + (*resp)[i].resp_retcode = 0; + } else { + (*resp)[i].resp = nullptr; + (*resp)[i].resp_retcode = 0; + } + } + + return PAM_SUCCESS; + } + +public: + PAMAuthPlugin() : service_name("ftpd") {} + + virtual bool initialize(const std::map& config) override { + // Get PAM service name with default + auto service_it = config.find("pam_service"); + service_name = (service_it != config.end()) ? service_it->second : "ftpd"; + + auto chroot_it = config.find("chroot"); + this->setChroot((chroot_it != config.end()) ? + (chroot_it->second == "on" || chroot_it->second == "true" || chroot_it->second == "yes") : + true); + + return true; + } + + virtual bool authenticate(ClientAuthDetails* auth_data) override { + if (!auth_data || !auth_data->username[0] || !auth_data->password[0]) { + logger->print(LOGLEVEL_ERROR, "auth_pam: Cannot use empty auth data"); + return false; + } + + struct pam_conv conv; + conv.conv = pam_conv_func; + conv.appdata_ptr = static_cast(auth_data->password); + + pam_handle_t* pamh = nullptr; + + // Start PAM session + int retval = pam_start(service_name.c_str(), auth_data->username, &conv, &pamh); + if (retval != PAM_SUCCESS) { + logger->print(LOGLEVEL_ERROR, "auth_pam: Failed to start PAM: %s", + pamh ? pam_strerror(pamh, retval) : "Unknown error"); + if (pamh) pam_end(pamh, retval); + return false; + } + + // Authenticate user + retval = pam_authenticate(pamh, 0); + if (retval != PAM_SUCCESS) { + logger->print(LOGLEVEL_ERROR, "auth_pam: Authentication failed: %s", + pam_strerror(pamh, retval)); + pam_end(pamh, retval); + return false; + } + + // Check account validity + retval = pam_acct_mgmt(pamh, 0); + if (retval != PAM_SUCCESS) { + logger->print(LOGLEVEL_ERROR, "auth_pam: Account validation failed: %s", + pam_strerror(pamh, retval)); + pam_end(pamh, retval); + return false; + } + + // End PAM session + pam_end(pamh, PAM_SUCCESS); + + // Get user info + struct passwd* pw = getpwnam(auth_data->username); + if (!pw) { + logger->print(LOGLEVEL_ERROR, "auth_pam: Failed to get user info for %s", + auth_data->username); + return false; + } + + // Store user info + user_uid = pw->pw_uid; + user_gid = pw->pw_gid; + + // Set home directory + if (pw->pw_dir) { + strncpy(auth_data->home_dir, pw->pw_dir, sizeof(auth_data->home_dir) - 1); + auth_data->home_dir[sizeof(auth_data->home_dir) - 1] = '\0'; + } + + return true; + } + + virtual bool isPasswordRequired() override { + return true; + } +}; + +IMPLEMENT_AUTH_PLUGIN(PAMAuthPlugin, "pam", "PAM-based local authentication", "1.0.0") \ No newline at end of file diff --git a/src/plugins/auth_passdb/CMakeLists.txt b/src/plugins/auth_passdb/CMakeLists.txt new file mode 100644 index 0000000..640b71d --- /dev/null +++ b/src/plugins/auth_passdb/CMakeLists.txt @@ -0,0 +1,4 @@ +target_link_libraries(${plugin} + PRIVATE + ${CMAKE_DL_LIBS} +) \ No newline at end of file diff --git a/src/plugins/auth_passdb/auth_passdb.cpp b/src/plugins/auth_passdb/auth_passdb.cpp new file mode 100644 index 0000000..b5a3e66 --- /dev/null +++ b/src/plugins/auth_passdb/auth_passdb.cpp @@ -0,0 +1,75 @@ +#include "auth_plugin.h" +#include +#include +#include +#include +#include + +class PassdbAuthPlugin : public Auth { +private: + std::string db_file = "passdb"; + bool case_sensitive; +public: + PassdbAuthPlugin() : db_file("passdb"), case_sensitive(true) {} + + virtual bool initialize(const std::map& config) override { + auto file_it = config.find("file"); + if (file_it != config.end()) { + db_file = file_it->second; + } + + auto home_it = config.find("home_path"); + if (home_it != config.end()) { + user_directory = home_it->second; + } + + auto case_it = config.find("case_sensitive"); + if (case_it != config.end()) { + case_sensitive = (case_it->second == "on" || case_it->second == "true" || case_it->second == "yes"); + } + + auto chroot_it = config.find("chroot"); + this->setChroot((chroot_it != config.end()) ? + (chroot_it->second == "on" || chroot_it->second == "true" || chroot_it->second == "yes") : + true); + + return true; + } + + virtual bool authenticate(ClientAuthDetails* auth_data) override { + std::ifstream file(db_file); + std::string line; + + while (std::getline(file, line)) { + size_t sep = line.find(':'); + if (sep == std::string::npos) continue; + + std::string user = line.substr(0, sep); + std::string pass = line.substr(sep + 1); + + bool authenticated = false; + if (!case_sensitive) { + std::string auth_user = auth_data->username; + std::transform(user.begin(), user.end(), user.begin(), ::tolower); + std::transform(auth_user.begin(), auth_user.end(), auth_user.begin(), ::tolower); + + authenticated = (user == auth_user && pass == auth_data->password); + } else { + authenticated = (user == auth_data->username && pass == auth_data->password); + } + + if (authenticated) { + // Only set the home directory after successful authentication + setUserDirectory(auth_data); + return true; + } + } + return false; + } + + virtual bool isPasswordRequired() override { + return true; + } +}; + +IMPLEMENT_AUTH_PLUGIN(PassdbAuthPlugin, "passdb", "Password database authentication", "1.0.0") \ No newline at end of file diff --git a/src/plugins/filer_local/CMakeLists.txt b/src/plugins/filer_local/CMakeLists.txt new file mode 100644 index 0000000..640b71d --- /dev/null +++ b/src/plugins/filer_local/CMakeLists.txt @@ -0,0 +1,4 @@ +target_link_libraries(${plugin} + PRIVATE + ${CMAKE_DL_LIBS} +) \ No newline at end of file diff --git a/src/plugins/filer_local/filer_local.cpp b/src/plugins/filer_local/filer_local.cpp new file mode 100644 index 0000000..9b86509 --- /dev/null +++ b/src/plugins/filer_local/filer_local.cpp @@ -0,0 +1,374 @@ +#include "filer_plugin.h" +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +class LocalFiler : public Filer { +public: + LocalFiler() {} + + file_data setRoot(std::string _root) { + struct file_data fd; + try { + // Always convert to absolute path + fs::path new_root = fs::absolute(fs::weakly_canonical(_root)); + fd.path = new_root.string().c_str(); + + // Create directory if it doesn't exist + if (!fs::exists(new_root)) { + if (!fs::create_directories(new_root)) { + fd.error = {FilerStatusCodes::NoPermission, "Failed to create root directory"}; + return fd; + } + } + + // Set the absolute root path + root = new_root; + // Reset CWD to root-relative path + cwd = "/"; + + } catch (const fs::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, "setRoot error: %s", ex.what()); + fd.error = {FilerStatusCodes::Exception, ex.what()}; + } + + return fd; + } + + file_data setCWD(std::string _cwd) { + struct file_data fd; + try { + // Always convert to absolute path + fs::path new_cwd = fs::weakly_canonical(_cwd); + fd.path = new_cwd.string().c_str(); + + // Create directory if it doesn't exist + if (!fs::exists(root / new_cwd)) { + if (!fs::create_directories(root / new_cwd)) { + fd.error = {FilerStatusCodes::NoPermission, "Failed to create cwd directory"}; + return fd; + } + } + + cwd = new_cwd; + } catch (const fs::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, "setCWD error: %s", ex.what()); + fd.error = {FilerStatusCodes::Exception, ex.what()}; + } + + return fd; + } + + fs::path getRoot() { + return this->root; + } + + fs::path getCWD() { + return this->cwd; + } + + fs::path resolvePath(const std::string& path) { + try { + if (path.empty() || path == ".") { + return root / cwd.relative_path(); + } + + fs::path target_path; + if (path[0] == '/') { + // Absolute path relative to FTP root + target_path = root / path.substr(1); + } else { + // Relative to current directory + target_path = root / cwd.relative_path() / path; + } + + // Remove .. and . components + target_path = fs::weakly_canonical(target_path); + + // Make sure we haven't escaped the root + if (!target_path.string().starts_with(root.string())) { + return root / cwd.relative_path(); + } + + return target_path; + } catch (const std::exception& e) { + logger->print(LOGLEVEL_ERROR, "Path resolution error: %s", e.what()); + return root / cwd.relative_path(); + } + } + + file_data traverse(std::string dir) { + struct file_data fd; + try { + // Handle special cases + if (dir.empty() || dir == ".") { + fd.path = resolvePath(".").string().c_str(); + return fd; + } + + fs::path requested_path = resolvePath(dir); + + // Verify the requested path exists and is within root + if (!fs::exists(requested_path)) { + fd.error = file_error{FilerStatusCodes::NotFound, "Directory Not Found"}; + return fd; + } + + if (!fs::is_directory(requested_path)) { + fd.error = file_error{FilerStatusCodes::NotFound, "Not a Directory"}; + return fd; + } + + // Make sure path is within root directory + fs::path rel_path = fs::relative(requested_path, root); + if (rel_path.string().find("..") == 0) { + fd.error = file_error{FilerStatusCodes::NoPermission, "Invalid Permissions"}; + return fd; + } + + // Update current working directory relative to root + cwd = "/" + rel_path.string(); + fd.path = requested_path.string().c_str(); + + } catch (const fs::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, "Path fallback: %s", ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + file_data createDirectory(std::string dir) { + struct file_data fd; + try { + fs::path resolved = resolvePath(dir); + fd.path = resolved.string().c_str(); + + if (fs::exists(resolved)) { + fd.error = file_error{FilerStatusCodes::FileExists, "Directory Already Exists"}; + return fd; + } + + fs::create_directory(resolved); + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + file_data fileSize(std::string name) { + struct file_data fd; + try { + fs::path resolved = resolvePath(name); + fd.path = resolved.string().c_str(); + + if (!fs::exists(resolved)) { + fd.error = file_error{FilerStatusCodes::NotFound, "File Not Found"}; + return fd; + } + + if (type == 'A') { + fd.error = file_error{FilerStatusCodes::InvalidTransferMode, "Refusing to transfer in ASCII mode"}; + return fd; + } + + std::ifstream infile(resolved, std::ios::in|std::ios::binary|std::ios::ate); + if (infile.is_open()) { + fd.size = infile.tellg(); + } else { + fd.error = file_error{FilerStatusCodes::NoPermission, "Unable to open file"}; + } + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + file_data deleteFile(std::string name) { + struct file_data fd; + try { + fs::path resolved = resolvePath(name); + fd.path = resolved.string().c_str(); + + if (!fs::exists(resolved)) { + fd.error = file_error{FilerStatusCodes::NotFound, "File Not Found"}; + return fd; + } + + if (fs::is_directory(resolved) && !fs::is_empty(resolved)) { + fd.error = file_error{FilerStatusCodes::DirectoryNotEmpty, "Directory not empty"}; + return fd; + } + + if (!fs::remove(resolved)) { + fd.error = file_error{FilerStatusCodes::NoPermission, "Unable to delete file"}; + } + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + file_data readFile(std::string name) { + struct file_data fd; + try { + fs::path resolved = resolvePath(name); + fd.path = resolved.string().c_str(); + + if (!fs::exists(resolved)) { + fd.error = file_error{FilerStatusCodes::NotFound, "File Not Found"}; + return fd; + } + + if (type != 'A') { + // Create a shared_ptr to manage the ifstream + fd.stream = std::make_shared(resolved, std::ios::in|std::ios::binary); + + if (fd.stream && fd.stream->is_open()) { + // Get file size + fd.stream->seekg(0, std::ios::end); + fd.size = fd.stream->tellg(); + fd.stream->seekg(0, std::ios::beg); + } else { + fd.error = file_error{FilerStatusCodes::NoPermission, "Unable to open file"}; + } + } else { + fd.error = file_error{FilerStatusCodes::InvalidTransferMode, "Refusing to transfer in ASCII mode"}; + } + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + file_data writeFile(std::string name, unsigned char* data, int size, bool append = false) { + struct file_data fd; + try { + fs::path resolved = resolvePath(name); + fd.path = resolved.string().c_str(); + + std::ios_base::openmode omode = std::ios::out|std::ios::binary; + if (append) omode |= std::ios::app; + + std::ofstream outfile(resolved, omode); + if (outfile.is_open()) { + outfile.write((char *)data, size); + outfile.close(); + fd.size = size; + } else { + fd.error = file_error{FilerStatusCodes::NoPermission, "Unable to open file"}; + } + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } + + struct file_data renameFile(const std::string& from, const std::string& to) { + struct file_data fd; + std::filesystem::path src_path = resolvePath(from); + std::filesystem::path dst_path = resolvePath(to); + + try { + // Check if destination already exists + if (std::filesystem::exists(dst_path)) { + fd.error = {FilerStatusCodes::FileExists, "Destination file already exists"}; + return fd; + } + + // Perform the rename operation + std::filesystem::rename(src_path, dst_path); + fd.error = {0, "OK"}; + + } catch (const std::filesystem::filesystem_error& e) { + fd.error = {FilerStatusCodes::AccessDenied, e.what()}; + } + + return fd; + } + + file_data list(std::string path = ".") { + struct file_data fd; + try { + fs::path resolved = resolvePath(path); + fd.path = resolved.string().c_str(); + + if (!fs::exists(resolved)) { + fd.error = file_error{FilerStatusCodes::NotFound, "File Not Found"}; + return fd; + } + + std::ostringstream listStream; + for(const auto& p : fs::directory_iterator(resolved, + fs::directory_options::follow_directory_symlink | + fs::directory_options::skip_permission_denied)) { + + struct stat fstat; + struct tm *time; + time_t rawtime; + char timebuff[80]; + + if (stat(p.path().c_str(), &fstat) == -1) { + fd.error = file_error{FilerStatusCodes::NoPermission, "Unable to stat file"}; + break; + } + + /* Convert time_t to tm struct */ + rawtime = fstat.st_mtime; + time = localtime(&rawtime); + strftime(timebuff, 80, "%b %d %H:%M", time); + + // God should've smitten me before I wrote such attrocities. + char* line; + fs::perms fperms = fs::status(p).permissions(); + asprintf( + &line, + "%c%c%c%c%c%c%c%c%c%c %4u %4u %4u %12u %s %s\r\n", + p.is_directory()?'d':'-', + (fperms & fs::perms::owner_read) != fs::perms::none?'r':'-', + (fperms & fs::perms::owner_write) != fs::perms::none?'w':'-', + (fperms & fs::perms::owner_exec) != fs::perms::none?'x':'-', + (fperms & fs::perms::group_read) != fs::perms::none?'r':'-', + (fperms & fs::perms::group_write) != fs::perms::none?'w':'-', + (fperms & fs::perms::group_exec) != fs::perms::none?'x':'-', + (fperms & fs::perms::others_read) != fs::perms::none?'r':'-', + (fperms & fs::perms::others_write) != fs::perms::none?'w':'-', + (fperms & fs::perms::others_exec) != fs::perms::none?'x':'-', + fs::hard_link_count(p), + fstat.st_uid, + fstat.st_gid, + fstat.st_size, + timebuff, + p.path().filename().c_str() + ); + + listStream << std::string(line); + free(line); + } + + fd.size = listStream.tellp(); + fd.bin = new char[fd.size]; + std::strncpy(fd.bin, listStream.str().c_str(), fd.size); + + } catch (const std::filesystem::filesystem_error& ex) { + logger->print(LOGLEVEL_ERROR, ex.what()); + fd.error = file_error{FilerStatusCodes::Exception, ex.what()}; + } + return fd; + } +private: + fs::path root; + fs::path cwd; + char type = 'I'; +}; + +IMPLEMENT_FILER_PLUGIN(LocalFiler, "local", "Local filesystem implementation", "1.0.0") \ No newline at end of file diff --git a/src/server.h b/src/server.h deleted file mode 100644 index 949cb53..0000000 --- a/src/server.h +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SERVER_H -#define SERVER_H - -#define APPNAME "digFTP" -#define APPVER "0.0" -#define BUFFERSIZE 1024*512 - -#endif \ No newline at end of file diff --git a/src/ssl.h b/src/ssl.h new file mode 100644 index 0000000..558716a --- /dev/null +++ b/src/ssl.h @@ -0,0 +1,77 @@ +#ifndef SSL_H +#define SSL_H +#include +#include + +class SSLManager { +public: + static SSLManager& getInstance() { + static SSLManager instance; + return instance; + } + + bool initialize(const std::string& cert_file, const std::string& key_file) { + // Initialize OpenSSL + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_all_algorithms(); + + // Create SSL context + ctx = SSL_CTX_new(TLS_server_method()); + if (!ctx) { + ERR_print_errors_fp(stderr); + return false; + } + + // Set SSL session caching + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER); + SSL_CTX_set_timeout(ctx, 300); + + // Set the certificate and private key + if (SSL_CTX_use_certificate_file(ctx, cert_file.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + + if (SSL_CTX_use_PrivateKey_file(ctx, key_file.c_str(), SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + return false; + } + + // Verify private key + if (!SSL_CTX_check_private_key(ctx)) { + fprintf(stderr, "Private key does not match the public certificate\n"); + return false; + } + + return true; + } + + bool setFlags(int flags) { + if (SSL_CTX_set_options(ctx, flags) != 1) { + ERR_print_errors_fp(stderr); + return false; + } + return true; + } + + bool setCiphers(const char* ciphers) { + if (SSL_CTX_set_cipher_list(ctx, ciphers) != 1) { + ERR_print_errors_fp(stderr); + return false; + } + return true; + } + + SSL_CTX* getContext() { return ctx; } + + ~SSLManager() { + if (ctx) SSL_CTX_free(ctx); + EVP_cleanup(); + } + +private: + SSLManager() : ctx(nullptr) {} + SSL_CTX* ctx; +}; +#endif \ No newline at end of file diff --git a/src/util.h b/src/util.h index 678e40c..512eff8 100644 --- a/src/util.h +++ b/src/util.h @@ -8,6 +8,7 @@ #include #include #include +#include std::string replace(std::string subject, const std::string& search, const std::string& replace) { size_t pos = 0; @@ -21,18 +22,18 @@ std::string replace(std::string subject, const std::string& search, const std::s template void split(const std::string &s, char delim, Out result, int limit) { int it = 0; - std::istringstream iss(s); - std::string item; - while (std::getline(iss, item, delim) && it <= limit) { - *result++ = item; - if (limit > 0) it++; - } + std::istringstream iss(s); + std::string item; + while (std::getline(iss, item, delim) && it <= limit) { + *result++ = item; + if (limit > 0) it++; + } } std::vector split(const std::string &s, char delim, int limit) { - std::vector elems; - split(s, delim, std::back_inserter(elems), limit); - return elems; + std::vector elems; + split(s, delim, std::back_inserter(elems), limit); + return elems; } static std::string toLower(std::string str) { @@ -45,6 +46,93 @@ static std::string toUpper(std::string str) { return str; } +static std::vector parseArgs(std::string line) { + if (line.size() == 0) { + return {}; + } + + int state = 0; + std::vector result; + std::string current; + bool lastTokenHasBeenQuoted = false; + bool lastTokenWasEscaped = false; + + for (int i = 0; i < line.size(); i++) { + char nextTok = line[i]; + switch (state) { + case 1: + if (nextTok == '\\') { + lastTokenWasEscaped = true; + } else if (nextTok == '\'' && !lastTokenWasEscaped) { + lastTokenHasBeenQuoted = true; + state = 0; + } else { + if (lastTokenWasEscaped) { + if (nextTok == 't') nextTok = '\t'; + if (nextTok == 'b') nextTok = '\b'; + if (nextTok == 'n') nextTok = '\n'; + if (nextTok == 'r') nextTok = '\r'; + if (nextTok == 'f') nextTok = '\f'; + } + current.push_back(nextTok); + lastTokenWasEscaped = false; + } + break; + case 2: + if (nextTok == '\\') { + lastTokenWasEscaped = true; + } else if (nextTok == '\"' && !lastTokenWasEscaped) { + lastTokenHasBeenQuoted = true; + state = 0; + } else { + if (lastTokenWasEscaped) { + if (nextTok == 't') nextTok = '\t'; + if (nextTok == 'b') nextTok = '\b'; + if (nextTok == 'n') nextTok = '\n'; + if (nextTok == 'r') nextTok = '\r'; + if (nextTok == 'f') nextTok = '\f'; + } + current.push_back(nextTok); + lastTokenWasEscaped = false; + } + break; + default: + switch (nextTok) { + case '\\': lastTokenWasEscaped = true; break; + case '\'': state = 1; break; + case '\"': state = 2; break; + case ' ': + if (!lastTokenWasEscaped && (lastTokenHasBeenQuoted || current.length() != 0)) { + result.push_back(current); + current = ""; + } + break; + default: + if (lastTokenWasEscaped) { + if (nextTok == 't') nextTok = '\t'; + if (nextTok == 'b') nextTok = '\b'; + if (nextTok == 'n') nextTok = '\n'; + if (nextTok == 'r') nextTok = '\r'; + if (nextTok == 'f') nextTok = '\f'; + lastTokenWasEscaped = false; + } + current.push_back(nextTok); + break; + } + lastTokenHasBeenQuoted = false; + break; + } + } + if (lastTokenHasBeenQuoted || current.size() != 0) { + result.push_back(current); + } + return result; +} + +std::string concatPath(std::string dir1, std::string dir2) { + return std::filesystem::weakly_canonical(dir1+"/"+dir2).string(); +} + static char* trim(char *str) { char *end; while(isspace(*str))