diff --git a/src/hielf.c b/src/hielf.c index 95259a5..5d3f054 100644 --- a/src/hielf.c +++ b/src/hielf.c @@ -8,119 +8,46 @@ #include #include -void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) { +void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf_ret) { - Elf *elf = elf_memory(module_base, n); + Elf *elf = NULL; - size_t phdrnum = 0; - int err = elf_getphdrnum(elf, &phdrnum); - if (err == -1) { - log_error("%s\n", elf_errmsg(elf_errno())); - return NULL; - } - - void *dyn_addr = NULL; - - for (size_t i=0; i < phdrnum; ++i) { - GElf_Phdr phdr; - if (gelf_getphdr(elf, i, &phdr) != &phdr) { - log_error("Failed to find program headers: %s\n", elf_errmsg(elf_errno())); - continue; - } - - if (phdr.p_type == PT_DYNAMIC) { - dyn_addr = (void*)((uptr)module_base + phdr.p_vaddr); - break; - } - } - - if (elf_ret) { - *elf_ret = elf; - } else { - elf_end(elf); - } - - return dyn_addr; -} - -void hi_elf_print_module_from_memory(void *address, size_t size) { - Elf *elf = elf_memory(address, size); - - int err = elf_errno(); - if (err) { - log_error("%s\n", elf_errmsg(err)); - return; + if (elf_ret && *elf_ret) { + elf = *elf_ret; + } else { + elf = elf_memory(module_base, n); } size_t phdrnum = 0; - err = elf_getphdrnum(elf, &phdrnum); - Elf64_Phdr *phdr = elf64_getphdr(elf); + int err = elf_getphdrnum(elf, &phdrnum); + if (err == -1) { + log_error("%s\n", elf_errmsg(elf_errno())); + return NULL; + } + + void *dyn_addr = NULL; for (size_t i = 0; i < phdrnum; ++i) { - Elf64_Phdr *p = phdr + i; - log_debug("segment type: %s\n", hi_elf_segtostr(p->p_type)); + GElf_Phdr phdr; + if (gelf_getphdr(elf, i, &phdr) != &phdr) { + log_error("Failed to find program headers: %s\n", + elf_errmsg(elf_errno())); + continue; + } - size_t segment_size = p->p_memsz; - void *segment_start = (void*)((uptr)address + p->p_vaddr); - void *segment_end = (void*)((uptr)address + p->p_vaddr + segment_size); - - void *strtab = 0; - void *symtab = 0; - size_t strsz = 0; - size_t syment = 0; - void *rela = 0; - size_t relasz = 0; - size_t relaent = 0; - size_t relacount = 0; - void *pltgot = 0; - size_t pltrelsz = 0; - void *pltrel = 0; - - if (p->p_type == PT_DYNAMIC) { - Elf64_Dyn *dyn = (Elf64_Dyn *)segment_start; - while ((void *)dyn < segment_end) { - log_debug(" dyn type: %s\n", hi_elf_dyntostr(dyn->d_tag)); - - if (dyn->d_tag == DT_STRTAB) { - strtab = (void *)dyn->d_un.d_ptr; - } else if (dyn->d_tag == DT_SYMTAB) { - symtab = (void *)dyn->d_un.d_ptr; - } else if (dyn->d_tag == DT_STRSZ) { - strsz = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_SYMENT) { - syment = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_RELA) { - rela = (void *)dyn->d_un.d_ptr; - } else if (dyn->d_tag == DT_RELASZ) { - relasz = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_RELAENT) { - relaent = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_RELACOUNT) { - relacount = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_PLTGOT) { - pltgot = (void *)dyn->d_un.d_ptr; - } else if (dyn->d_tag == DT_PLTRELSZ) { - pltrelsz = dyn->d_un.d_val; - } else if (dyn->d_tag == DT_PLTREL) { - pltrel = (void *)dyn->d_un.d_val; - } - ++dyn; - } - log_debug("\nstrtab: %p\n" - "symtab: %p\n" - "strsz: %zu\n" - "syment: %zu\n" - "rela: %p\n" - "relasz: %zu\n" - "relaent: %zu\n" - "relacount: %zu\n" - "pltgot: %p\n" - "pltrelsz: %zu\n" - "pltrel: %p\n", - strtab, symtab, strsz, syment, rela, relasz, relaent, - relacount, pltgot, pltrelsz, pltrel); + if (phdr.p_type == PT_DYNAMIC) { + dyn_addr = (void *)((uptr)module_base + phdr.p_vaddr); + break; } } + + if (elf_ret) { + *elf_ret = elf; + } else { + elf_end(elf); + } + + return dyn_addr; } #define TYPE_TO_STR(type) \ diff --git a/src/hielf.h b/src/hielf.h index 4a6d9f7..9b9db22 100644 --- a/src/hielf.h +++ b/src/hielf.h @@ -1,11 +1,10 @@ #pragma once #include -#include #include +#include -void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf); +void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf); const char *hi_elf_dyntostr(unsigned type); const char *hi_elf_segtostr(unsigned type); - diff --git a/src/hiload.c b/src/hiload.c index b11345d..e7bef7e 100644 --- a/src/hiload.c +++ b/src/hiload.c @@ -6,7 +6,6 @@ #include "logger.h" #include "memmap.h" #include "moduler.h" -#include "symbols.h" #include "vector.h" #include @@ -66,9 +65,6 @@ static int initial_gather_callback(struct dl_phdr_info *info, size_t size, module.original_name = strdup(modpath); module.name = module.original_name; - log_debugv("dli_fname: %s\n", dl_info.dli_fname); - log_debugv("dli_sname: %s\n", dl_info.dli_sname); - } else { // I don't know when this could happen since we're passing the info straight // from the info struct @@ -161,7 +157,7 @@ static ModuleData *module_get(const char *path, for (size_t i = 0; i < vector_size(modules); i++) { ModuleData *module = &vector_at(modules, i); - const char *name = module->name; + const char *name = module->original_name; if (strcmp(name, path) == 0) { return module; } @@ -232,10 +228,13 @@ int hi_init(size_t n, const char **enabled_modules) { return 1; } + moduler_init(&context.modules); + return 0; } void hi_deinit() { + moduler_deinit(); module_data_free(&context.modules); filewatcher_destroy(context.filewatcher); log_term(); diff --git a/src/moduler.c b/src/moduler.c index 18bb771..3275973 100644 --- a/src/moduler.c +++ b/src/moduler.c @@ -8,6 +8,7 @@ #include "memmap.h" #include "symbols.h" #include "types.h" +#include "vector.h" #include #include @@ -25,6 +26,28 @@ typedef struct { MemorySpan memreg; } PatchData; +typedef struct { + /// The original module name + const char *modname; + /// The associated symbols for that module + VectorSymbol symbols; +} SymbolData; + +vector_def(SymbolData, SymbolData); + +// TODO: Stop being lazy and put this with the rest of the data +static VectorSymbolData symbol_data; + +static VectorSymbol *symdat_get(const char *modname) { + for (size_t i = 0; i < vector_size(&symbol_data); ++i) { + SymbolData *symdat = &vector_at(&symbol_data, i); + if (strcmp(modname, symdat->modname) == 0) { + return &symdat->symbols; + } + } + return NULL; +} + static void *adjust_if_relative(void *ptr, void *module_base) { uptr p = (uptr)ptr; if (p && (p < (uptr)module_base)) { @@ -36,7 +59,7 @@ static void *adjust_if_relative(void *ptr, void *module_base) { static HiResult moduler_collect_symbols(VectorSymbol *symbols, const char *module_name, void *module_base) { - symbol_clear(symbols); + symbols_clear(symbols); HiResult ret = HI_FAIL; @@ -181,10 +204,9 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols, size_t relasz = 0; Elf *elf = NULL; - ElfW(Dyn) *dyn_sct = - hi_elf_find_dynamic_segment(module_base, module_size, &elf); + ElfW(Dyn) *dyn = hi_elf_find_dynseg(module_base, module_size, &elf); - for (ElfW(Dyn) *d = dyn_sct; d->d_tag != DT_NULL; d++) { + for (ElfW(Dyn) *d = dyn; d->d_tag != DT_NULL; d++) { log_debugv("%s\n", hi_elf_dyntostr(d->d_tag)); switch (d->d_tag) { @@ -247,12 +269,12 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols, // Check if this is a symbol we want to patch for (size_t j = 0; j < num_symbols; j++) { - Symbol *sym = &vector_at(psymbols, j); - if (strcmp(sym->name, name) == 0) { - sym->got_entry = got_entry; - sym->orig_address = *got_entry; // Save the original function + Symbol *s = &vector_at(psymbols, j); + if (strcmp(s->name, name) == 0) { + s->got_entry = got_entry; + s->orig_address = *got_entry; // Save the original function - *got_entry = sym->address; + *got_entry = s->address; log_debug("Found GOT entry for '%s' at %p (points to %p)\n", name, got_entry, *got_entry); @@ -280,15 +302,15 @@ static PatchData moduler_create_patch(ModuleData *module) { } char filename[512]; - size_t written = - string_concat_buf(sizeof(filename), filename, module->name, file_append); + size_t written = string_concat_buf(sizeof(filename), filename, + module->original_name, file_append); if (written == 0) { - log_error("Failed to concat %s and %s\n", module->name, ".patch"); + log_error("Failed to concat %s and %s\n", module->original_name, ".patch"); return (PatchData){0}; } - file_copy(module->name, filename); + file_copy(module->original_name, filename); PatchData data = {.filename = strdup(filename)}; return data; @@ -313,7 +335,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) { // Load patch dlerror(); // clear any previous errors - log_debug("Opening: %s\n", patch.filename); + log_debug("Opening patch: %s\n", patch.filename); void *new_handle = dlopen(patch.filename, RTLD_LAZY); if (!new_handle) { log_error("Couldn't load: %s\n", dlerror()); @@ -333,7 +355,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) { void *patch_base = (void *)patch.memreg.start; VectorSymbol patch_symbols; - symbol_init(&patch_symbols); + symbols_init(&patch_symbols); HiResult ret = moduler_collect_symbols(&patch_symbols, patch.filename, patch_base); @@ -353,44 +375,72 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) { moduler_patch_functions(&patch_symbols, module_memory); // This relies on the patch filename being different only by the append - if (strncmp(mod.name, patch.filename, strlen(mod.name)) == 0) { + if (strncmp(mod.original_name, patch.filename, strlen(mod.original_name)) == + 0) { + // If patch is for the same module, also collect local object symbols for // coping those over. - - VectorSymbol module_symbols; - symbol_init(&module_symbols); - ret = moduler_collect_symbols(&module_symbols, mod.name, - (void *)mod.address); - if (!HIOK(ret)) { - log_error("Failed to gather symbols for %s\n", mod.name); - symbol_term(&module_symbols); - continue; - } + VectorSymbol *module_symbols = symdat_get(mod.original_name); // Copy old data to new data location. Breaks if layout changes, // e.g. struct fields are moved around - for (size_t j = 0; j < vector_size(&module_symbols); ++j) { - Symbol *sym = &vector_at(&module_symbols, j); - if (sym->type == HI_SYMBOL_TYPE_OBJECT) { - Symbol *ps = symbol_find(&patch_symbols, sym); - if (ps) { - size_t copy_size = MIN(sym->size, ps->size); - memcpy(ps->address, sym->address, copy_size); - log_debug("Copied data for symbol: %s\n", sym->name); + // Also update our symbol cache for the module. We need to add new + // symbols, but don't care much about deleting old ones + for (size_t j = 0; j < vector_size(&patch_symbols); ++j) { + Symbol *psym = &vector_at(&patch_symbols, j); + if (psym->type == HI_SYMBOL_TYPE_OBJECT) { + Symbol *msym = symbol_find(module_symbols, psym); + if (msym) { + size_t copy_size = MIN(psym->size, msym->size); + memcpy(psym->address, msym->address, copy_size); + log_debug("Copied data for symbol: %s\n", psym->name); + + // Maintain our current symbol references with the new address + msym->address = psym->address; + } else { + // A new symbol has been added in the patch + symbols_add(module_symbols, symbol_copy(psym)); + log_debug("Found new symbol from patch: %s\n", psym->name); } } } - symbol_term(&module_symbols); } - - module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY); } + symbols_term(&patch_symbols); - free((char *)patch.filename); - symbol_term(&patch_symbols); - + module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY); dlclose(module->dlhandle); + + // Update module reference + if (module->name != module->original_name) + free((char *)module->name); + module->name = patch.filename; + module->address = patch.memreg.start; module->dlhandle = patch.dlhandle; return HI_OK; } + +void moduler_init(const VectorModuleData *modules) { + vector_init(&symbol_data); + + ModuleData mod = {0}; + vector_foreach(modules, mod) { + SymbolData symdata = {0}; + symdata.modname = mod.original_name; + vector_add(&symbol_data, symdata); + + if (modinfo_has(mod.info, HI_MODULE_STATE_PATCHABLE)) { + moduler_collect_symbols(&vector_last(&symbol_data).symbols, + mod.original_name, (void *)mod.address); + } + + vector_init(&symdata.symbols); + } +} + +void moduler_deinit() { + SymbolData symdata = {0}; + vector_foreach(&symbol_data, symdata) { symbols_term(&symdata.symbols); } + vector_term(&symbol_data); +} diff --git a/src/moduler.h b/src/moduler.h index 098dbed..978b52f 100644 --- a/src/moduler.h +++ b/src/moduler.h @@ -45,3 +45,5 @@ static inline bool modinfo_has(ModuleInfo flags, ModuleFlags flag) { #define HI_MODINFO_CLEAR(info, flag) ((info) &= ~flag) HiResult moduler_reload(VectorModuleData *modules, size_t modindx); +void moduler_init(const VectorModuleData *modules); +void moduler_deinit(void); diff --git a/src/symbols.h b/src/symbols.h index 8b402e6..8c4228e 100644 --- a/src/symbols.h +++ b/src/symbols.h @@ -33,23 +33,29 @@ vector_def(Symbol, Symbol); static inline void symbol_free(Symbol *symbol) { free((char *)symbol->name); } -static inline void symbol_init(VectorSymbol *symbols) { vector_init(symbols); } -static inline void symbol_clear(VectorSymbol *symbols) { +static inline void symbols_init(VectorSymbol *symbols) { vector_init(symbols); } +static inline void symbols_clear(VectorSymbol *symbols) { for (size_t i = 0; i < vector_size(symbols); ++i) { symbol_free(&vector_at(symbols, i)); } vector_clear(symbols); } -static inline void symbol_term(VectorSymbol *symbols) { - symbol_clear(symbols); +static inline void symbols_term(VectorSymbol *symbols) { + symbols_clear(symbols); vector_term(symbols); } -static inline void symbol_add(VectorSymbol *symbols, Symbol symbol) { +static inline void symbols_add(VectorSymbol *symbols, Symbol symbol) { vector_add(symbols, symbol); } +static inline Symbol symbol_copy(const Symbol *symbol) { + Symbol sym = *symbol; + sym.name = strdup(sym.name); + return sym; +} + Symbol *symbol_find(VectorSymbol *symbols, Symbol *symbol); SymbolBind symbol_bind_from_efibind(u32 efi_bind);