Unified plugin handling

This commit is contained in:
2024-12-14 10:09:47 -06:00
parent 3ec6522741
commit f4d44043fd
14 changed files with 309 additions and 411 deletions

View File

@@ -10,6 +10,8 @@
#include "logger.h"
#include "util.h"
#define AUTH_PLUGIN_API_VERSION 1
typedef struct {
char username[255] = {0};
char password[255] = {0};
@@ -57,16 +59,10 @@ public:
virtual uint64_t getFreeSpace() { return this->getTotalSpace(); }
virtual void setTotalSpace(uint64_t space) { this->max_filespace = space; }
static void setLogger(Logger* log) { logger = log; }
private:
bool require_password = false;
uint64_t max_filespace = -1;
bool chroot = false;
protected:
static Logger* logger;
};
Logger* Auth::logger = nullptr;
#endif

View File

@@ -1,124 +0,0 @@
#ifndef AUTH_MANAGER_H
#define AUTH_MANAGER_H
#include <map>
#include <vector>
#include <string>
#include <dlfcn.h>
#include "main.h"
#include "auth_plugin.h"
class AuthManager {
public:
static AuthManager& getInstance() {
static AuthManager instance;
return instance;
}
void setLogger(Logger* log) {
logger = log;
}
Auth* createAuth(const std::string& type) {
auto it = plugins.find(type);
if (it != plugins.end()) {
return it->second.create();
}
if (logger) logger->print(LOGLEVEL_ERROR, "Auth plugin type '%s' not found", type.c_str());
return nullptr;
}
bool loadPlugin(const std::string& path) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading auth plugin: %s", path.c_str());
void* handle = dlopen(path.c_str(), RTLD_LAZY);
if (!handle) {
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load auth plugin %s: %s",
path.c_str(), dlerror());
return false;
}
// Load plugin functions
CreateAuthFunc create = (CreateAuthFunc)dlsym(handle, "createAuthPlugin");
DestroyAuthFunc destroy = (DestroyAuthFunc)dlsym(handle, "destroyAuthPlugin");
GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName");
const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription");
const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion");
// Check each required function and log specific missing functions
if (!create || !destroy || !get_api_version || !setLogger || !get_name || !get_desc || !get_ver) {
if (logger) {
const char* missing = !create ? "createAuthPlugin" :
!destroy ? "destroyAuthPlugin" :
!get_api_version ? "getAPIVersion" :
!setLogger ? "setLogger" :
!get_name ? "getPluginName" :
!get_desc ? "getPluginDescription" :
"getPluginVersion";
logger->print(LOGLEVEL_ERROR, "Invalid auth plugin %s: missing required function '%s'",
path.c_str(), missing);
}
dlclose(handle);
return false;
}
// Check API version
int api_version = get_api_version();
if (api_version != AUTH_PLUGIN_API_VERSION) {
if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible auth plugin API version in %s (got %d, expected %d)",
path.c_str(), api_version, AUTH_PLUGIN_API_VERSION);
dlclose(handle);
return false;
}
std::string plugin_name = get_name();
setLogger(logger);
// Check if plugin is already loaded
if (plugins.find(plugin_name) != plugins.end()) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
dlclose(handle);
return true;
}
// Store plugin info
AuthPluginInfo plugin = {
plugin_name,
get_desc(),
get_ver(),
create,
destroy,
get_api_version,
setLogger,
handle
};
plugins[plugin.name] = plugin;
if (logger) logger->print(LOGLEVEL_INFO, "Loaded auth plugin: %s v%s",
plugin.name.c_str(), plugin.version.c_str());
return true;
}
void unloadPlugins() {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all auth plugins");
for (auto& pair : plugins) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
dlclose(pair.second.handle);
}
plugins.clear();
}
~AuthManager() {
unloadPlugins();
}
private:
AuthManager() : logger(nullptr) {} // Singleton
std::map<std::string, AuthPluginInfo> plugins;
Logger* logger;
};
#endif

