preload symbols and use the data instead

This commit is contained in:
2025-05-05 23:17:18 +03:00
parent 7c1145617a
commit b5246d2eb8
6 changed files with 139 additions and 156 deletions

View File

@@ -8,9 +8,15 @@
#include <libelf.h> #include <libelf.h>
#include <unistd.h> #include <unistd.h>
void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) { void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf_ret) {
Elf *elf = elf_memory(module_base, n); Elf *elf = NULL;
if (elf_ret && *elf_ret) {
elf = *elf_ret;
} else {
elf = elf_memory(module_base, n);
}
size_t phdrnum = 0; size_t phdrnum = 0;
int err = elf_getphdrnum(elf, &phdrnum); int err = elf_getphdrnum(elf, &phdrnum);
@@ -24,7 +30,8 @@ void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) {
for (size_t i = 0; i < phdrnum; ++i) { for (size_t i = 0; i < phdrnum; ++i) {
GElf_Phdr phdr; GElf_Phdr phdr;
if (gelf_getphdr(elf, i, &phdr) != &phdr) { if (gelf_getphdr(elf, i, &phdr) != &phdr) {
log_error("Failed to find program headers: %s\n", elf_errmsg(elf_errno())); log_error("Failed to find program headers: %s\n",
elf_errmsg(elf_errno()));
continue; continue;
} }
@@ -43,86 +50,6 @@ void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf_ret) {
return dyn_addr; return dyn_addr;
} }
void hi_elf_print_module_from_memory(void *address, size_t size) {
Elf *elf = elf_memory(address, size);
int err = elf_errno();
if (err) {
log_error("%s\n", elf_errmsg(err));
return;
}
size_t phdrnum = 0;
err = elf_getphdrnum(elf, &phdrnum);
Elf64_Phdr *phdr = elf64_getphdr(elf);
for (size_t i = 0; i < phdrnum; ++i) {
Elf64_Phdr *p = phdr + i;
log_debug("segment type: %s\n", hi_elf_segtostr(p->p_type));
size_t segment_size = p->p_memsz;
void *segment_start = (void*)((uptr)address + p->p_vaddr);
void *segment_end = (void*)((uptr)address + p->p_vaddr + segment_size);
void *strtab = 0;
void *symtab = 0;
size_t strsz = 0;
size_t syment = 0;
void *rela = 0;
size_t relasz = 0;
size_t relaent = 0;
size_t relacount = 0;
void *pltgot = 0;
size_t pltrelsz = 0;
void *pltrel = 0;
if (p->p_type == PT_DYNAMIC) {
Elf64_Dyn *dyn = (Elf64_Dyn *)segment_start;
while ((void *)dyn < segment_end) {
log_debug(" dyn type: %s\n", hi_elf_dyntostr(dyn->d_tag));
if (dyn->d_tag == DT_STRTAB) {
strtab = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_SYMTAB) {
symtab = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_STRSZ) {
strsz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_SYMENT) {
syment = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELA) {
rela = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_RELASZ) {
relasz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELAENT) {
relaent = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_RELACOUNT) {
relacount = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_PLTGOT) {
pltgot = (void *)dyn->d_un.d_ptr;
} else if (dyn->d_tag == DT_PLTRELSZ) {
pltrelsz = dyn->d_un.d_val;
} else if (dyn->d_tag == DT_PLTREL) {
pltrel = (void *)dyn->d_un.d_val;
}
++dyn;
}
log_debug("\nstrtab: %p\n"
"symtab: %p\n"
"strsz: %zu\n"
"syment: %zu\n"
"rela: %p\n"
"relasz: %zu\n"
"relaent: %zu\n"
"relacount: %zu\n"
"pltgot: %p\n"
"pltrelsz: %zu\n"
"pltrel: %p\n",
strtab, symtab, strsz, syment, rela, relasz, relaent,
relacount, pltgot, pltrelsz, pltrel);
}
}
}
#define TYPE_TO_STR(type) \ #define TYPE_TO_STR(type) \
case (type): \ case (type): \
return #type return #type

View File

@@ -1,11 +1,10 @@
#pragma once #pragma once
#include <elf.h> #include <elf.h>
#include <libelf.h>
#include <gelf.h> #include <gelf.h>
#include <libelf.h>
void *hi_elf_find_dynamic_segment(void *module_base, size_t n, Elf **elf); void *hi_elf_find_dynseg(void *module_base, size_t n, Elf **elf);
const char *hi_elf_dyntostr(unsigned type); const char *hi_elf_dyntostr(unsigned type);
const char *hi_elf_segtostr(unsigned type); const char *hi_elf_segtostr(unsigned type);

View File

@@ -6,7 +6,6 @@
#include "logger.h" #include "logger.h"
#include "memmap.h" #include "memmap.h"
#include "moduler.h" #include "moduler.h"
#include "symbols.h"
#include "vector.h" #include "vector.h"
#include <assert.h> #include <assert.h>
@@ -66,9 +65,6 @@ static int initial_gather_callback(struct dl_phdr_info *info, size_t size,
module.original_name = strdup(modpath); module.original_name = strdup(modpath);
module.name = module.original_name; module.name = module.original_name;
log_debugv("dli_fname: %s\n", dl_info.dli_fname);
log_debugv("dli_sname: %s\n", dl_info.dli_sname);
} else { } else {
// I don't know when this could happen since we're passing the info straight // I don't know when this could happen since we're passing the info straight
// from the info struct // from the info struct
@@ -161,7 +157,7 @@ static ModuleData *module_get(const char *path,
for (size_t i = 0; i < vector_size(modules); i++) { for (size_t i = 0; i < vector_size(modules); i++) {
ModuleData *module = &vector_at(modules, i); ModuleData *module = &vector_at(modules, i);
const char *name = module->name; const char *name = module->original_name;
if (strcmp(name, path) == 0) { if (strcmp(name, path) == 0) {
return module; return module;
} }
@@ -232,10 +228,13 @@ int hi_init(size_t n, const char **enabled_modules) {
return 1; return 1;
} }
moduler_init(&context.modules);
return 0; return 0;
} }
void hi_deinit() { void hi_deinit() {
moduler_deinit();
module_data_free(&context.modules); module_data_free(&context.modules);
filewatcher_destroy(context.filewatcher); filewatcher_destroy(context.filewatcher);
log_term(); log_term();

View File

@@ -8,6 +8,7 @@
#include "memmap.h" #include "memmap.h"
#include "symbols.h" #include "symbols.h"
#include "types.h" #include "types.h"
#include "vector.h"
#include <dlfcn.h> #include <dlfcn.h>
#include <errno.h> #include <errno.h>
@@ -25,6 +26,28 @@ typedef struct {
MemorySpan memreg; MemorySpan memreg;
} PatchData; } PatchData;
typedef struct {
/// The original module name
const char *modname;
/// The associated symbols for that module
VectorSymbol symbols;
} SymbolData;
vector_def(SymbolData, SymbolData);
// TODO: Stop being lazy and put this with the rest of the data
static VectorSymbolData symbol_data;
static VectorSymbol *symdat_get(const char *modname) {
for (size_t i = 0; i < vector_size(&symbol_data); ++i) {
SymbolData *symdat = &vector_at(&symbol_data, i);
if (strcmp(modname, symdat->modname) == 0) {
return &symdat->symbols;
}
}
return NULL;
}
static void *adjust_if_relative(void *ptr, void *module_base) { static void *adjust_if_relative(void *ptr, void *module_base) {
uptr p = (uptr)ptr; uptr p = (uptr)ptr;
if (p && (p < (uptr)module_base)) { if (p && (p < (uptr)module_base)) {
@@ -36,7 +59,7 @@ static void *adjust_if_relative(void *ptr, void *module_base) {
static HiResult moduler_collect_symbols(VectorSymbol *symbols, static HiResult moduler_collect_symbols(VectorSymbol *symbols,
const char *module_name, const char *module_name,
void *module_base) { void *module_base) {
symbol_clear(symbols); symbols_clear(symbols);
HiResult ret = HI_FAIL; HiResult ret = HI_FAIL;
@@ -181,10 +204,9 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
size_t relasz = 0; size_t relasz = 0;
Elf *elf = NULL; Elf *elf = NULL;
ElfW(Dyn) *dyn_sct = ElfW(Dyn) *dyn = hi_elf_find_dynseg(module_base, module_size, &elf);
hi_elf_find_dynamic_segment(module_base, module_size, &elf);
for (ElfW(Dyn) *d = dyn_sct; d->d_tag != DT_NULL; d++) { for (ElfW(Dyn) *d = dyn; d->d_tag != DT_NULL; d++) {
log_debugv("%s\n", hi_elf_dyntostr(d->d_tag)); log_debugv("%s\n", hi_elf_dyntostr(d->d_tag));
switch (d->d_tag) { switch (d->d_tag) {
@@ -247,12 +269,12 @@ static HiResult moduler_patch_functions(VectorSymbol *psymbols,
// Check if this is a symbol we want to patch // Check if this is a symbol we want to patch
for (size_t j = 0; j < num_symbols; j++) { for (size_t j = 0; j < num_symbols; j++) {
Symbol *sym = &vector_at(psymbols, j); Symbol *s = &vector_at(psymbols, j);
if (strcmp(sym->name, name) == 0) { if (strcmp(s->name, name) == 0) {
sym->got_entry = got_entry; s->got_entry = got_entry;
sym->orig_address = *got_entry; // Save the original function s->orig_address = *got_entry; // Save the original function
*got_entry = sym->address; *got_entry = s->address;
log_debug("Found GOT entry for '%s' at %p (points to %p)\n", name, log_debug("Found GOT entry for '%s' at %p (points to %p)\n", name,
got_entry, *got_entry); got_entry, *got_entry);
@@ -280,15 +302,15 @@ static PatchData moduler_create_patch(ModuleData *module) {
} }
char filename[512]; char filename[512];
size_t written = size_t written = string_concat_buf(sizeof(filename), filename,
string_concat_buf(sizeof(filename), filename, module->name, file_append); module->original_name, file_append);
if (written == 0) { if (written == 0) {
log_error("Failed to concat %s and %s\n", module->name, ".patch"); log_error("Failed to concat %s and %s\n", module->original_name, ".patch");
return (PatchData){0}; return (PatchData){0};
} }
file_copy(module->name, filename); file_copy(module->original_name, filename);
PatchData data = {.filename = strdup(filename)}; PatchData data = {.filename = strdup(filename)};
return data; return data;
@@ -313,7 +335,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
// Load patch // Load patch
dlerror(); // clear any previous errors dlerror(); // clear any previous errors
log_debug("Opening: %s\n", patch.filename); log_debug("Opening patch: %s\n", patch.filename);
void *new_handle = dlopen(patch.filename, RTLD_LAZY); void *new_handle = dlopen(patch.filename, RTLD_LAZY);
if (!new_handle) { if (!new_handle) {
log_error("Couldn't load: %s\n", dlerror()); log_error("Couldn't load: %s\n", dlerror());
@@ -333,7 +355,7 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
void *patch_base = (void *)patch.memreg.start; void *patch_base = (void *)patch.memreg.start;
VectorSymbol patch_symbols; VectorSymbol patch_symbols;
symbol_init(&patch_symbols); symbols_init(&patch_symbols);
HiResult ret = HiResult ret =
moduler_collect_symbols(&patch_symbols, patch.filename, patch_base); moduler_collect_symbols(&patch_symbols, patch.filename, patch_base);
@@ -353,44 +375,72 @@ HiResult moduler_reload(VectorModuleData *modules, size_t modindx) {
moduler_patch_functions(&patch_symbols, module_memory); moduler_patch_functions(&patch_symbols, module_memory);
// This relies on the patch filename being different only by the append // This relies on the patch filename being different only by the append
if (strncmp(mod.name, patch.filename, strlen(mod.name)) == 0) { if (strncmp(mod.original_name, patch.filename, strlen(mod.original_name)) ==
0) {
// If patch is for the same module, also collect local object symbols for // If patch is for the same module, also collect local object symbols for
// coping those over. // coping those over.
VectorSymbol *module_symbols = symdat_get(mod.original_name);
VectorSymbol module_symbols;
symbol_init(&module_symbols);
ret = moduler_collect_symbols(&module_symbols, mod.name,
(void *)mod.address);
if (!HIOK(ret)) {
log_error("Failed to gather symbols for %s\n", mod.name);
symbol_term(&module_symbols);
continue;
}
// Copy old data to new data location. Breaks if layout changes, // Copy old data to new data location. Breaks if layout changes,
// e.g. struct fields are moved around // e.g. struct fields are moved around
for (size_t j = 0; j < vector_size(&module_symbols); ++j) { // Also update our symbol cache for the module. We need to add new
Symbol *sym = &vector_at(&module_symbols, j); // symbols, but don't care much about deleting old ones
if (sym->type == HI_SYMBOL_TYPE_OBJECT) { for (size_t j = 0; j < vector_size(&patch_symbols); ++j) {
Symbol *ps = symbol_find(&patch_symbols, sym); Symbol *psym = &vector_at(&patch_symbols, j);
if (ps) { if (psym->type == HI_SYMBOL_TYPE_OBJECT) {
size_t copy_size = MIN(sym->size, ps->size); Symbol *msym = symbol_find(module_symbols, psym);
memcpy(ps->address, sym->address, copy_size); if (msym) {
log_debug("Copied data for symbol: %s\n", sym->name); size_t copy_size = MIN(psym->size, msym->size);
memcpy(psym->address, msym->address, copy_size);
log_debug("Copied data for symbol: %s\n", psym->name);
// Maintain our current symbol references with the new address
msym->address = psym->address;
} else {
// A new symbol has been added in the patch
symbols_add(module_symbols, symbol_copy(psym));
log_debug("Found new symbol from patch: %s\n", psym->name);
} }
} }
} }
symbol_term(&module_symbols);
} }
}
symbols_term(&patch_symbols);
module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY); module->info = modinfo_clear(module->info, HI_MODULE_STATE_DIRTY);
}
free((char *)patch.filename);
symbol_term(&patch_symbols);
dlclose(module->dlhandle); dlclose(module->dlhandle);
// Update module reference
if (module->name != module->original_name)
free((char *)module->name);
module->name = patch.filename;
module->address = patch.memreg.start;
module->dlhandle = patch.dlhandle; module->dlhandle = patch.dlhandle;
return HI_OK; return HI_OK;
} }
void moduler_init(const VectorModuleData *modules) {
vector_init(&symbol_data);
ModuleData mod = {0};
vector_foreach(modules, mod) {
SymbolData symdata = {0};
symdata.modname = mod.original_name;
vector_add(&symbol_data, symdata);
if (modinfo_has(mod.info, HI_MODULE_STATE_PATCHABLE)) {
moduler_collect_symbols(&vector_last(&symbol_data).symbols,
mod.original_name, (void *)mod.address);
}
vector_init(&symdata.symbols);
}
}
void moduler_deinit() {
SymbolData symdata = {0};
vector_foreach(&symbol_data, symdata) { symbols_term(&symdata.symbols); }
vector_term(&symbol_data);
}

View File

@@ -45,3 +45,5 @@ static inline bool modinfo_has(ModuleInfo flags, ModuleFlags flag) {
#define HI_MODINFO_CLEAR(info, flag) ((info) &= ~flag) #define HI_MODINFO_CLEAR(info, flag) ((info) &= ~flag)
HiResult moduler_reload(VectorModuleData *modules, size_t modindx); HiResult moduler_reload(VectorModuleData *modules, size_t modindx);
void moduler_init(const VectorModuleData *modules);
void moduler_deinit(void);

View File

@@ -33,23 +33,29 @@ vector_def(Symbol, Symbol);
static inline void symbol_free(Symbol *symbol) { free((char *)symbol->name); } static inline void symbol_free(Symbol *symbol) { free((char *)symbol->name); }
static inline void symbol_init(VectorSymbol *symbols) { vector_init(symbols); } static inline void symbols_init(VectorSymbol *symbols) { vector_init(symbols); }
static inline void symbol_clear(VectorSymbol *symbols) { static inline void symbols_clear(VectorSymbol *symbols) {
for (size_t i = 0; i < vector_size(symbols); ++i) { for (size_t i = 0; i < vector_size(symbols); ++i) {
symbol_free(&vector_at(symbols, i)); symbol_free(&vector_at(symbols, i));
} }
vector_clear(symbols); vector_clear(symbols);
} }
static inline void symbol_term(VectorSymbol *symbols) { static inline void symbols_term(VectorSymbol *symbols) {
symbol_clear(symbols); symbols_clear(symbols);
vector_term(symbols); vector_term(symbols);
} }
static inline void symbol_add(VectorSymbol *symbols, Symbol symbol) { static inline void symbols_add(VectorSymbol *symbols, Symbol symbol) {
vector_add(symbols, symbol); vector_add(symbols, symbol);
} }
static inline Symbol symbol_copy(const Symbol *symbol) {
Symbol sym = *symbol;
sym.name = strdup(sym.name);
return sym;
}
Symbol *symbol_find(VectorSymbol *symbols, Symbol *symbol); Symbol *symbol_find(VectorSymbol *symbols, Symbol *symbol);
SymbolBind symbol_bind_from_efibind(u32 efi_bind); SymbolBind symbol_bind_from_efibind(u32 efi_bind);