18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114 | class ExpMod:
modtype: str
def __init_subclass__(cls) -> None:
if not hasattr(cls, "modtype"):
raise NotImplementedError(f"{cls} must implement class attribute 'modtype'")
@staticmethod
def register(name: str) -> Callable[[type[ExpMod]], type[ExpMod]]:
def register_decorator(cls: type[ExpMod]) -> type[ExpMod]:
if name in expmod_registry[cls.modtype]:
raise Exception(
f"ExpMod.register attempting to register duplicate name '{name}' for module '{cls.modtype}'"
)
expmod_registry[cls.modtype][name] = cls()
return cls
return register_decorator
@classmethod
def get(cls, default: str | None = None) -> Self:
modtype = cls.modtype
name: str | None = (
expmod_modtype_current[modtype]
if expmod_modtype_current[modtype] is not None
else default
)
if name is None:
raise Exception(f"ExpMod couldn't get module for type: '{modtype}'")
return cast(Self, expmod_registry[modtype][name])
@classmethod
def set(cls, name: str, modtype: str | None = None) -> None:
if modtype is None:
modtype = cls.modtype
if modtype not in expmod_registry:
raise Exception(f"ExpMod.set can't find module for type: '{modtype}'")
if name not in expmod_registry[modtype]:
raise Exception(
f"ExpMod.set can't find module for name: '{name}' in module '{modtype}'"
)
expmod_modtype_current[modtype] = name
@staticmethod
def import_file(filename: str, basepath: str = "") -> ModuleType:
module_name = f"roc:expmod:{filename}"
filepath = Path(basepath) / filename
spec = importlib.util.spec_from_file_location(module_name, filepath)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
@staticmethod
def init() -> None:
settings = Config.get()
mods = settings.expmods.copy()
basepaths = settings.expmod_dirs.copy()
basepaths.insert(0, "")
# load module files
missing_mods: list[str] = []
for base in basepaths:
for mod in mods:
file = mod if mod.endswith(".py") else mod + ".py"
try:
expmod_loaded[mod] = ExpMod.import_file(file, base)
except FileNotFoundError:
missing_mods.append(mod)
mods = missing_mods.copy()
missing_mods.clear()
if len(mods) > 0:
raise FileNotFoundError(f"could not load experiment modules: {mods}")
# set modules
use_mods = [m.split(":") for m in settings.expmods_use]
mod_name_count = Counter([m[0] for m in use_mods])
duplicate_names = {k: v for k, v in mod_name_count.items() if v > 1}
if len(duplicate_names) > 0:
dupes = ", ".join(duplicate_names.keys())
raise Exception(f"ExpMod.init found multiple attempts to set the same modules: {dupes}")
for mod_tn in use_mods:
t, n = mod_tn
ExpMod.set(name=n, modtype=t)
|