View File

@@ -1,37 +0,0 @@
#ifndef AUTH_PLUGIN_H
#define AUTH_PLUGIN_H
#include <string>
#include "auth.h"
#define AUTH_PLUGIN_API_VERSION 1
typedef Auth* (*CreateAuthFunc)();
typedef void (*DestroyAuthFunc)(Auth*);
typedef int (*GetAPIVersionFunc)();
typedef void (*SetLoggerFunc)(Logger*);
struct AuthPluginInfo {
std::string name;
std::string description;
std::string version;
CreateAuthFunc create;
DestroyAuthFunc destroy;
GetAPIVersionFunc get_api_version;
SetLoggerFunc setLogger;
void* handle;
};
#define AUTH_PLUGIN_EXPORT extern "C"
#define IMPLEMENT_AUTH_PLUGIN(classname, name, description, version) \
extern "C" { \
Auth* createAuthPlugin() { return new classname(); } \
void destroyAuthPlugin(Auth* plugin) { delete plugin; } \
const char* getPluginName() { return name; } \
const char* getPluginDescription() { return description; } \
const char* getPluginVersion() { return version; } \
int getAPIVersion() { return AUTH_PLUGIN_API_VERSION; } \
void setLogger(Logger* log) { Auth::setLogger(log); } \
}
#endif

View File

@@ -3,6 +3,8 @@
#include <filesystem>
#include "logger.h"
#define FILER_PLUGIN_API_VERSION 1
struct file_error {
uint8_t code = 0;
const char* msg = {};
@@ -34,11 +36,6 @@ class Filer {
public:
virtual ~Filer() {};
static void setLogger(Logger* log) {
logger = log;
}
// Core operations
virtual file_data setRoot(std::string _root) = 0;
virtual file_data setCWD(std::string _cwd) = 0;
virtual std::filesystem::path getRoot() = 0;
@@ -52,7 +49,6 @@ public:
return this->transfer_mode;
}
// File operations
virtual file_data traverse(std::string dir) = 0;
virtual file_data createDirectory(std::string dir) = 0;
virtual file_data fileSize(std::string name) = 0;
@@ -61,11 +57,8 @@ public:
virtual file_data writeFile(std::string name, unsigned char* data, int size, bool append = false) = 0;
virtual file_data renameFile(const std::string& from, const std::string& to) = 0;
virtual file_data list(std::string path = ".") = 0;
protected:
static Logger* logger;
private:
char transfer_mode = 'I';
};
Logger* Filer::logger = nullptr;
#endif

View File

