diff --git a/src/moduler/moduler.c b/src/moduler/moduler.c index 28100d4..582fc6d 100644 --- a/src/moduler/moduler.c +++ b/src/moduler/moduler.c @@ -19,8 +19,15 @@ #include #include #include +#include #include +typedef struct { + const char *filename; + void *dlhandle; + MemoryRegionSpan memreg; +} PatchData; + static void *adjust_if_relative(void *ptr, void *module_base) { uptr p = (uptr)ptr; if (p && (p < (uptr)module_base)) { @@ -131,8 +138,8 @@ static HiloadResult gather_patchable_symbols(struct sc_array_sym *symbols, } void *sym_addr = (void *)((uintptr_t)module_base + sym.st_value); - HiSymbolBind binding = symbol_bind_from_efi(GELF_ST_BIND(sym.st_info)); - HiSymbolType type = symbol_type_from_efi(GELF_ST_TYPE(sym.st_info)); + HiSymbolBind binding = symbol_bind_from_efibind(GELF_ST_BIND(sym.st_info)); + HiSymbolType type = symbol_type_from_efitype(GELF_ST_TYPE(sym.st_info)); size_t size = sym.st_size; // Gather global symbols and local object symbols. Local functions are @@ -268,6 +275,36 @@ static HiloadResult moduler_apply_module_patch(HiSymbols *psymbols, return HILOAD_OK; } +PatchData moduler_create_patch(HiModuleData *module) { + + time_t now = time(NULL); + struct tm *t = localtime(&now); + if (t == NULL) { + return (PatchData){0}; + } + + char file_append[32]; + if (strftime(file_append, sizeof(file_append), ".%Y%m%d%H%M%S.patch", t) == 0) { + log_error("Failed to create patch filename.\n"); + return (PatchData){0}; + } + + char filename[512]; + size_t written = + hi_strncat_buf(sizeof(filename), filename, module->name, file_append); + + + if (written == 0) { + log_error("Failed to concat %s and %s\n", module->name, ".patch"); + return (PatchData){0}; + } + + hi_file_copy(module->name, filename); + + PatchData data = {.filename = strdup(filename)}; + return data; +} + HiloadResult moduler_reload(HiModuleArray *modules, HiModuleData *module, struct sc_array_memreg *memregs) { @@ -279,40 +316,36 @@ HiloadResult moduler_reload(HiModuleArray *modules, HiModuleData *module, return HILOAD_OK; } - dlerror(); // clear old errors - char patch_filename[512]; - size_t written = hi_strncat_buf(sizeof(patch_filename), patch_filename, - module->name, ".patch"); - if (written == 0) { - log_error("Failed to concat %s and %s\n", module->name, ".patch"); + PatchData patch = moduler_create_patch(module); + if (!patch.filename) { + log_error("Couldn't create patch for %s\n", module->name); return HILOAD_FAIL; } - hi_file_copy(module->name, patch_filename); - // Load patch - log_debug("Opening: %s\n", patch_filename); - void *new_handle = dlopen(patch_filename, RTLD_LAZY); + dlerror(); // clear any previous errors + log_debug("Opening: %s\n", patch.filename); + void *new_handle = dlopen(patch.filename, RTLD_LAZY); if (!new_handle) { log_error("Couldn't load: %s\n", dlerror()); module->info = hi_modinfo_clear(module->info, HI_MODULE_STATE_DIRTY); return HILOAD_FAIL; } + patch.dlhandle = new_handle; // refresh cache read_memory_maps_self(memregs); - MemoryRegionSpan patch_memory = - memory_get_module_span(memregs, patch_filename); - void *patch_base = (void *)patch_memory.region_start; + patch.memreg = memory_get_module_span(memregs, patch.filename); + void *patch_base = (void *)patch.memreg.region_start; HiSymbols patch_symbols; symbol_init_symbols(&patch_symbols); HiloadResult ret = - gather_patchable_symbols(&patch_symbols, patch_filename, patch_base); + gather_patchable_symbols(&patch_symbols, patch.filename, patch_base); if (!HIOK(ret)) { - log_error("Failed to gather symbols for %s\n", patch_filename); + log_error("Failed to gather symbols for %s\n", patch.filename); return HILOAD_FAIL; } @@ -327,7 +360,7 @@ HiloadResult moduler_reload(HiModuleArray *modules, HiModuleData *module, // If patch is for the same module, also collect local object symbols for // coping those over - if (strncmp(mod.name, patch_filename, strlen(mod.name)) == 0) { + if (strncmp(mod.name, patch.filename, strlen(mod.name)) == 0) { HiSymbols module_symbols; symbol_init_symbols(&module_symbols); @@ -347,7 +380,7 @@ HiloadResult moduler_reload(HiModuleArray *modules, HiModuleData *module, if (ps) { if (ps->size >= sym->size) { memcpy(ps->address, sym->address, sym->size); - } else { + } else { memcpy(ps->address, sym->address, ps->size); } log_debug("Copied data for symbol: %s\n", sym->name); @@ -360,7 +393,11 @@ HiloadResult moduler_reload(HiModuleArray *modules, HiModuleData *module, module->info = hi_modinfo_clear(module->info, HI_MODULE_STATE_DIRTY); } + free((char*)patch.filename); symbol_term_symbols(&patch_symbols); + dlclose(module->dlhandle); + module->dlhandle = patch.dlhandle; + return HILOAD_OK; } diff --git a/src/symbols.c b/src/symbols.c index 9d1af04..b759d65 100644 --- a/src/symbols.c +++ b/src/symbols.c @@ -15,7 +15,7 @@ HiSymbol *symbol_find(HiSymbols *symbols, HiSymbol *symbol) { return NULL; } -HiSymbolBind symbol_bind_from_efi(u32 efi_bind) { +HiSymbolBind symbol_bind_from_efibind(u32 efi_bind) { // clang-format off switch (efi_bind) { case STB_LOCAL: return HI_SYMBOL_BIND_LOCAL; @@ -26,7 +26,7 @@ HiSymbolBind symbol_bind_from_efi(u32 efi_bind) { return ~0u; } -HiSymbolType symbol_type_from_efi(u32 efi_type) { +HiSymbolType symbol_type_from_efitype(u32 efi_type) { // clang-format off switch (efi_type) { case STT_NOTYPE: return HI_SYMBOL_TYPE_NOTYPE; /* Symbol type is unspecified */ diff --git a/src/symbols.h b/src/symbols.h index c6cb9d7..46f49ad 100644 --- a/src/symbols.h +++ b/src/symbols.h @@ -52,6 +52,6 @@ static inline void symbol_term_symbols(HiSymbols *symbols) { HiSymbol *symbol_find(HiSymbols *symbols, HiSymbol *symbol); -HiSymbolBind symbol_bind_from_efi(u32 efi_bind); -HiSymbolType symbol_type_from_efi(u32 efi_type); +HiSymbolBind symbol_bind_from_efibind(u32 efi_bind); +HiSymbolType symbol_type_from_efitype(u32 efi_type); #endif // SYMBOLS_H_ diff --git a/test/manual/minimal.cpp b/test/manual/minimal.cpp index f2f23c0..14b69e6 100644 --- a/test/manual/minimal.cpp +++ b/test/manual/minimal.cpp @@ -18,7 +18,7 @@ int main(int argc, char *argv[]) { while (modified != 0) { modified = minimal_lib::getNewValue(5); - printf("getNewValue(5): %d\n", modified); + printf("getNewValue: %d\n", modified); std::this_thread::sleep_for(std::chrono::seconds(1)); printf("otherValue: %d\n", minimal_lib::otherValue++); diff --git a/test/manual/minimal_lib.cpp b/test/manual/minimal_lib.cpp index a0fccf0..b72e1c2 100644 --- a/test/manual/minimal_lib.cpp +++ b/test/manual/minimal_lib.cpp @@ -1,10 +1,11 @@ #include "minimal_lib.h" +#include namespace minimal_lib { int getNewValue(int x) { static int value = 0; - value = value + 5; + value = value + x; return value; }