preload symbols and use the data instead

This commit is contained in:
2025-05-05 23:17:18 +03:00
parent 7c1145617a
commit b5246d2eb8
6 changed files with 139 additions and 156 deletions

View File

@@ -8,9 +8,15 @@
#include <libelf.h>
#include <unistd.h>
void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) {
void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf_ret) {
Elf *elf = elf_memory(module_base, n);
Elf *elf = NULL;
if (elf_ret && *elf_ret) {
elf = *elf_ret;
} else {
elf = elf_memory(module_base, n);
}
size_t phdrnum = 0;
int err = elf_getphdrnum(elf, &phdrnum);
@@ -21,15 +27,16 @@ void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) {
void *dyn_addr = NULL;
for (size_t i=0; i < phdrnum; ++i) {
for (size_t i = 0; i < phdrnum; ++i) {
GElf_Phdr phdr;
if (gelf_getphdr(elf, i, &phdr) != &phdr) {
log_error("Failed to find program headers: %s\n", elf_errmsg(elf_errno()));
log_error("Failed to find program headers: %s\n",
elf_errmsg(elf_errno()));
continue;
}
if (phdr.p_type == PT_DYNAMIC) {
dyn_addr = (void*)((uptr)module_base + phdr.p_vaddr);
dyn_addr = (void *)((uptr)module_base + phdr.p_vaddr);
break;
}
}
@@ -43,86 +50,6 @@ void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) {
return dyn_addr;
}
void hi_elf_print_module_from_memory(void *address, size_t size) {
Elf *elf = elf_memory(address, size);
int err = elf_errno();
if (err) {
log_error("%s\n", elf_errmsg(err));
return;
}
size_t phdrnum = 0;
err = elf_getphdrnum(elf, &phdrnum);
Elf64_Phdr *phdr = elf64_getphdr(elf);
for (size_t i = 0; i < phdrnum; ++i) {
Elf64_Phdr *p = phdr + i;
log_debug("segment type: %s\n", hi_elf_segtostr(p->p_type));
size_t segment_size = p->p_memsz;
void *segment_start = (void*)((uptr)address + p->p_vaddr);
void *segment_end = (void*)((uptr)address + p->p_vaddr + segment_size);
void *strtab = 0;
void *symtab = 0;
size_t strsz = 0;
size_t syment = 0;
void *rela = 0;
size_t relasz = 0;
size_t relaent = 0;
size_t relacount = 0;
void *pltgot = 0;
size_t pltrelsz = 0;
void *pltrel = 0;
if (p->p_type == PT_DYNAMIC) {
Elf64_Dyn *dyn = (Elf64_Dyn *)segment_start;
while ((void *)dyn < segment_end) {
log_debug(" dyn type: %s\n", hi_elf_dyntostr(dyn->d_tag));
if (dyn->d_tag == DT_STRTAB) {
strtab = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_SYMTAB) {
symtab = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_STRSZ) {
strsz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_SYMENT) {
syment = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELA) {
rela = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_RELASZ) {
relasz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELAENT) {
relaent = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELACOUNT) {
relacount = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_PLTGOT) {
pltgot = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_PLTRELSZ) {
pltrelsz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_PLTREL) {
pltrel = (void *)dyn->d_un.d_val;
}
++dyn;
}
log_debug("\nstrtab: %p\n"
"symtab: %p\n"
"strsz: %zu\n"
"syment: %zu\n"
"rela: %p\n"
"relasz: %zu\n"
"relaent: %zu\n"
"relacount: %zu\n"
"pltgot: %p\n"
"pltrelsz: %zu\n"
"pltrel: %p\n",
strtab, symtab, strsz, syment, rela, relasz, relaent,
relacount, pltgot, pltrelsz, pltrel);
}
}
}
#define TYPE_TO_STR(type) \
case (type): \
return #type

View File

@@ -1,11 +1,10 @@
#pragma once
#include <elf.h>
#include <libelf.h>
#include <gelf.h>
#include <libelf.h>
void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf);
void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf);
const char *hi_elf_dyntostr(unsigned type);
const char *hi_elf_segtostr(unsigned type);

View File