@@ -1,130 +0,0 @@
#ifndef FILER_MANAGER_H
#define FILER_MANAGER_H
#include <map>
#include <vector>
#include <string>
#include <dlfcn.h>
#include "filer_plugin.h"
#include "main.h"
class FilerManager {
public:
static FilerManager& getInstance() {
static FilerManager instance;
return instance;
}
void setLogger(Logger* log) {
logger = log;
}
Filer* createFiler(const std::string& type) {
auto it = plugins.find(type);
if (it != plugins.end()) {
return it->second.create();
}
if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str());
return nullptr;
}
bool loadPlugin(const std::string& path) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading filer plugin: %s", path.c_str());
void* handle = dlopen(path.c_str(), RTLD_LAZY);
if (!handle) {
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load filer plugin %s: %s",
path.c_str(), dlerror());
return false;
}
// Load plugin functions
CreateFilerFunc create = (CreateFilerFunc)dlsym(handle, "createFilerPlugin");
DestroyFilerFunc destroy = (DestroyFilerFunc)dlsym(handle, "destroyFilerPlugin");
GetAPIVersionFunc get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
SetLoggerFunc setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
const char* (*get_name)() = (const char* (*)())dlsym(handle, "getPluginName");
const char* (*get_desc)() = (const char* (*)())dlsym(handle, "getPluginDescription");
const char* (*get_ver)() = (const char* (*)())dlsym(handle, "getPluginVersion");
if (!create || !destroy || !get_api_version || !setLogger ||
!get_name || !get_desc || !get_ver) {
if (logger) {
const char* missing = !create ? "createFilerPlugin" :
!destroy ? "destroyFilerPlugin" :
!get_api_version ? "getAPIVersion" :
!setLogger ? "setLogger" :
!get_name ? "getPluginName" :
!get_desc ? "getPluginDescription" :
"getPluginVersion";
logger->print(LOGLEVEL_ERROR, "Invalid filer plugin %s: missing required function '%s'",
path.c_str(), missing);
}
dlclose(handle);
return false;
}
int api_version = get_api_version();
if (api_version != FILER_PLUGIN_API_VERSION) {
if (logger) logger->print(LOGLEVEL_ERROR, "Incompatible filer plugin API version in %s (got %d, expected %d)",
path.c_str(), api_version, FILER_PLUGIN_API_VERSION);
dlclose(handle);
return false;
}
std::string plugin_name = get_name();
setLogger(logger);
if (plugins.find(plugin_name) != plugins.end()) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
dlclose(handle);
return true;
}
FilerPluginInfo plugin = {
plugin_name,
get_desc(),
get_ver(),
create,
destroy,
get_api_version,
setLogger,
handle
};
plugins[plugin.name] = plugin;
if (logger) logger->print(LOGLEVEL_INFO, "Loaded filer plugin: %s v%s",
plugin.name.c_str(), plugin.version.c_str());
return true;
}
CreateFilerFunc getFactory(const std::string& type) {
auto it = plugins.find(type);
if (it != plugins.end()) {
return it->second.create;
}
if (logger) logger->print(LOGLEVEL_ERROR, "Filer plugin type '%s' not found", type.c_str());
return nullptr;
}
void unloadPlugins() {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all filer plugins");
for (auto& pair : plugins) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
dlclose(pair.second.handle);
}
plugins.clear();
}
~FilerManager() {
unloadPlugins();
}
private:
FilerManager() : logger(nullptr) {}
std::map<std::string, FilerPluginInfo> plugins;
Logger* logger;
};
#endif

View File

@@ -1,37 +0,0 @@
#ifndef FILER_PLUGIN_H
#define FILER_PLUGIN_H
#include <string>
#include "filer.h"
#define FILER_PLUGIN_API_VERSION 1
typedef Filer* (*CreateFilerFunc)();
typedef void (*DestroyFilerFunc)(Filer*);
typedef int (*GetAPIVersionFunc)();
typedef void (*SetLoggerFunc)(Logger*);
struct FilerPluginInfo {
std::string name;
std::string description;
std::string version;
CreateFilerFunc create;
DestroyFilerFunc destroy;
GetAPIVersionFunc get_api_version;
SetLoggerFunc setLogger;
void* handle;
};
#define FILER_PLUGIN_EXPORT extern "C"
#define IMPLEMENT_FILER_PLUGIN(classname, name, description, version) \
extern "C" { \
Filer* createFilerPlugin() { return new classname(); } \
void destroyFilerPlugin(Filer* plugin) { delete plugin; } \
const char* getPluginName() { return name; } \
const char* getPluginDescription() { return description; } \
const char* getPluginVersion() { return version; } \
int getAPIVersion() { return FILER_PLUGIN_API_VERSION; } \
void setLogger(Logger* log) { Filer::setLogger(log); } \
}
#endif

View File

