From f4d44043fd3b31160cfa073f476c253a4787d784 Mon Sep 17 00:00:00 2001 From: Wirlaburla Date: Sat, 14 Dec 2024 10:09:47 -0600 Subject: [PATCH] Unified plugin handling --- src/auth.h | 8 +- src/auth_manager.cpp | 124 ------------------ src/auth_plugin.h | 37 ------ src/filer.h | 13 +- src/filer_manager.cpp | 130 ------------------- src/filer_plugin.h | 37 ------ src/main.cpp | 78 ++++------- src/main.h | 8 +- src/plugin.h | 49 +++++++ src/plugin_manager.h | 166 ++++++++++++++++++++++++ src/plugin_traits.h | 48 +++++++ src/plugins/auth_pam/auth_pam.cpp | 7 +- src/plugins/auth_passdb/auth_passdb.cpp | 8 +- src/plugins/filer_local/filer_local.cpp | 7 +- 14 files changed, 309 insertions(+), 411 deletions(-) delete mode 100644 src/auth_manager.cpp delete mode 100644 src/auth_plugin.h delete mode 100644 src/filer_manager.cpp delete mode 100644 src/filer_plugin.h create mode 100644 src/plugin.h create mode 100644 src/plugin_manager.h create mode 100644 src/plugin_traits.h diff --git a/src/auth.h b/src/auth.h index c48af90..8240a09 100644 --- a/src/auth.h +++ b/src/auth.h @@ -10,6 +10,8 @@ #include "logger.h" #include "util.h" +#define AUTH_PLUGIN_API_VERSION 1 + typedef struct { char username[255] = {0}; char password[255] = {0}; @@ -57,16 +59,10 @@ public: virtual uint64_t getFreeSpace() { return this->getTotalSpace(); } virtual void setTotalSpace(uint64_t space) { this->max_filespace = space; } - - static void setLogger(Logger* log) { logger = log; } private: bool require_password = false; uint64_t max_filespace = -1; bool chroot = false; -protected: - static Logger* logger; }; -Logger* Auth::logger = nullptr; - #endif \ No newline at end of file diff --git a/src/auth_manager.cpp b/src/auth_manager.cpp deleted file mode 100644 index b2852ca..0000000 --- a/src/auth_manager.cpp +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef AUTH_MANAGER_H -#define AUTH_MANAGER_H - -#include -#include -#include -#include -#include "main.h" -#include "auth_plugin.h" - -class AuthManager { -public: - static AuthManager& getInstance() { - static AuthManager instance; - return instance; - } - - void setLogger(Logger* log) { - logger = log; - } - - Auth* createAuth(const std::string& type) { - auto it = plugins.find(type); - if (it != plugins.end()) { - return it->second.create(); - } - if (logger) logger->print(LOGLEVEL_ERROR, "Auth plugin type '%s' not found", type.c_str()); - return nullptr; - } - - bool loadPlugin(const std::string& path) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Loading auth plugin: %s", path.c_str()); - - void* handle = dlopen(path.c_str(), RTLD_LAZY); - if (!handle) { - if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load auth plugin %s: %s", - path.c_str(), dlerror()); - return false; - } - - // Load plugin functions - CreateAuthFunc create = (CreateAuthFunc)dlsym(handle, "createAuthPlugin"); - DestroyAuthFunc destroy = (DestroyAuthFunc)dlsym(handle, "destroyAuthPlugin"); - GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion"); - SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger"); - - const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName"); - const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription"); - const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion"); - - // Check each required function and log specific missing functions - if (!create || !destroy || !get_api_version || !setLogger || !get_name || !get_desc || !get_ver) { - if (logger) { - const char* missing = !create ? "createAuthPlugin" : - !destroy ? "destroyAuthPlugin" : - !get_api_version ? "getAPIVersion" : - !setLogger ? "setLogger" : - !get_name ? "getPluginName" : - !get_desc ? "getPluginDescription" : - "getPluginVersion"; - logger->print(LOGLEVEL_ERROR, "Invalid auth plugin %s: missing required function '%s'", - path.c_str(), missing); - } - dlclose(handle); - return false; - } - - // Check API version - int api_version = get_api_version(); - if (api_version != AUTH_PLUGIN_API_VERSION) { - if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible auth plugin API version in %s (got %d, expected %d)", - path.c_str(), api_version, AUTH_PLUGIN_API_VERSION); - dlclose(handle); - return false; - } - - std::string plugin_name = get_name(); - setLogger(logger); - - // Check if plugin is already loaded - if (plugins.find(plugin_name) != plugins.end()) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str()); - dlclose(handle); - return true; - } - - // Store plugin info - AuthPluginInfo plugin = { - plugin_name, - get_desc(), - get_ver(), - create, - destroy, - get_api_version, - setLogger, - handle - }; - - plugins[plugin.name] = plugin; - if (logger) logger->print(LOGLEVEL_INFO, "Loaded auth plugin: %s v%s", - plugin.name.c_str(), plugin.version.c_str()); - return true; - } - - void unloadPlugins() { - if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all auth plugins"); - for (auto& pair : plugins) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str()); - dlclose(pair.second.handle); - } - plugins.clear(); - } - - ~AuthManager() { - unloadPlugins(); - } - -private: - AuthManager() : logger(nullptr) {} // Singleton - std::map plugins; - Logger* logger; -}; - -#endif \ No newline at end of file diff --git a/src/auth_plugin.h b/src/auth_plugin.h deleted file mode 100644 index 0e475cf..0000000 --- a/src/auth_plugin.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef AUTH_PLUGIN_H -#define AUTH_PLUGIN_H -#include -#include "auth.h" - -#define AUTH_PLUGIN_API_VERSION 1 - -typedef Auth* (*CreateAuthFunc)(); -typedef void (*DestroyAuthFunc)(Auth*); -typedef int (*GetAPIVersionFunc)(); -typedef void (*SetLoggerFunc)(Logger*); - -struct AuthPluginInfo { - std::string name; - std::string description; - std::string version; - CreateAuthFunc create; - DestroyAuthFunc destroy; - GetAPIVersionFunc get_api_version; - SetLoggerFunc setLogger; - void* handle; -}; - -#define AUTH_PLUGIN_EXPORT extern "C" - -#define IMPLEMENT_AUTH_PLUGIN(classname, name, description, version) \ - extern "C" { \ - Auth* createAuthPlugin() { return new classname(); } \ - void destroyAuthPlugin(Auth* plugin) { delete plugin; } \ - const char* getPluginName() { return name; } \ - const char* getPluginDescription() { return description; } \ - const char* getPluginVersion() { return version; } \ - int getAPIVersion() { return AUTH_PLUGIN_API_VERSION; } \ - void setLogger(Logger* log) { Auth::setLogger(log); } \ - } - -#endif \ No newline at end of file diff --git a/src/filer.h b/src/filer.h index 80e7d9a..36a67b2 100644 --- a/src/filer.h +++ b/src/filer.h @@ -3,6 +3,8 @@ #include #include "logger.h" +#define FILER_PLUGIN_API_VERSION 1 + struct file_error { uint8_t code = 0; const char* msg = {}; @@ -34,11 +36,6 @@ class Filer { public: virtual ~Filer() {}; - static void setLogger(Logger* log) { - logger = log; - } - - // Core operations virtual file_data setRoot(std::string _root) = 0; virtual file_data setCWD(std::string _cwd) = 0; virtual std::filesystem::path getRoot() = 0; @@ -52,7 +49,6 @@ public: return this->transfer_mode; } - // File operations virtual file_data traverse(std::string dir) = 0; virtual file_data createDirectory(std::string dir) = 0; virtual file_data fileSize(std::string name) = 0; @@ -61,11 +57,8 @@ public: virtual file_data writeFile(std::string name, unsigned char* data, int size, bool append = false) = 0; virtual file_data renameFile(const std::string& from, const std::string& to) = 0; virtual file_data list(std::string path = ".") = 0; -protected: - static Logger* logger; +private: char transfer_mode = 'I'; }; -Logger* Filer::logger = nullptr; - #endif \ No newline at end of file diff --git a/src/filer_manager.cpp b/src/filer_manager.cpp deleted file mode 100644 index 9d51444..0000000 --- a/src/filer_manager.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef FILER_MANAGER_H -#define FILER_MANAGER_H - -#include -#include -#include -#include -#include "filer_plugin.h" -#include "main.h" - -class FilerManager { -public: - static FilerManager& getInstance() { - static FilerManager instance; - return instance; - } - - void setLogger(Logger* log) { - logger = log; - } - - Filer* createFiler(const std::string& type) { - auto it = plugins.find(type); - if (it != plugins.end()) { - return it->second.create(); - } - if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str()); - return nullptr; - } - - bool loadPlugin(const std::string& path) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Loading filer plugin: %s", path.c_str()); - - void* handle = dlopen(path.c_str(), RTLD_LAZY); - if (!handle) { - if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load filer plugin %s: %s", - path.c_str(), dlerror()); - return false; - } - - // Load plugin functions - CreateFilerFunc create = (CreateFilerFunc)dlsym(handle, "createFilerPlugin"); - DestroyFilerFunc destroy = (DestroyFilerFunc)dlsym(handle, "destroyFilerPlugin"); - GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion"); - SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger"); - - const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName"); - const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription"); - const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion"); - - if (!create || !destroy || !get_api_version || !setLogger || - !get_name || !get_desc || !get_ver) { - if (logger) { - const char* missing = !create ? "createFilerPlugin" : - !destroy ? "destroyFilerPlugin" : - !get_api_version ? "getAPIVersion" : - !setLogger ? "setLogger" : - !get_name ? "getPluginName" : - !get_desc ? "getPluginDescription" : - "getPluginVersion"; - logger->print(LOGLEVEL_ERROR, "Invalid filer plugin %s: missing required function '%s'", - path.c_str(), missing); - } - dlclose(handle); - return false; - } - - int api_version = get_api_version(); - if (api_version != FILER_PLUGIN_API_VERSION) { - if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible filer plugin API version in %s (got %d, expected %d)", - path.c_str(), api_version, FILER_PLUGIN_API_VERSION); - dlclose(handle); - return false; - } - - std::string plugin_name = get_name(); - setLogger(logger); - - if (plugins.find(plugin_name) != plugins.end()) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str()); - dlclose(handle); - return true; - } - - FilerPluginInfo plugin = { - plugin_name, - get_desc(), - get_ver(), - create, - destroy, - get_api_version, - setLogger, - handle - }; - - plugins[plugin.name] = plugin; - if (logger) logger->print(LOGLEVEL_INFO, "Loaded filer plugin: %s v%s", - plugin.name.c_str(), plugin.version.c_str()); - return true; - } - - CreateFilerFunc getFactory(const std::string& type) { - auto it = plugins.find(type); - if (it != plugins.end()) { - return it->second.create; - } - if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str()); - return nullptr; - } - - void unloadPlugins() { - if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all filer plugins"); - for (auto& pair : plugins) { - if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str()); - dlclose(pair.second.handle); - } - plugins.clear(); - } - - ~FilerManager() { - unloadPlugins(); - } - -private: - FilerManager() : logger(nullptr) {} - std::map plugins; - Logger* logger; -}; - -#endif \ No newline at end of file diff --git a/src/filer_plugin.h b/src/filer_plugin.h deleted file mode 100644 index 4f0fa3a..0000000 --- a/src/filer_plugin.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef FILER_PLUGIN_H -#define FILER_PLUGIN_H -#include -#include "filer.h" - -#define FILER_PLUGIN_API_VERSION 1 - -typedef Filer* (*CreateFilerFunc)(); -typedef void (*DestroyFilerFunc)(Filer*); -typedef int (*GetAPIVersionFunc)(); -typedef void (*SetLoggerFunc)(Logger*); - -struct FilerPluginInfo { - std::string name; - std::string description; - std::string version; - CreateFilerFunc create; - DestroyFilerFunc destroy; - GetAPIVersionFunc get_api_version; - SetLoggerFunc setLogger; - void* handle; -}; - -#define FILER_PLUGIN_EXPORT extern "C" - -#define IMPLEMENT_FILER_PLUGIN(classname, name, description, version) \ - extern "C" { \ - Filer* createFilerPlugin() { return new classname(); } \ - void destroyFilerPlugin(Filer* plugin) { delete plugin; } \ - const char* getPluginName() { return name; } \ - const char* getPluginDescription() { return description; } \ - const char* getPluginVersion() { return version; } \ - int getAPIVersion() { return FILER_PLUGIN_API_VERSION; } \ - void setLogger(Logger* log) { Filer::setLogger(log); } \ - } - -#endif \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index cd57a1f..c91e5a1 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -20,7 +20,6 @@ #include #include "main.h" -#include "auth_manager.cpp" #include "client.cpp" #include "util.h" @@ -181,70 +180,46 @@ void runClient(struct ftpconn* cfd) { cfd->close = true; } -void initializeAuth() { - AuthManager& auth_manager = AuthManager::getInstance(); - auth_manager.setLogger(logger); - - // Load auth plugins from plugin directory +void initializePlugins() { std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); - // Load any additional plugins from plugin directory + // Load auth plugins + auto& auth_manager = PluginManager::getInstance(); + auth_manager.setLogger(logger); + + // Load filer plugins + auto& filer_manager = PluginManager::getInstance(); + filer_manager.setLogger(logger); + for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { - if (entry.path().extension() == ".so" && - entry.path().filename().string().find("libauth_") == 0) { - auth_manager.loadPlugin(entry.path().string()); + if (entry.path().extension() == ".so") { + const std::string& filename = entry.path().filename().string(); + if (filename.find(PluginTraits::pluginPrefix()) == 0) { + auth_manager.loadPlugin(entry.path().string()); + } + else if (filename.find(PluginTraits::pluginPrefix()) == 0) { + filer_manager.loadPlugin(entry.path().string()); + } } } - // Create auth instance based on config + // Initialize auth std::string auth_type = config->getValue("core", "auth_engine", "pam"); - auth = auth_manager.createAuth(auth_type); + auth = auth_manager.createPlugin(auth_type); if (!auth) { logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str()); exit(1); } - // Initialize auth plugin with config - std::map auth_config = config->get("auth_"+auth_type)->get(); - - if (!auth->initialize(auth_config)) { - logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str()); - exit(1); - } -} - -void initializeFiler() { - FilerManager& filer_manager = FilerManager::getInstance(); - filer_manager.setLogger(logger); - - std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR); - - for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) { - if (entry.path().extension() == ".so" && - entry.path().filename().string().find("libfiler_") == 0) { - filer_manager.loadPlugin(entry.path().string()); - } - } - + // Initialize filer std::string filer_type = config->getValue("core", "filer_engine", "local"); - default_filer_factory = filer_manager.getFactory(filer_type); if (!default_filer_factory) { logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str()); exit(1); } - - std::map filer_config = config->get("filer_"+filer_type)->get(); - - // Test create using the factory - Filer* test_filer = default_filer_factory(); - if (!test_filer) { - logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance"); - exit(1); - } - delete test_filer; } int main(int argc , char *argv[]) { @@ -307,14 +282,10 @@ int main(int argc , char *argv[]) { if (!config->getBool("ssl", "tls_v1_3", true)) ssl_flags+=SSL_OP_NO_TLSv1_3; - if ((ssl_flags & SSL_OP_NO_SSLv2) && - (ssl_flags & SSL_OP_NO_SSLv3) && - (ssl_flags & SSL_OP_NO_TLSv1) && - (ssl_flags & SSL_OP_NO_TLSv1_1) && - (ssl_flags & SSL_OP_NO_TLSv1_2) && - (ssl_flags & SSL_OP_NO_TLSv1_3)) { + if ((ssl_flags & SSL_OP_NO_SSLv2) && (ssl_flags & SSL_OP_NO_SSLv3) && + (ssl_flags & SSL_OP_NO_TLSv1) && (ssl_flags & SSL_OP_NO_TLSv1_1) && + (ssl_flags & SSL_OP_NO_TLSv1_2) && (ssl_flags & SSL_OP_NO_TLSv1_3)) logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!"); - } if (!config->getBool("ssl", "compression", true)) ssl_flags+=SSL_OP_NO_COMPRESSION; @@ -323,8 +294,7 @@ int main(int argc , char *argv[]) { SSLManager::getInstance().setFlags(ssl_flags); } - initializeAuth(); - initializeFiler(); + initializePlugins(); int opt = 1, master_socket = -1, diff --git a/src/main.h b/src/main.h index 0cfea10..aa2ae9f 100644 --- a/src/main.h +++ b/src/main.h @@ -10,8 +10,10 @@ #include "conf.h" #include "logger.h" #include "ssl.h" -#include "auth_manager.cpp" -#include "filer_manager.cpp" +#include "auth.h" +#include "filer.h" +#include "plugin_traits.h" +#include "plugin_manager.h" unsigned char* server_address = new unsigned char[127]; uint16_t server_port = 21; @@ -20,7 +22,7 @@ std::string motd; ConfigFile* config; Auth* auth; Logger* logger; -static CreateFilerFunc default_filer_factory = nullptr; +static typename PluginTraits::CreateFunc default_filer_factory = nullptr; bool runServer; bool runCompression; diff --git a/src/plugin.h b/src/plugin.h new file mode 100644 index 0000000..787cbf9 --- /dev/null +++ b/src/plugin.h @@ -0,0 +1,49 @@ +#ifndef PLUGIN_H +#define PLUGIN_H + +#include +#include "logger.h" + +typedef int (*GetAPIVersionFunc)(); +typedef void (*SetLoggerFunc)(Logger*); + +class IPlugin { +public: + virtual ~IPlugin() = default; + static void setLogger(Logger* log) { logger = log; } +protected: + static Logger* logger; +}; + +Logger* IPlugin::logger = nullptr; + +struct PluginInfo { + std::string name; + std::string description; + std::string version; + GetAPIVersionFunc get_api_version; + SetLoggerFunc setLogger; + void* handle; + int api_version; +}; + +template +struct PluginTraits { + static constexpr int API_VERSION = 1; + using CreateFunc = void* (*)(); + using DestroyFunc = void (*)(void*); + using PluginInfoType = PluginInfo; +}; + +#define IMPLEMENT_PLUGIN(PluginClass, BaseClass, name, description, version) \ + extern "C" { \ + BaseClass* create##BaseClass##Plugin() { return new PluginClass(); } \ + void destroy##BaseClass##Plugin(BaseClass* plugin) { delete plugin; } \ + const char* getPluginName() { return name; } \ + const char* getPluginDescription() { return description; } \ + const char* getPluginVersion() { return version; } \ + int getAPIVersion() { return PluginTraits::API_VERSION; } \ + void setLogger(Logger* log) { IPlugin::setLogger(log); } \ + } + +#endif \ No newline at end of file diff --git a/src/plugin_manager.h b/src/plugin_manager.h new file mode 100644 index 0000000..8cdc7a3 --- /dev/null +++ b/src/plugin_manager.h @@ -0,0 +1,166 @@ +#ifndef PLUGIN_MANAGER_H +#define PLUGIN_MANAGER_H + +#include +#include +#include +#include "plugin.h" + +template +class PluginManager { +public: + static PluginManager& getInstance() { + static PluginManager instance; + return instance; + } + + void setLogger(Logger* log) { + logger = log; + IPlugin::setLogger(log); + } + + T* createPlugin(const std::string& type) { + auto it = plugins.find(type); + if (it != plugins.end()) { + return it->second.create(); + } + if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str()); + return nullptr; + } + + bool loadPlugin(const std::string& path) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Loading plugin: %s", path.c_str()); + + void* handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load plugin %s: %s", + path.c_str(), dlerror()); + return false; + } + + // Load common functions + auto create = (typename PluginTraits::CreateFunc)dlsym(handle, PluginTraits::createFuncName()); + auto destroy = (typename PluginTraits::DestroyFunc)dlsym(handle, PluginTraits::destroyFuncName()); + auto get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion"); + auto setLogger = (SetLoggerFunc)dlsym(handle, "setLogger"); + auto get_name = (const char* (*)())dlsym(handle, "getPluginName"); + auto get_desc = (const char* (*)())dlsym(handle, "getPluginDescription"); + auto get_ver = (const char* (*)())dlsym(handle, "getPluginVersion"); + + if (!checkRequiredFunctions(create, destroy, get_api_version, setLogger, get_name, get_desc, get_ver)) { + dlclose(handle); + return false; + } + + if (!checkAPIVersion(get_api_version(), path)) { + dlclose(handle); + return false; + } + + std::string plugin_name = get_name(); + setLogger(logger); + + if (plugins.find(plugin_name) != plugins.end()) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str()); + dlclose(handle); + return true; + } + + auto plugin_info = createPluginInfo(plugin_name, get_desc(), get_ver(), create, destroy, + get_api_version, setLogger, handle); + + plugins[plugin_name] = plugin_info; + if (logger) logger->print(LOGLEVEL_INFO, "Loaded plugin: %s v%s", + plugin_name.c_str(), plugin_info.version.c_str()); + return true; + } + + typename PluginTraits::CreateFunc getFactory(const std::string& type) { + auto it = plugins.find(type); + if (it != plugins.end()) { + return it->second.create; + } + if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str()); + return nullptr; + } + + void unloadPlugins() { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all plugins"); + for (auto& pair : plugins) { + if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str()); + dlclose(pair.second.handle); + } + plugins.clear(); + } + + ~PluginManager() { + unloadPlugins(); + } + +private: + PluginManager() : logger(nullptr) {} + std::map::PluginInfoType> plugins; + Logger* logger; + + bool checkRequiredFunctions( + typename PluginTraits::CreateFunc create, + typename PluginTraits::DestroyFunc destroy, + GetAPIVersionFunc get_api_version, + SetLoggerFunc setLogger, + const char* (*get_name)(), + const char* (*get_desc)(), + const char* (*get_ver)() + ) { + if (!create || !destroy || !get_api_version || !setLogger || + !get_name || !get_desc || !get_ver) { + if (logger) { + const char* missing = !create ? PluginTraits::createFuncName() : + !destroy ? PluginTraits::destroyFuncName() : + !get_api_version ? "getAPIVersion" : + !setLogger ? "setLogger" : + !get_name ? "getPluginName" : + !get_desc ? "getPluginDescription" : + "getPluginVersion"; + logger->print(LOGLEVEL_ERROR, "Missing required function: %s", missing); + } + return false; + } + return true; + } + + bool checkAPIVersion(int version, const std::string& path) { + if (version != PluginTraits::API_VERSION) { + if (logger) { + logger->print(LOGLEVEL_ERROR, + "Incompatible plugin API version in %s (got %d, expected %d)", + path.c_str(), version, PluginTraits::API_VERSION); + } + return false; + } + return true; + } + + typename PluginTraits::PluginInfoType createPluginInfo( + const std::string& name, + const char* description, + const char* version, + typename PluginTraits::CreateFunc create, + typename PluginTraits::DestroyFunc destroy, + GetAPIVersionFunc get_api_version, + SetLoggerFunc setLogger, + void* handle + ) { + typename PluginTraits::PluginInfoType info; + info.name = name; + info.description = description; + info.version = version; + info.create = create; + info.destroy = destroy; + info.get_api_version = get_api_version; + info.setLogger = setLogger; + info.handle = handle; + return info; + } +}; + +#endif \ No newline at end of file diff --git a/src/plugin_traits.h b/src/plugin_traits.h new file mode 100644 index 0000000..054096b --- /dev/null +++ b/src/plugin_traits.h @@ -0,0 +1,48 @@ +#ifndef PLUGIN_TRAITS_H +#define PLUGIN_TRAITS_H + +#include "plugin.h" +#include "auth.h" +#include "filer.h" + +struct AuthPluginInfo : public PluginInfo { + typedef Auth* (*CreateFunc)(); + typedef void (*DestroyFunc)(Auth*); + + CreateFunc create; + DestroyFunc destroy; +}; + +struct FilerPluginInfo : public PluginInfo { + typedef Filer* (*CreateFunc)(); + typedef void (*DestroyFunc)(Filer*); + + CreateFunc create; + DestroyFunc destroy; +}; + +template<> +struct PluginTraits { + static constexpr int API_VERSION = AUTH_PLUGIN_API_VERSION; + using CreateFunc = Auth* (*)(); + using DestroyFunc = void (*)(Auth*); + using PluginInfoType = AuthPluginInfo; + + static const char* createFuncName() { return "createAuthPlugin"; } + static const char* destroyFuncName() { return "destroyAuthPlugin"; } + static const char* pluginPrefix() { return "libauth_"; } +}; + +template<> +struct PluginTraits { + static constexpr int API_VERSION = FILER_PLUGIN_API_VERSION; + using CreateFunc = Filer* (*)(); + using DestroyFunc = void (*)(Filer*); + using PluginInfoType = FilerPluginInfo; + + static const char* createFuncName() { return "createFilerPlugin"; } + static const char* destroyFuncName() { return "destroyFilerPlugin"; } + static const char* pluginPrefix() { return "libfiler_"; } +}; + +#endif \ No newline at end of file diff --git a/src/plugins/auth_pam/auth_pam.cpp b/src/plugins/auth_pam/auth_pam.cpp index e3930a4..f24d9db 100644 --- a/src/plugins/auth_pam/auth_pam.cpp +++ b/src/plugins/auth_pam/auth_pam.cpp @@ -1,4 +1,5 @@ -#include "auth_plugin.h" +#include "plugin.h" +#include "auth.h" #include #include #include @@ -6,7 +7,7 @@ #include #include -class PAMAuthPlugin : public Auth { +class PAMAuthPlugin : public IPlugin, public Auth { private: std::string service_name; uid_t user_uid; @@ -128,4 +129,4 @@ public: } }; -IMPLEMENT_AUTH_PLUGIN(PAMAuthPlugin, "pam", "PAM-based local authentication", "1.0.0") \ No newline at end of file +IMPLEMENT_PLUGIN(PAMAuthPlugin, Auth, "pam", "PAM-based local authentication", "1.0.0") \ No newline at end of file diff --git a/src/plugins/auth_passdb/auth_passdb.cpp b/src/plugins/auth_passdb/auth_passdb.cpp index b5a3e66..da84dd2 100644 --- a/src/plugins/auth_passdb/auth_passdb.cpp +++ b/src/plugins/auth_passdb/auth_passdb.cpp @@ -1,11 +1,12 @@ -#include "auth_plugin.h" +#include "plugin.h" +#include "auth.h" #include #include #include #include #include -class PassdbAuthPlugin : public Auth { +class PassdbAuthPlugin : public IPlugin, public Auth { private: std::string db_file = "passdb"; bool case_sensitive; @@ -59,7 +60,6 @@ public: } if (authenticated) { - // Only set the home directory after successful authentication setUserDirectory(auth_data); return true; } @@ -72,4 +72,4 @@ public: } }; -IMPLEMENT_AUTH_PLUGIN(PassdbAuthPlugin, "passdb", "Password database authentication", "1.0.0") \ No newline at end of file +IMPLEMENT_PLUGIN(PassdbAuthPlugin, Auth, "passdb", "Password database authentication", "1.0.0") \ No newline at end of file diff --git a/src/plugins/filer_local/filer_local.cpp b/src/plugins/filer_local/filer_local.cpp index 9b86509..a4eb678 100644 --- a/src/plugins/filer_local/filer_local.cpp +++ b/src/plugins/filer_local/filer_local.cpp @@ -1,4 +1,5 @@ -#include "filer_plugin.h" +#include "plugin.h" +#include "filer.h" #include #include #include @@ -8,7 +9,7 @@ namespace fs = std::filesystem; -class LocalFiler : public Filer { +class LocalFiler : public IPlugin, public Filer { public: LocalFiler() {} @@ -371,4 +372,4 @@ private: char type = 'I'; }; -IMPLEMENT_FILER_PLUGIN(LocalFiler, "local", "Local filesystem implementation", "1.0.0") \ No newline at end of file +IMPLEMENT_PLUGIN(LocalFiler, Filer, "local", "Local filesystem implementation", "1.0.0") \ No newline at end of file