@@ -6,7 +6,6 @@
#include "logger.h"
#include "memmap.h"
#include "moduler.h"
#include "symbols.h"
#include "vector.h"
#include <assert.h>
@@ -66,9 +65,6 @@ static int initial_gather_callback(struct dl_phdr_info *info, size_t size,
module.original_name = strdup(modpath);
module.name = module.original_name;
log_debugv("dli_fname: %s\n", dl_info.dli_fname);
log_debugv("dli_sname: %s\n", dl_info.dli_sname);
} else {
// I don't know when this could happen since we're passing the info straight
// from the info struct
@@ -161,7 +157,7 @@ static ModuleData *module_get(const char *path,
for (size_t i = 0; i < vector_size(modules); i++) {
ModuleData *module = &vector_at(modules, i);
const char *name = module->name;
const char *name = module->original_name;
if (strcmp(name, path) == 0) {
return module;
}
@@ -232,10 +228,13 @@ int hi_init(size_t n, const char **enabled_modules) {
return 1;
}
moduler_init(&context.modules);
return 0;
}
void hi_deinit() {
moduler_deinit();
module_data_free(&context.modules);
filewatcher_destroy(context.filewatcher);
log_term();

View File

@@ -8,6 +8,7 @@
#include "memmap.h"
#include "symbols.h"
#include "types.h"
#include "vector.h"
#include <dlfcn.h>
#include <errno.h>
@@ -25,6 +26,28 @@ typedef struct {
MemorySpan memreg;
} PatchData;
typedef struct {
/// The original module name
const char *modname;
/// The associated symbols for that module
VectorSymbol symbols;
} SymbolData;
vector_def(SymbolData, SymbolData);
// TODO: Stop being lazy and put this with the rest of the data
static VectorSymbolData symbol_data;
static VectorSymbol *symdat_get(const char *modname) {
for (size_t i = 0; i < vector_size(&symbol_data); ++i) {
SymbolData *symdat = &vector_at(&symbol_data, i);
if (strcmp(modname, symdat->modname) == 0) {
return &symdat->symbols;
}
}
return NULL;
}
static void *adjust_if_relative(void *ptr, void *module_base) {
uptr p = (uptr)ptr;
if (p && (p < (uptr)module_base)) {
@@ -36,7 +59,7 @@ static void *adjust_if_relative(void *ptr, void *module_base) {
static HiResult moduler_collect_symbols(VectorSymbol *symbols,
const char *module_name,
void *module_base) {
symbol_clear(symbols);
symbols_clear(symbols);
HiResult ret = HI_FAIL;
@@ -181,10 +204,9 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
size_t relasz = 0;
Elf *elf = NULL;
ElfW(Dyn) *dyn_sct =
hi_elf_find_dynamic_segment(module_base, module_size, &elf);
ElfW(Dyn) *dyn = hi_elf_find_dynseg(module_base, module_size, &elf);
for (ElfW(Dyn) *d = dyn_sct; d->d_tag != DT_NULL; d++) {
for (ElfW(Dyn) *d = dyn; d->d_tag != DT_NULL; d++) {
log_debugv("%s\n", hi_elf_dyntostr(d->d_tag));
switch (d->d_tag) {
@@ -247,12 +269,12 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
// Check if this is a symbol we want to patch
for (size_t j = 0; j < num_symbols; j++) {
Symbol *sym = &vector_at(psymbols, j);
if (strcmp(sym->name, name) == 0) {
sym->got_entry = got_entry;
sym->orig_address = *got_entry; // Save the original function
Symbol *s = &vector_at(psymbols, j);
if (strcmp(s->name, name) == 0) {
s->got_entry = got_entry;
s->orig_address = *got_entry; // Save the original function
*got_entry = sym->address;
*got_entry = s->address;
log_debug("Found GOT entry for '%s' at %p (points to %p)\n", name,
got_entry, *got_entry);
@@ -280,15 +302,15 @@ static PatchData moduler_create_patch(ModuleData *module) {
}
char filename[512];
size_t written =
string_concat_buf(sizeof(filename), filename, module->name, file_append);
size_t written = string_concat_buf(sizeof(filename), filename,
module->original_name, file_append);
if (written == 0) {
log_error("Failed to concat %s and %s\n", module->name, ".patch");
log_error("Failed to concat %s and %s\n", module->original_name, ".patch");
return (PatchData){0};
}
file_copy(module->name, filename);
file_copy(module->original_name, filename);
PatchData data = {.filename = strdup(filename)};
return data;
@@ -313,7 +335,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
// Load patch
dlerror(); // clear any previous errors
log_debug("Opening: %s\n", patch.filename);
log_debug("Opening patch: %s\n", patch.filename);
void *new_handle = dlopen(patch.filename, RTLD_LAZY);
if (!new_handle) {
log_error("Couldn't load: %s\n", dlerror());
@@ -333,7 +355,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
void *patch_base = (void *)patch.memreg.start;
VectorSymbol patch_symbols;
symbol_init(&patch_symbols);
symbols_init(&patch_symbols);
HiResult ret =
moduler_collect_symbols(&patch_symbols, patch.filename, patch_base);
@@ -353,44 +375,72 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
moduler_patch_functions(&patch_symbols, module_memory);
// This relies on the patch filename being different only by the append
if (strncmp(mod.name, patch.filename, strlen(mod.name)) == 0) {
if (strncmp(mod.original_name, patch.filename, strlen(mod.original_name)) ==
0) {
// If patch is for the same module, also collect local object symbols for
// coping those over.
VectorSymbol module_symbols;
symbol_init(&module_symbols);
ret = moduler_collect_symbols(&module_symbols, mod.name,
(void *)mod.address);
if (!HIOK(ret)) {
log_error("Failed to gather symbols for %s\n", mod.name);
symbol_term(&module_symbols);
continue;
}
VectorSymbol *module_symbols = symdat_get(mod.original_name);
// Copy old data to new data location. Breaks if layout changes,
// e.g. struct fields are moved around
for (size_t j = 0; j < vector_size(&module_symbols); ++j) {
Symbol *sym = &vector_at(&module_symbols, j);
if (sym->type == HI_SYMBOL_TYPE_OBJECT) {
Symbol *ps = symbol_find(&patch_symbols, sym);
if (ps) {
size_t copy_size = MIN(sym->size, ps->size);
memcpy(ps->address, sym->address, copy_size);
log_debug("Copied data for symbol: %s\n", sym->name);
// Also update our symbol cache for the module. We need to add new
// symbols, but don't care much about deleting old ones
for (size_t j = 0; j < vector_size(&patch_symbols); ++j) {
Symbol *psym = &vector_at(&patch_symbols, j);
if (psym->type == HI_SYMBOL_TYPE_OBJECT) {
Symbol *msym = symbol_find(module_symbols, psym);
if (msym) {
size_t copy_size = MIN(psym->size, msym->size);
memcpy(psym->address, msym->address, copy_size);
log_debug("Copied data for symbol: %s\n", psym->name);
// Maintain our current symbol references with the new address
msym->address = psym->address;
} else {
// A new symbol has been added in the patch
symbols_add(module_symbols, symbol_copy(psym));
log_debug("Found new symbol from patch: %s\n", psym->name);
}
}
}
symbol_term(&module_symbols);
}
}
symbols_term(&patch_symbols);
module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY);
}
free((char *)patch.filename);
symbol_term(&patch_symbols);
dlclose(module->dlhandle);
// Update module reference
if (module->name != module->original_name)
free((char *)module->name);
module->name = patch.filename;
module->address = patch.memreg.start;
module->dlhandle = patch.dlhandle;
return HI_OK;
}
void moduler_init(const VectorModuleData *modules) {
vector_init(&symbol_data);
ModuleData mod = {0};
vector_foreach(modules, mod) {
SymbolData symdata = {0};
symdata.modname = mod.original_name;
vector_add(&symbol_data, symdata);
if (modinfo_has(mod.info, HI_MODULE_STATE_PATCHABLE)) {
moduler_collect_symbols(&vector_last(&symbol_data).symbols,
mod.original_name, (void *)mod.address);
}
vector_init(&symdata.symbols);
}
}
void moduler_deinit() {
SymbolData symdata = {0};
vector_foreach(&symbol_data, symdata) { symbols_term(&symdata.symbols); }
vector_term(&symbol_data);
}

View File

@@ -45,3 +45,5 @@ static inline bool modinfo_has(ModuleInfo flags, ModuleFlags flag) {
#define HI_MODINFO_CLEAR(info, flag) ((info) &= ~flag)
HiResult moduler_reload(VectorModuleData *modules, size_t modindx);
void moduler_init(const VectorModuleData *modules);
void moduler_deinit(void);

View File

@@ -33,23 +33,29 @@ vector_def(Symbol, Symbol);
static inline void symbol_free(Symbol *symbol) { free((char *)symbol->name); }
static inline void symbol_init(VectorSymbol *symbols) { vector_init(symbols); }
static inline void symbol_clear(VectorSymbol *symbols) {
static inline void symbols_init(VectorSymbol *symbols) { vector_init(symbols); }
static inline void symbols_clear(VectorSymbol *symbols) {
for (size_t i = 0; i < vector_size(symbols); ++i) {
symbol_free(&vector_at(symbols, i));
}
vector_clear(symbols);
}
static inline void symbol_term(VectorSymbol *symbols) {
symbol_clear(symbols);
static inline void symbols_term(VectorSymbol *symbols) {
symbols_clear(symbols);
vector_term(symbols);
}
static inline void symbol_add(VectorSymbol *symbols, Symbol symbol) {
static inline void symbols_add(VectorSymbol *symbols, Symbol symbol) {
vector_add(symbols, symbol);
}
static inline Symbol symbol_copy(const Symbol *symbol) {
Symbol sym = *symbol;
sym.name = strdup(sym.name);
return sym;
}
Symbol *symbol_find(VectorSymbol *symbols, Symbol *symbol);
SymbolBind symbol_bind_from_efibind(u32 efi_bind);