@@ -20,7 +20,6 @@
#include <fcntl.h>
#include "main.h"
#include "auth_manager.cpp"
#include "client.cpp"
#include "util.h"
@@ -181,70 +180,46 @@ void runClient(struct ftpconn* cfd) {
cfd->close = true;
}
void initializeAuth() {
AuthManager& auth_manager = AuthManager::getInstance();
auth_manager.setLogger(logger);
// Load auth plugins from plugin directory
void initializePlugins() {
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
// Load any additional plugins from plugin directory
// Load auth plugins
auto& auth_manager = PluginManager<Auth>::getInstance();
auth_manager.setLogger(logger);
// Load filer plugins
auto& filer_manager = PluginManager<Filer>::getInstance();
filer_manager.setLogger(logger);
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
if (entry.path().extension() == ".so" &&
entry.path().filename().string().find("libauth_") == 0) {
auth_manager.loadPlugin(entry.path().string());
if (entry.path().extension() == ".so") {
const std::string& filename = entry.path().filename().string();
if (filename.find(PluginTraits<Auth>::pluginPrefix()) == 0) {
auth_manager.loadPlugin(entry.path().string());
}
else if (filename.find(PluginTraits<Filer>::pluginPrefix()) == 0) {
filer_manager.loadPlugin(entry.path().string());
}
}
}
// Create auth instance based on config
// Initialize auth
std::string auth_type = config->getValue("core", "auth_engine", "pam");
auth = auth_manager.createAuth(auth_type);
auth = auth_manager.createPlugin(auth_type);
if (!auth) {
logger->print(LOGLEVEL_CRITICAL, "Failed to create auth engine: %s", auth_type.c_str());
exit(1);
}
// Initialize auth plugin with config
std::map<std::string, std::string> auth_config = config->get("auth_"+auth_type)->get();
if (!auth->initialize(auth_config)) {
logger->print(LOGLEVEL_CRITICAL, "Failed to initialize auth engine: %s", auth_type.c_str());
exit(1);
}
}
void initializeFiler() {
FilerManager& filer_manager = FilerManager::getInstance();
filer_manager.setLogger(logger);
std::string plugin_dir = config->getValue("core", "plugin_path", PLUGIN_DIR);
for (const auto& entry : std::filesystem::directory_iterator(plugin_dir)) {
if (entry.path().extension() == ".so" &&
entry.path().filename().string().find("libfiler_") == 0) {
filer_manager.loadPlugin(entry.path().string());
}
}
// Initialize filer
std::string filer_type = config->getValue("core", "filer_engine", "local");
default_filer_factory = filer_manager.getFactory(filer_type);
if (!default_filer_factory) {
logger->print(LOGLEVEL_CRITICAL, "Failed to get filer factory for type: %s", filer_type.c_str());
exit(1);
}
std::map<std::string, std::string> filer_config = config->get("filer_"+filer_type)->get();
// Test create using the factory
Filer* test_filer = default_filer_factory();
if (!test_filer) {
logger->print(LOGLEVEL_CRITICAL, "Failed to create filer instance");
exit(1);
}
delete test_filer;
}
int main(int argc , char *argv[]) {
@@ -307,14 +282,10 @@ int main(int argc , char *argv[]) {
if (!config->getBool("ssl", "tls_v1_3", true))
ssl_flags+=SSL_OP_NO_TLSv1_3;
if ((ssl_flags & SSL_OP_NO_SSLv2) &&
(ssl_flags & SSL_OP_NO_SSLv3) &&
(ssl_flags & SSL_OP_NO_TLSv1) &&
(ssl_flags & SSL_OP_NO_TLSv1_1) &&
(ssl_flags & SSL_OP_NO_TLSv1_2) &&
(ssl_flags & SSL_OP_NO_TLSv1_3)) {
if ((ssl_flags & SSL_OP_NO_SSLv2) && (ssl_flags & SSL_OP_NO_SSLv3) &&
(ssl_flags & SSL_OP_NO_TLSv1) && (ssl_flags & SSL_OP_NO_TLSv1_1) &&
(ssl_flags & SSL_OP_NO_TLSv1_2) && (ssl_flags & SSL_OP_NO_TLSv1_3))
logger->print(LOGLEVEL_WARNING, "All SSL/TLS protocols disabled. You're a mad man!");
}
if (!config->getBool("ssl", "compression", true))
ssl_flags+=SSL_OP_NO_COMPRESSION;
@@ -323,8 +294,7 @@ int main(int argc , char *argv[]) {
SSLManager::getInstance().setFlags(ssl_flags);
}
initializeAuth();
initializeFiler();
initializePlugins();
int opt = 1,
master_socket = -1,

View File

@@ -10,8 +10,10 @@
#include "conf.h"
#include "logger.h"
#include "ssl.h"
#include "auth_manager.cpp"
#include "filer_manager.cpp"
#include "auth.h"
#include "filer.h"
#include "plugin_traits.h"
#include "plugin_manager.h"
unsigned char* server_address = new unsigned char[127];
uint16_t server_port = 21;
@@ -20,7 +22,7 @@ std::string motd;
ConfigFile* config;
Auth* auth;
Logger* logger;
static CreateFilerFunc default_filer_factory = nullptr;
static typename PluginTraits<Filer>::CreateFunc default_filer_factory = nullptr;
bool runServer;
bool runCompression;

49
src/plugin.h Normal file
View File

@@ -0,0 +1,49 @@
#ifndef PLUGIN_H
#define PLUGIN_H
#include <string>
#include "logger.h"
typedef int (*GetAPIVersionFunc)();
typedef void (*SetLoggerFunc)(Logger*);
class IPlugin {
public:
virtual ~IPlugin() = default;
static void setLogger(Logger* log) { logger = log; }
protected:
static Logger* logger;
};
Logger* IPlugin::logger = nullptr;
struct PluginInfo {
std::string name;
std::string description;
std::string version;
GetAPIVersionFunc get_api_version;
SetLoggerFunc setLogger;
void* handle;
int api_version;
};
template<typename T>
struct PluginTraits {
static constexpr int API_VERSION = 1;
using CreateFunc = void* (*)();
using DestroyFunc = void (*)(void*);
using PluginInfoType = PluginInfo;
};
#define IMPLEMENT_PLUGIN(PluginClass, BaseClass, name, description, version) \
extern "C" { \
BaseClass* create##BaseClass##Plugin() { return new PluginClass(); } \
void destroy##BaseClass##Plugin(BaseClass* plugin) { delete plugin; } \
const char* getPluginName() { return name; } \
const char* getPluginDescription() { return description; } \
const char* getPluginVersion() { return version; } \
int getAPIVersion() { return PluginTraits<BaseClass>::API_VERSION; } \
void setLogger(Logger* log) { IPlugin::setLogger(log); } \
}
#endif

166
src/plugin_manager.h Normal file
View File

@@ -0,0 +1,166 @@
#ifndef PLUGIN_MANAGER_H
#define PLUGIN_MANAGER_H
#include <map>
#include <dlfcn.h>
#include <filesystem>
#include "plugin.h"
template<typename T>
class PluginManager {
public:
static PluginManager<T>& getInstance() {
static PluginManager<T> instance;
return instance;
}
void setLogger(Logger* log) {
logger = log;
IPlugin::setLogger(log);
}
T* createPlugin(const std::string& type) {
auto it = plugins.find(type);
if (it != plugins.end()) {
return it->second.create();
}
if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str());
return nullptr;
}
bool loadPlugin(const std::string& path) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Loading plugin: %s", path.c_str());
void* handle = dlopen(path.c_str(), RTLD_LAZY);
if (!handle) {
if (logger) logger->print(LOGLEVEL_ERROR, "Failed to load plugin %s: %s",
path.c_str(), dlerror());
return false;
}
// Load common functions
auto create = (typename PluginTraits<T>::CreateFunc)dlsym(handle, PluginTraits<T>::createFuncName());
auto destroy = (typename PluginTraits<T>::DestroyFunc)dlsym(handle, PluginTraits<T>::destroyFuncName());
auto get_api_version = (GetAPIVersionFunc)dlsym(handle, "getAPIVersion");
auto setLogger = (SetLoggerFunc)dlsym(handle, "setLogger");
auto get_name = (const char* (*)())dlsym(handle, "getPluginName");
auto get_desc = (const char* (*)())dlsym(handle, "getPluginDescription");
auto get_ver = (const char* (*)())dlsym(handle, "getPluginVersion");
if (!checkRequiredFunctions(create, destroy, get_api_version, setLogger, get_name, get_desc, get_ver)) {
dlclose(handle);
return false;
}
if (!checkAPIVersion(get_api_version(), path)) {
dlclose(handle);
return false;
}
std::string plugin_name = get_name();
setLogger(logger);
if (plugins.find(plugin_name) != plugins.end()) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Plugin %s is already loaded", plugin_name.c_str());
dlclose(handle);
return true;
}
auto plugin_info = createPluginInfo(plugin_name, get_desc(), get_ver(), create, destroy,
get_api_version, setLogger, handle);
plugins[plugin_name] = plugin_info;
if (logger) logger->print(LOGLEVEL_INFO, "Loaded plugin: %s v%s",
plugin_name.c_str(), plugin_info.version.c_str());
return true;
}
typename PluginTraits<T>::CreateFunc getFactory(const std::string& type) {
auto it = plugins.find(type);
if (it != plugins.end()) {
return it->second.create;
}
if (logger) logger->print(LOGLEVEL_ERROR, "Plugin type '%s' not found", type.c_str());
return nullptr;
}
void unloadPlugins() {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading all plugins");
for (auto& pair : plugins) {
if (logger) logger->print(LOGLEVEL_DEBUG, "Unloading plugin: %s", pair.first.c_str());
dlclose(pair.second.handle);
}
plugins.clear();
}
~PluginManager() {
unloadPlugins();
}
private:
PluginManager() : logger(nullptr) {}
std::map<std::string, typename PluginTraits<T>::PluginInfoType> plugins;
Logger* logger;
bool checkRequiredFunctions(
typename PluginTraits<T>::CreateFunc create,
typename PluginTraits<T>::DestroyFunc destroy,
GetAPIVersionFunc get_api_version,
SetLoggerFunc setLogger,
const char* (*get_name)(),
const char* (*get_desc)(),
const char* (*get_ver)()
) {
if (!create || !destroy || !get_api_version || !setLogger ||
!get_name || !get_desc || !get_ver) {
if (logger) {
const char* missing = !create ? PluginTraits<T>::createFuncName() :
!destroy ? PluginTraits<T>::destroyFuncName() :
!get_api_version ? "getAPIVersion" :
!setLogger ? "setLogger" :
!get_name ? "getPluginName" :
!get_desc ? "getPluginDescription" :
"getPluginVersion";
logger->print(LOGLEVEL_ERROR, "Missing required function: %s", missing);
}
return false;
}
return true;
}
bool checkAPIVersion(int version, const std::string& path) {
if (version != PluginTraits<T>::API_VERSION) {
if (logger) {
logger->print(LOGLEVEL_ERROR,
"Incompatible plugin API version in %s (got %d, expected %d)",
path.c_str(), version, PluginTraits<T>::API_VERSION);
}
return false;
}
return true;
}
typename PluginTraits<T>::PluginInfoType createPluginInfo(
const std::string& name,
const char* description,
const char* version,
typename PluginTraits<T>::CreateFunc create,
typename PluginTraits<T>::DestroyFunc destroy,
GetAPIVersionFunc get_api_version,
SetLoggerFunc setLogger,
void* handle
) {
typename PluginTraits<T>::PluginInfoType info;
info.name = name;
info.description = description;
info.version = version;
info.create = create;
info.destroy = destroy;
info.get_api_version = get_api_version;
info.setLogger = setLogger;
info.handle = handle;
return info;
}
};
#endif

