Unified plugin handling
This commit is contained in:
@@ -10,6 +10,8 @@
|
||||
#include "logger.h"
|
||||
#include "util.h"
|
||||
|
||||
#define AUTH_PLUGIN_API_VERSION 1
|
||||
|
||||
typedef struct {
|
||||
char username[255] = {0};
|
||||
char password[255] = {0};
|
||||
@@ -57,16 +59,10 @@ public:
|
||||
virtual uint64_t getFreeSpace() { return this->getTotalSpace(); }
|
||||
|
||||
virtual void setTotalSpace(uint64_t space) { this->max_filespace = space; }
|
||||
|
||||
static void setLogger(Logger* log) { logger = log; }
|
||||
private:
|
||||
bool require_password = false;
|
||||
uint64_t max_filespace = -1;
|
||||
bool chroot = false;
|
||||
protected:
|
||||
static Logger* logger;
|
||||
};
|
||||
|
||||
Logger* Auth::logger = nullptr;
|
||||
|
||||
#endif
|
||||
@@ -1,124 +0,0 @@
|
||||
#ifndef AUTH_MANAGER_H
|
||||
#define AUTH_MANAGER_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <dlfcn.h>
|
||||
#include "main.h"
|
||||
#include "auth_plugin.h"
|
||||
|
||||
class AuthManager {
|
||||
public:
|
||||
static AuthManager& getInstance() {
|
||||
static AuthManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void setLogger(Logger* log) {
|
||||
logger = log;
|
||||
}
|
||||
|
||||
Auth* createAuth(const std::string& type) {
|
||||
auto it = plugins.find(type);
|
||||
if (it != plugins.end()) {
|
||||
return it->second.create();
|
||||
}
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Auth plugin type '%s' not found", type.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool loadPlugin(const std::string& path) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading auth plugin: %s", path.c_str());
|
||||
|
||||
void* handle = dlopen(path.c_str(), RTLD_LAZY);
|
||||
if (!handle) {
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load auth plugin %s: %s",
|
||||
path.c_str(), dlerror());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load plugin functions
|
||||
CreateAuthFunc create = (CreateAuthFunc)dlsym(handle, "createAuthPlugin");
|
||||
DestroyAuthFunc destroy = (DestroyAuthFunc)dlsym(handle, "destroyAuthPlugin");
|
||||
GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
|
||||
SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
|
||||
|
||||
const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName");
|
||||
const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription");
|
||||
const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion");
|
||||
|
||||
// Check each required function and log specific missing functions
|
||||
if (!create || !destroy || !get_api_version || !setLogger || !get_name || !get_desc || !get_ver) {
|
||||
if (logger) {
|
||||
const char* missing = !create ? "createAuthPlugin" :
|
||||
!destroy ? "destroyAuthPlugin" :
|
||||
!get_api_version ? "getAPIVersion" :
|
||||
!setLogger ? "setLogger" :
|
||||
!get_name ? "getPluginName" :
|
||||
!get_desc ? "getPluginDescription" :
|
||||
"getPluginVersion";
|
||||
logger->print(LOGLEVEL_ERROR, "Invalid auth plugin %s: missing required function '%s'",
|
||||
path.c_str(), missing);
|
||||
}
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check API version
|
||||
int api_version = get_api_version();
|
||||
if (api_version != AUTH_PLUGIN_API_VERSION) {
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible auth plugin API version in %s (got %d, expected %d)",
|
||||
path.c_str(), api_version, AUTH_PLUGIN_API_VERSION);
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string plugin_name = get_name();
|
||||
setLogger(logger);
|
||||
|
||||
// Check if plugin is already loaded
|
||||
if (plugins.find(plugin_name) != plugins.end()) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
|
||||
dlclose(handle);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Store plugin info
|
||||
AuthPluginInfo plugin = {
|
||||
plugin_name,
|
||||
get_desc(),
|
||||
get_ver(),
|
||||
create,
|
||||
destroy,
|
||||
get_api_version,
|
||||
setLogger,
|
||||
handle
|
||||
};
|
||||
|
||||
plugins[plugin.name] = plugin;
|
||||
if (logger) logger->print(LOGLEVEL_INFO, "Loaded auth plugin: %s v%s",
|
||||
plugin.name.c_str(), plugin.version.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
void unloadPlugins() {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all auth plugins");
|
||||
for (auto& pair : plugins) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
|
||||
dlclose(pair.second.handle);
|
||||
}
|
||||
plugins.clear();
|
||||
}
|
||||
|
||||
~AuthManager() {
|
||||
unloadPlugins();
|
||||
}
|
||||
|
||||
private:
|
||||
AuthManager() : logger(nullptr) {} // Singleton
|
||||
std::map<std::string, AuthPluginInfo> plugins;
|
||||
Logger* logger;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,37 +0,0 @@
|
||||
#ifndef AUTH_PLUGIN_H
|
||||
#define AUTH_PLUGIN_H
|
||||
#include <string>
|
||||
#include "auth.h"
|
||||
|
||||
#define AUTH_PLUGIN_API_VERSION 1
|
||||
|
||||
typedef Auth* (*CreateAuthFunc)();
|
||||
typedef void (*DestroyAuthFunc)(Auth*);
|
||||
typedef int (*GetAPIVersionFunc)();
|
||||
typedef void (*SetLoggerFunc)(Logger*);
|
||||
|
||||
struct AuthPluginInfo {
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string version;
|
||||
CreateAuthFunc create;
|
||||
DestroyAuthFunc destroy;
|
||||
GetAPIVersionFunc get_api_version;
|
||||
SetLoggerFunc setLogger;
|
||||
void* handle;
|
||||
};
|
||||
|
||||
#define AUTH_PLUGIN_EXPORT extern "C"
|
||||
|
||||
#define IMPLEMENT_AUTH_PLUGIN(classname, name, description, version) \
|
||||
extern "C" { \
|
||||
Auth* createAuthPlugin() { return new classname(); } \
|
||||
void destroyAuthPlugin(Auth* plugin) { delete plugin; } \
|
||||
const char* getPluginName() { return name; } \
|
||||
const char* getPluginDescription() { return description; } \
|
||||
const char* getPluginVersion() { return version; } \
|
||||
int getAPIVersion() { return AUTH_PLUGIN_API_VERSION; } \
|
||||
void setLogger(Logger* log) { Auth::setLogger(log); } \
|
||||
}
|
||||
|
||||
#endif
|
||||
13
src/filer.h
13
src/filer.h
@@ -3,6 +3,8 @@
|
||||
#include <filesystem>
|
||||
#include "logger.h"
|
||||
|
||||
#define FILER_PLUGIN_API_VERSION 1
|
||||
|
||||
struct file_error {
|
||||
uint8_t code = 0;
|
||||
const char* msg = {};
|
||||
@@ -34,11 +36,6 @@ class Filer {
|
||||
public:
|
||||
virtual ~Filer() {};
|
||||
|
||||
static void setLogger(Logger* log) {
|
||||
logger = log;
|
||||
}
|
||||
|
||||
// Core operations
|
||||
virtual file_data setRoot(std::string _root) = 0;
|
||||
virtual file_data setCWD(std::string _cwd) = 0;
|
||||
virtual std::filesystem::path getRoot() = 0;
|
||||
@@ -52,7 +49,6 @@ public:
|
||||
return this->transfer_mode;
|
||||
}
|
||||
|
||||
// File operations
|
||||
virtual file_data traverse(std::string dir) = 0;
|
||||
virtual file_data createDirectory(std::string dir) = 0;
|
||||
virtual file_data fileSize(std::string name) = 0;
|
||||
@@ -61,11 +57,8 @@ public:
|
||||
virtual file_data writeFile(std::string name, unsigned char* data, int size, bool append = false) = 0;
|
||||
virtual file_data renameFile(const std::string& from, const std::string& to) = 0;
|
||||
virtual file_data list(std::string path = ".") = 0;
|
||||
protected:
|
||||
static Logger* logger;
|
||||
private:
|
||||
char transfer_mode = 'I';
|
||||
};
|
||||
|
||||
Logger* Filer::logger = nullptr;
|
||||
|
||||
#endif
|
||||
@@ -1,130 +0,0 @@
|
||||
#ifndef FILER_MANAGER_H
|
||||
#define FILER_MANAGER_H
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <dlfcn.h>
|
||||
#include "filer_plugin.h"
|
||||
#include "main.h"
|
||||
|
||||
class FilerManager {
|
||||
public:
|
||||
static FilerManager& getInstance() {
|
||||
static FilerManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void setLogger(Logger* log) {
|
||||
logger = log;
|
||||
}
|
||||
|
||||
Filer* createFiler(const std::string& type) {
|
||||
auto it = plugins.find(type);
|
||||
if (it != plugins.end()) {
|
||||
return it->second.create();
|
||||
}
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool loadPlugin(const std::string& path) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading filer plugin: %s", path.c_str());
|
||||
|
||||
void* handle = dlopen(path.c_str(), RTLD_LAZY);
|
||||
if (!handle) {
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load filer plugin %s: %s",
|
||||
path.c_str(), dlerror());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load plugin functions
|
||||
CreateFilerFunc create = (CreateFilerFunc)dlsym(handle, "createFilerPlugin");
|
||||
DestroyFilerFunc destroy = (DestroyFilerFunc)dlsym(handle, "destroyFilerPlugin");
|
||||
GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
|
||||
SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
|
||||
|
||||
const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName");
|
||||
const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription");
|
||||
const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion");
|
||||
|
||||
if (!create || !destroy || !get_api_version || !setLogger ||
|
||||
!get_name || !get_desc || !get_ver) {
|
||||
if (logger) {
|
||||
const char* missing = !create ? "createFilerPlugin" :
|
||||
!destroy ? "destroyFilerPlugin" :
|
||||
!get_api_version ? "getAPIVersion" :
|
||||
!setLogger ? "setLogger" :
|
||||
!get_name ? "getPluginName" :
|
||||
!get_desc ? "getPluginDescription" :
|
||||
"getPluginVersion";
|
||||
logger->print(LOGLEVEL_ERROR, "Invalid filer plugin %s: missing required function '%s'",
|
||||
path.c_str(), missing);
|
||||
}
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
int api_version = get_api_version();
|
||||
if (api_version != FILER_PLUGIN_API_VERSION) {
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible filer plugin API version in %s (got %d, expected %d)",
|
||||
path.c_str(), api_version, FILER_PLUGIN_API_VERSION);
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string plugin_name = get_name();
|
||||
setLogger(logger);
|
||||
|
||||
if (plugins.find(plugin_name) != plugins.end()) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
|
||||
dlclose(handle);
|
||||
return true;
|
||||
}
|
||||
|
||||
FilerPluginInfo plugin = {
|
||||
plugin_name,
|
||||
get_desc(),
|
||||
get_ver(),
|
||||
create,
|
||||
destroy,
|
||||
get_api_version,
|
||||
setLogger,
|
||||
handle
|
||||
};
|
||||
|
||||
plugins[plugin.name] = plugin;
|
||||
if (logger) logger->print(LOGLEVEL_INFO, "Loaded filer plugin: %s v%s",
|
||||
plugin.name.c_str(), plugin.version.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
CreateFilerFunc getFactory(const std::string& type) {
|
||||
auto it = plugins.find(type);
|
||||
if (it != plugins.end()) {
|
||||
return it->second.create;
|
||||
}
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void unloadPlugins() {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all filer plugins");
|
||||
for (auto& pair : plugins) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
|
||||
dlclose(pair.second.handle);
|
||||
}
|
||||
plugins.clear();
|
||||
}
|
||||
|
||||
~FilerManager() {
|
||||
unloadPlugins();
|
||||
}
|
||||
|
||||
private:
|
||||
FilerManager() : logger(nullptr) {}
|
||||
std::map<std::string, FilerPluginInfo> plugins;
|
||||
Logger* logger;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,37 +0,0 @@
|
||||
#ifndef FILER_PLUGIN_H
|
||||
#define FILER_PLUGIN_H
|
||||
#include <string>
|
||||
#include "filer.h"
|
||||
|
||||
#define FILER_PLUGIN_API_VERSION 1
|
||||
|
||||
typedef Filer* (*CreateFilerFunc)();
|
||||
typedef void (*DestroyFilerFunc)(Filer*);
|
||||
typedef int (*GetAPIVersionFunc)();
|
||||
typedef void (*SetLoggerFunc)(Logger*);
|
||||
|
||||
struct FilerPluginInfo {
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string version;
|
||||
CreateFilerFunc create;
|
||||
DestroyFilerFunc destroy;
|
||||
GetAPIVersionFunc get_api_version;
|
||||
SetLoggerFunc setLogger;
|
||||
void* handle;
|
||||
};
|
||||
|
||||
#define FILER_PLUGIN_EXPORT extern "C"
|
||||
|
||||
#define IMPLEMENT_FILER_PLUGIN(classname, name, description, version) \
|
||||
extern "C" { \
|
||||
Filer* createFilerPlugin() { return new classname(); } \
|
||||
void destroyFilerPlugin(Filer* plugin) { delete plugin; } \
|
||||
const char* getPluginName() { return name; } \
|
||||
const char* getPluginDescription() { return description; } \
|
||||
const char* getPluginVersion() { return version; } \
|
||||
int getAPIVersion() { return FILER_PLUGIN_API_VERSION; } \
|
||||
void setLogger(Logger* log) { Filer::setLogger(log); } \
|
||||
}
|
||||
|
||||
#endif
|
||||
78
src/main.cpp
78
src/main.cpp
@@ -20,7 +20,6 @@
|
||||
#include <fcntl.h>
|
||||
|
||||
#include "main.h"
|
||||
#include "auth_manager.cpp"
|
||||
#include "client.cpp"
|
||||
#include "util.h"
|
||||
|
||||
@@ -181,70 +180,46 @@ void runClient(struct ftpconn* cfd) {
|
||||
cfd->close = true;
|
||||
}
|
||||
|
||||
void initializeAuth() {
|
||||
AuthManager& auth_manager = AuthManager::getInstance();
|
||||
auth_manager.setLogger(logger);
|
||||
|
||||
// Load auth plugins from plugin directory
|
||||
void initializePlugins() {
|
||||
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
|
||||
|
||||
// Load any additional plugins from plugin directory
|
||||
// Load auth plugins
|
||||
auto& auth_manager = PluginManager<Auth>::getInstance();
|
||||
auth_manager.setLogger(logger);
|
||||
|
||||
// Load filer plugins
|
||||
auto& filer_manager = PluginManager<Filer>::getInstance();
|
||||
filer_manager.setLogger(logger);
|
||||
|
||||
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
|
||||
if (entry.path().extension() == ".so" &&
|
||||
entry.path().filename().string().find("libauth_") == 0) {
|
||||
auth_manager.loadPlugin(entry.path().string());
|
||||
if (entry.path().extension() == ".so") {
|
||||
const std::string& filename = entry.path().filename().string();
|
||||
if (filename.find(PluginTraits<Auth>::pluginPrefix()) == 0) {
|
||||
auth_manager.loadPlugin(entry.path().string());
|
||||
}
|
||||
else if (filename.find(PluginTraits<Filer>::pluginPrefix()) == 0) {
|
||||
filer_manager.loadPlugin(entry.path().string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create auth instance based on config
|
||||
// Initialize auth
|
||||
std::string auth_type = config->getValue("core", "auth_engine", "pam");
|
||||
auth = auth_manager.createAuth(auth_type);
|
||||
auth = auth_manager.createPlugin(auth_type);
|
||||
|
||||
if (!auth) {
|
||||
logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
// Initialize auth plugin with config
|
||||
std::map<std::string, std::string> auth_config = config->get("auth_"+auth_type)->get();
|
||||
|
||||
if (!auth->initialize(auth_config)) {
|
||||
logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
void initializeFiler() {
|
||||
FilerManager& filer_manager = FilerManager::getInstance();
|
||||
filer_manager.setLogger(logger);
|
||||
|
||||
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
|
||||
|
||||
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
|
||||
if (entry.path().extension() == ".so" &&
|
||||
entry.path().filename().string().find("libfiler_") == 0) {
|
||||
filer_manager.loadPlugin(entry.path().string());
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize filer
|
||||
std::string filer_type = config->getValue("core", "filer_engine", "local");
|
||||
|
||||
default_filer_factory = filer_manager.getFactory(filer_type);
|
||||
|
||||
if (!default_filer_factory) {
|
||||
logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str());
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> filer_config = config->get("filer_"+filer_type)->get();
|
||||
|
||||
// Test create using the factory
|
||||
Filer* test_filer = default_filer_factory();
|
||||
if (!test_filer) {
|
||||
logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance");
|
||||
exit(1);
|
||||
}
|
||||
delete test_filer;
|
||||
}
|
||||
|
||||
int main(int argc , char *argv[]) {
|
||||
@@ -307,14 +282,10 @@ int main(int argc , char *argv[]) {
|
||||
if (!config->getBool("ssl", "tls_v1_3", true))
|
||||
ssl_flags+=SSL_OP_NO_TLSv1_3;
|
||||
|
||||
if ((ssl_flags & SSL_OP_NO_SSLv2) &&
|
||||
(ssl_flags & SSL_OP_NO_SSLv3) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1_1) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1_2) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1_3)) {
|
||||
if ((ssl_flags & SSL_OP_NO_SSLv2) && (ssl_flags & SSL_OP_NO_SSLv3) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1) && (ssl_flags & SSL_OP_NO_TLSv1_1) &&
|
||||
(ssl_flags & SSL_OP_NO_TLSv1_2) && (ssl_flags & SSL_OP_NO_TLSv1_3))
|
||||
logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!");
|
||||
}
|
||||
|
||||
if (!config->getBool("ssl", "compression", true))
|
||||
ssl_flags+=SSL_OP_NO_COMPRESSION;
|
||||
@@ -323,8 +294,7 @@ int main(int argc , char *argv[]) {
|
||||
SSLManager::getInstance().setFlags(ssl_flags);
|
||||
}
|
||||
|
||||
initializeAuth();
|
||||
initializeFiler();
|
||||
initializePlugins();
|
||||
|
||||
int opt = 1,
|
||||
master_socket = -1,
|
||||
|
||||
@@ -10,8 +10,10 @@
|
||||
#include "conf.h"
|
||||
#include "logger.h"
|
||||
#include "ssl.h"
|
||||
#include "auth_manager.cpp"
|
||||
#include "filer_manager.cpp"
|
||||
#include "auth.h"
|
||||
#include "filer.h"
|
||||
#include "plugin_traits.h"
|
||||
#include "plugin_manager.h"
|
||||
|
||||
unsigned char* server_address = new unsigned char[127];
|
||||
uint16_t server_port = 21;
|
||||
@@ -20,7 +22,7 @@ std::string motd;
|
||||
ConfigFile* config;
|
||||
Auth* auth;
|
||||
Logger* logger;
|
||||
static CreateFilerFunc default_filer_factory = nullptr;
|
||||
static typename PluginTraits<Filer>::CreateFunc default_filer_factory = nullptr;
|
||||
bool runServer;
|
||||
bool runCompression;
|
||||
|
||||
|
||||
49
src/plugin.h
Normal file
49
src/plugin.h
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef PLUGIN_H
|
||||
#define PLUGIN_H
|
||||
|
||||
#include <string>
|
||||
#include "logger.h"
|
||||
|
||||
typedef int (*GetAPIVersionFunc)();
|
||||
typedef void (*SetLoggerFunc)(Logger*);
|
||||
|
||||
class IPlugin {
|
||||
public:
|
||||
virtual ~IPlugin() = default;
|
||||
static void setLogger(Logger* log) { logger = log; }
|
||||
protected:
|
||||
static Logger* logger;
|
||||
};
|
||||
|
||||
Logger* IPlugin::logger = nullptr;
|
||||
|
||||
struct PluginInfo {
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string version;
|
||||
GetAPIVersionFunc get_api_version;
|
||||
SetLoggerFunc setLogger;
|
||||
void* handle;
|
||||
int api_version;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct PluginTraits {
|
||||
static constexpr int API_VERSION = 1;
|
||||
using CreateFunc = void* (*)();
|
||||
using DestroyFunc = void (*)(void*);
|
||||
using PluginInfoType = PluginInfo;
|
||||
};
|
||||
|
||||
#define IMPLEMENT_PLUGIN(PluginClass, BaseClass, name, description, version) \
|
||||
extern "C" { \
|
||||
BaseClass* create##BaseClass##Plugin() { return new PluginClass(); } \
|
||||
void destroy##BaseClass##Plugin(BaseClass* plugin) { delete plugin; } \
|
||||
const char* getPluginName() { return name; } \
|
||||
const char* getPluginDescription() { return description; } \
|
||||
const char* getPluginVersion() { return version; } \
|
||||
int getAPIVersion() { return PluginTraits<BaseClass>::API_VERSION; } \
|
||||
void setLogger(Logger* log) { IPlugin::setLogger(log); } \
|
||||
}
|
||||
|
||||
#endif
|
||||
166
src/plugin_manager.h
Normal file
166
src/plugin_manager.h
Normal file
@@ -0,0 +1,166 @@
|
||||
#ifndef PLUGIN_MANAGER_H
|
||||
#define PLUGIN_MANAGER_H
|
||||
|
||||
#include <map>
|
||||
#include <dlfcn.h>
|
||||
#include <filesystem>
|
||||
#include "plugin.h"
|
||||
|
||||
template<typename T>
|
||||
class PluginManager {
|
||||
public:
|
||||
static PluginManager<T>& getInstance() {
|
||||
static PluginManager<T> instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void setLogger(Logger* log) {
|
||||
logger = log;
|
||||
IPlugin::setLogger(log);
|
||||
}
|
||||
|
||||
T* createPlugin(const std::string& type) {
|
||||
auto it = plugins.find(type);
|
||||
if (it != plugins.end()) {
|
||||
return it->second.create();
|
||||
}
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool loadPlugin(const std::string& path) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading plugin: %s", path.c_str());
|
||||
|
||||
void* handle = dlopen(path.c_str(), RTLD_LAZY);
|
||||
if (!handle) {
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load plugin %s: %s",
|
||||
path.c_str(), dlerror());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Load common functions
|
||||
auto create = (typename PluginTraits<T>::CreateFunc)dlsym(handle, PluginTraits<T>::createFuncName());
|
||||
auto destroy = (typename PluginTraits<T>::DestroyFunc)dlsym(handle, PluginTraits<T>::destroyFuncName());
|
||||
auto get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
|
||||
auto setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
|
||||
auto get_name = (const char* (*)())dlsym(handle, "getPluginName");
|
||||
auto get_desc = (const char* (*)())dlsym(handle, "getPluginDescription");
|
||||
auto get_ver = (const char* (*)())dlsym(handle, "getPluginVersion");
|
||||
|
||||
if (!checkRequiredFunctions(create, destroy, get_api_version, setLogger, get_name, get_desc, get_ver)) {
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!checkAPIVersion(get_api_version(), path)) {
|
||||
dlclose(handle);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string plugin_name = get_name();
|
||||
setLogger(logger);
|
||||
|
||||
if (plugins.find(plugin_name) != plugins.end()) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
|
||||
dlclose(handle);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto plugin_info = createPluginInfo(plugin_name, get_desc(), get_ver(), create, destroy,
|
||||
get_api_version, setLogger, handle);
|
||||
|
||||
plugins[plugin_name] = plugin_info;
|
||||
if (logger) logger->print(LOGLEVEL_INFO, "Loaded plugin: %s v%s",
|
||||
plugin_name.c_str(), plugin_info.version.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
typename PluginTraits<T>::CreateFunc getFactory(const std::string& type) {
|
||||
auto it = plugins.find(type);
|
||||
if (it != plugins.end()) {
|
||||
return it->second.create;
|
||||
}
|
||||
if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void unloadPlugins() {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all plugins");
|
||||
for (auto& pair : plugins) {
|
||||
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
|
||||
dlclose(pair.second.handle);
|
||||
}
|
||||
plugins.clear();
|
||||
}
|
||||
|
||||
~PluginManager() {
|
||||
unloadPlugins();
|
||||
}
|
||||
|
||||
private:
|
||||
PluginManager() : logger(nullptr) {}
|
||||
std::map<std::string, typename PluginTraits<T>::PluginInfoType> plugins;
|
||||
Logger* logger;
|
||||
|
||||
bool checkRequiredFunctions(
|
||||
typename PluginTraits<T>::CreateFunc create,
|
||||
typename PluginTraits<T>::DestroyFunc destroy,
|
||||
GetAPIVersionFunc get_api_version,
|
||||
SetLoggerFunc setLogger,
|
||||
const char* (*get_name)(),
|
||||
const char* (*get_desc)(),
|
||||
const char* (*get_ver)()
|
||||
) {
|
||||
if (!create || !destroy || !get_api_version || !setLogger ||
|
||||
!get_name || !get_desc || !get_ver) {
|
||||
if (logger) {
|
||||
const char* missing = !create ? PluginTraits<T>::createFuncName() :
|
||||
!destroy ? PluginTraits<T>::destroyFuncName() :
|
||||
!get_api_version ? "getAPIVersion" :
|
||||
!setLogger ? "setLogger" :
|
||||
!get_name ? "getPluginName" :
|
||||
!get_desc ? "getPluginDescription" :
|
||||
"getPluginVersion";
|
||||
logger->print(LOGLEVEL_ERROR, "Missing required function: %s", missing);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool checkAPIVersion(int version, const std::string& path) {
|
||||
if (version != PluginTraits<T>::API_VERSION) {
|
||||
if (logger) {
|
||||
logger->print(LOGLEVEL_ERROR,
|
||||
"Incompatible plugin API version in %s (got %d, expected %d)",
|
||||
path.c_str(), version, PluginTraits<T>::API_VERSION);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
typename PluginTraits<T>::PluginInfoType createPluginInfo(
|
||||
const std::string& name,
|
||||
const char* description,
|
||||
const char* version,
|
||||
typename PluginTraits<T>::CreateFunc create,
|
||||
typename PluginTraits<T>::DestroyFunc destroy,
|
||||
GetAPIVersionFunc get_api_version,
|
||||
SetLoggerFunc setLogger,
|
||||
void* handle
|
||||
) {
|
||||
typename PluginTraits<T>::PluginInfoType info;
|
||||
info.name = name;
|
||||
info.description = description;
|
||||
info.version = version;
|
||||
info.create = create;
|
||||
info.destroy = destroy;
|
||||
info.get_api_version = get_api_version;
|
||||
info.setLogger = setLogger;
|
||||
info.handle = handle;
|
||||
return info;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
48
src/plugin_traits.h
Normal file
48
src/plugin_traits.h
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef PLUGIN_TRAITS_H
|
||||
#define PLUGIN_TRAITS_H
|
||||
|
||||
#include "plugin.h"
|
||||
#include "auth.h"
|
||||
#include "filer.h"
|
||||
|
||||
struct AuthPluginInfo : public PluginInfo {
|
||||
typedef Auth* (*CreateFunc)();
|
||||
typedef void (*DestroyFunc)(Auth*);
|
||||
|
||||
CreateFunc create;
|
||||
DestroyFunc destroy;
|
||||
};
|
||||
|
||||
struct FilerPluginInfo : public PluginInfo {
|
||||
typedef Filer* (*CreateFunc)();
|
||||
typedef void (*DestroyFunc)(Filer*);
|
||||
|
||||
CreateFunc create;
|
||||
DestroyFunc destroy;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PluginTraits<Auth> {
|
||||
static constexpr int API_VERSION = AUTH_PLUGIN_API_VERSION;
|
||||
using CreateFunc = Auth* (*)();
|
||||
using DestroyFunc = void (*)(Auth*);
|
||||
using PluginInfoType = AuthPluginInfo;
|
||||
|
||||
static const char* createFuncName() { return "createAuthPlugin"; }
|
||||
static const char* destroyFuncName() { return "destroyAuthPlugin"; }
|
||||
static const char* pluginPrefix() { return "libauth_"; }
|
||||
};
|
||||
|
||||
template<>
|
||||
struct PluginTraits<Filer> {
|
||||
static constexpr int API_VERSION = FILER_PLUGIN_API_VERSION;
|
||||
using CreateFunc = Filer* (*)();
|
||||
using DestroyFunc = void (*)(Filer*);
|
||||
using PluginInfoType = FilerPluginInfo;
|
||||
|
||||
static const char* createFuncName() { return "createFilerPlugin"; }
|
||||
static const char* destroyFuncName() { return "destroyFilerPlugin"; }
|
||||
static const char* pluginPrefix() { return "libfiler_"; }
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "auth_plugin.h"
|
||||
#include "plugin.h"
|
||||
#include "auth.h"
|
||||
#include <security/pam_appl.h>
|
||||
#include <pwd.h>
|
||||
#include <grp.h>
|
||||
@@ -6,7 +7,7 @@
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
class PAMAuthPlugin : public Auth {
|
||||
class PAMAuthPlugin : public IPlugin, public Auth {
|
||||
private:
|
||||
std::string service_name;
|
||||
uid_t user_uid;
|
||||
@@ -128,4 +129,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
IMPLEMENT_AUTH_PLUGIN(PAMAuthPlugin, "pam", "PAM-based local authentication", "1.0.0")
|
||||
IMPLEMENT_PLUGIN(PAMAuthPlugin, Auth, "pam", "PAM-based local authentication", "1.0.0")
|
||||
@@ -1,11 +1,12 @@
|
||||
#include "auth_plugin.h"
|
||||
#include "plugin.h"
|
||||
#include "auth.h"
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
class PassdbAuthPlugin : public Auth {
|
||||
class PassdbAuthPlugin : public IPlugin, public Auth {
|
||||
private:
|
||||
std::string db_file = "passdb";
|
||||
bool case_sensitive;
|
||||
@@ -59,7 +60,6 @@ public:
|
||||
}
|
||||
|
||||
if (authenticated) {
|
||||
// Only set the home directory after successful authentication
|
||||
setUserDirectory(auth_data);
|
||||
return true;
|
||||
}
|
||||
@@ -72,4 +72,4 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
IMPLEMENT_AUTH_PLUGIN(PassdbAuthPlugin, "passdb", "Password database authentication", "1.0.0")
|
||||
IMPLEMENT_PLUGIN(PassdbAuthPlugin, Auth, "passdb", "Password database authentication", "1.0.0")
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "filer_plugin.h"
|
||||
#include "plugin.h"
|
||||
#include "filer.h"
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
@@ -8,7 +9,7 @@
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
class LocalFiler : public Filer {
|
||||
class LocalFiler : public IPlugin, public Filer {
|
||||
public:
|
||||
LocalFiler() {}
|
||||
|
||||
@@ -371,4 +372,4 @@ private:
|
||||
char type = 'I';
|
||||
};
|
||||
|
||||
IMPLEMENT_FILER_PLUGIN(LocalFiler, "local", "Local filesystem implementation", "1.0.0")
|
||||
IMPLEMENT_PLUGIN(LocalFiler, Filer, "local", "Local filesystem implementation", "1.0.0")
|
||||
Reference in New Issue
Block a user