preload symbols and use the data instead

This commit is contained in:
2025-05-05 23:17:18 +03:00
parent 7c1145617a
commit b5246d2eb8
6 changed files with 139 additions and 156 deletions

View File

@@ -8,6 +8,7 @@
#include "memmap.h"
#include "symbols.h"
#include "types.h"
#include "vector.h"
#include <dlfcn.h>
#include <errno.h>
@@ -25,6 +26,28 @@ typedef struct {
MemorySpan memreg;
} PatchData;
typedef struct {
/// The original module name
const char *modname;
/// The associated symbols for that module
VectorSymbol symbols;
} SymbolData;
vector_def(SymbolData, SymbolData);
// TODO: Stop being lazy and put this with the rest of the data
static VectorSymbolData symbol_data;
static VectorSymbol *symdat_get(const char *modname) {
for (size_t i = 0; i < vector_size(&symbol_data); ++i) {
SymbolData *symdat = &vector_at(&symbol_data, i);
if (strcmp(modname, symdat->modname) == 0) {
return &symdat->symbols;
}
}
return NULL;
}
static void *adjust_if_relative(void *ptr, void *module_base) {
uptr p = (uptr)ptr;
if (p && (p < (uptr)module_base)) {
@@ -36,7 +59,7 @@ static void *adjust_if_relative(void *ptr, void *module_base) {
static HiResult moduler_collect_symbols(VectorSymbol *symbols,
const char *module_name,
void *module_base) {
symbol_clear(symbols);
symbols_clear(symbols);
HiResult ret = HI_FAIL;
@@ -181,10 +204,9 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
size_t relasz = 0;
Elf *elf = NULL;
ElfW(Dyn) *dyn_sct =
hi_elf_find_dynamic_segment(module_base, module_size, &elf);
ElfW(Dyn) *dyn = hi_elf_find_dynseg(module_base, module_size, &elf);
for (ElfW(Dyn) *d = dyn_sct; d->d_tag != DT_NULL; d++) {
for (ElfW(Dyn) *d = dyn; d->d_tag != DT_NULL; d++) {
log_debugv("%s\n", hi_elf_dyntostr(d->d_tag));
switch (d->d_tag) {
@@ -247,12 +269,12 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
// Check if this is a symbol we want to patch
for (size_t j = 0; j < num_symbols; j++) {
Symbol *sym = &vector_at(psymbols, j);
if (strcmp(sym->name, name) == 0) {
sym->got_entry = got_entry;
sym->orig_address = *got_entry; // Save the original function
Symbol *s = &vector_at(psymbols, j);
if (strcmp(s->name, name) == 0) {
s->got_entry = got_entry;
s->orig_address = *got_entry; // Save the original function
*got_entry = sym->address;
*got_entry = s->address;
log_debug("Found GOT entry for '%s' at %p (points to %p)\n", name,
got_entry, *got_entry);
@@ -280,15 +302,15 @@ static PatchData moduler_create_patch(ModuleData *module) {
}
char filename[512];
size_t written =
string_concat_buf(sizeof(filename), filename, module->name, file_append);
size_t written = string_concat_buf(sizeof(filename), filename,
module->original_name, file_append);
if (written == 0) {
log_error("Failed to concat %s and %s\n", module->name, ".patch");
log_error("Failed to concat %s and %s\n", module->original_name, ".patch");
return (PatchData){0};
}
file_copy(module->name, filename);
file_copy(module->original_name, filename);
PatchData data = {.filename = strdup(filename)};
return data;
@@ -313,7 +335,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
// Load patch
dlerror(); // clear any previous errors
log_debug("Opening: %s\n", patch.filename);
log_debug("Opening patch: %s\n", patch.filename);
void *new_handle = dlopen(patch.filename, RTLD_LAZY);
if (!new_handle) {
log_error("Couldn't load: %s\n", dlerror());
@@ -333,7 +355,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
void *patch_base = (void *)patch.memreg.start;
VectorSymbol patch_symbols;
symbol_init(&patch_symbols);
symbols_init(&patch_symbols);
HiResult ret =
moduler_collect_symbols(&patch_symbols, patch.filename, patch_base);
@@ -353,44 +375,72 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
moduler_patch_functions(&patch_symbols, module_memory);
// This relies on the patch filename being different only by the append
if (strncmp(mod.name, patch.filename, strlen(mod.name)) == 0) {
if (strncmp(mod.original_name, patch.filename, strlen(mod.original_name)) ==
0) {
// If patch is for the same module, also collect local object symbols for
// coping those over.
VectorSymbol module_symbols;
symbol_init(&module_symbols);
ret = moduler_collect_symbols(&module_symbols, mod.name,
(void *)mod.address);
if (!HIOK(ret)) {
log_error("Failed to gather symbols for %s\n", mod.name);
symbol_term(&module_symbols);
continue;
}
VectorSymbol *module_symbols = symdat_get(mod.original_name);
// Copy old data to new data location. Breaks if layout changes,
// e.g. struct fields are moved around
for (size_t j = 0; j < vector_size(&module_symbols); ++j) {
Symbol *sym = &vector_at(&module_symbols, j);
if (sym->type == HI_SYMBOL_TYPE_OBJECT) {
Symbol *ps = symbol_find(&patch_symbols, sym);
if (ps) {
size_t copy_size = MIN(sym->size, ps->size);
memcpy(ps->address, sym->address, copy_size);
log_debug("Copied data for symbol: %s\n", sym->name);
// Also update our symbol cache for the module. We need to add new
// symbols, but don't care much about deleting old ones
for (size_t j = 0; j < vector_size(&patch_symbols); ++j) {
Symbol *psym = &vector_at(&patch_symbols, j);
if (psym->type == HI_SYMBOL_TYPE_OBJECT) {
Symbol *msym = symbol_find(module_symbols, psym);
if (msym) {
size_t copy_size = MIN(psym->size, msym->size);
memcpy(psym->address, msym->address, copy_size);
log_debug("Copied data for symbol: %s\n", psym->name);
// Maintain our current symbol references with the new address
msym->address = psym->address;
} else {
// A new symbol has been added in the patch
symbols_add(module_symbols, symbol_copy(psym));
log_debug("Found new symbol from patch: %s\n", psym->name);
}
}
}
symbol_term(&module_symbols);
}
module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY);
}
symbols_term(&patch_symbols);
free((char *)patch.filename);
symbol_term(&patch_symbols);
module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY);
dlclose(module->dlhandle);
// Update module reference
if (module->name != module->original_name)
free((char *)module->name);
module->name = patch.filename;
module->address = patch.memreg.start;
module->dlhandle = patch.dlhandle;
return HI_OK;
}
void moduler_init(const VectorModuleData *modules) {
vector_init(&symbol_data);
ModuleData mod = {0};
vector_foreach(modules, mod) {
SymbolData symdata = {0};
symdata.modname = mod.original_name;
vector_add(&symbol_data, symdata);
if (modinfo_has(mod.info, HI_MODULE_STATE_PATCHABLE)) {
moduler_collect_symbols(&vector_last(&symbol_data).symbols,
mod.original_name, (void *)mod.address);
}
vector_init(&symdata.symbols);
}
}
void moduler_deinit() {
SymbolData symdata = {0};
vector_foreach(&symbol_data, symdata) { symbols_term(&symdata.symbols); }
vector_term(&symbol_data);
}