48
src/plugin_traits.h Normal file
View File

@@ -0,0 +1,48 @@
#ifndef PLUGIN_TRAITS_H
#define PLUGIN_TRAITS_H
#include "plugin.h"
#include "auth.h"
#include "filer.h"
struct AuthPluginInfo : public PluginInfo {
typedef Auth* (*CreateFunc)();
typedef void (*DestroyFunc)(Auth*);
CreateFunc create;
DestroyFunc destroy;
};
struct FilerPluginInfo : public PluginInfo {
typedef Filer* (*CreateFunc)();
typedef void (*DestroyFunc)(Filer*);
CreateFunc create;
DestroyFunc destroy;
};
template<>
struct PluginTraits<Auth> {
static constexpr int API_VERSION = AUTH_PLUGIN_API_VERSION;
using CreateFunc = Auth* (*)();
using DestroyFunc = void (*)(Auth*);
using PluginInfoType = AuthPluginInfo;
static const char* createFuncName() { return "createAuthPlugin"; }
static const char* destroyFuncName() { return "destroyAuthPlugin"; }
static const char* pluginPrefix() { return "libauth_"; }
};
template<>
struct PluginTraits<Filer> {
static constexpr int API_VERSION = FILER_PLUGIN_API_VERSION;
using CreateFunc = Filer* (*)();
using DestroyFunc = void (*)(Filer*);
using PluginInfoType = FilerPluginInfo;
static const char* createFuncName() { return "createFilerPlugin"; }
static const char* destroyFuncName() { return "destroyFilerPlugin"; }
static const char* pluginPrefix() { return "libfiler_"; }
};
#endif

View File

@@ -1,4 +1,5 @@
#include "auth_plugin.h"
#include "plugin.h"
#include "auth.h"
#include <security/pam_appl.h>
#include <pwd.h>
#include <grp.h>
@@ -6,7 +7,7 @@
#include <cstring>
#include <memory>
class PAMAuthPlugin : public Auth {
class PAMAuthPlugin : public IPlugin, public Auth {
private:
std::string service_name;
uid_t user_uid;
@@ -128,4 +129,4 @@ public:
}
};
IMPLEMENT_AUTH_PLUGIN(PAMAuthPlugin, "pam", "PAM-based local authentication", "1.0.0")
IMPLEMENT_PLUGIN(PAMAuthPlugin, Auth, "pam", "PAM-based local authentication", "1.0.0")

View File

@@ -1,11 +1,12 @@
#include "auth_plugin.h"
#include "plugin.h"
#include "auth.h"
#include <fstream>
#include <string>
#include <map>
#include <algorithm>
#include <cstring>
class PassdbAuthPlugin : public Auth {
class PassdbAuthPlugin : public IPlugin, public Auth {
private:
std::string db_file = "passdb";
bool case_sensitive;
@@ -59,7 +60,6 @@ public:
}
if (authenticated) {
// Only set the home directory after successful authentication
setUserDirectory(auth_data);
return true;
}
@@ -72,4 +72,4 @@ public:
}
};
IMPLEMENT_AUTH_PLUGIN(PassdbAuthPlugin, "passdb", "Password database authentication", "1.0.0")
IMPLEMENT_PLUGIN(PassdbAuthPlugin, Auth, "passdb", "Password database authentication", "1.0.0")

View File

@@ -1,4 +1,5 @@
#include "filer_plugin.h"
#include "plugin.h"
#include "filer.h"
#include <iostream>
#include <fstream>
#include <string>
@@ -8,7 +9,7 @@
namespace fs = std::filesystem;
class LocalFiler : public Filer {
class LocalFiler : public IPlugin, public Filer {
public:
LocalFiler() {}
@@ -371,4 +372,4 @@ private:
char type = 'I';
};
IMPLEMENT_FILER_PLUGIN(LocalFiler, "local", "Local filesystem implementation", "1.0.0")
IMPLEMENT_PLUGIN(LocalFiler, Filer, "local", "Local filesystem implementation", "1.0.0")