Skip to content

Main

Reinforcement Learning of Concepts

ActionData = ActionRequest | TakeAction module-attribute

PerceptionData = VisionData | Settled | Feature[Any] module-attribute

__all__ = ['Component', 'GymComponent', 'Perception', 'PerceptionData', 'ActionData', 'Action', 'ExpMod'] module-attribute

ng = None module-attribute

Action

Bases: Component

Component for determining which action to take.

Source code in roc/action.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@register_component("action", "action", auto=True)
class Action(Component):
    """Component for determining which action to take."""

    bus = EventBus[ActionData]("action", cache_depth=10)

    def __init__(self) -> None:
        super().__init__()
        self.action_bus_conn = self.connect_bus(self.bus)
        self.action_bus_conn.listen(self.action_request)

    def event_filter(self, e: ActionEvent) -> bool:
        return isinstance(e.data, ActionRequest)

    def action_request(self, e: ActionEvent) -> None:
        action = DefaultActionExpMod.get(default="pass").get_action()
        actevt = TakeAction(action=action)
        self.action_bus_conn.send(actevt)

action_bus_conn = self.connect_bus(self.bus) instance-attribute

bus = EventBus[ActionData]('action', cache_depth=10) class-attribute instance-attribute

__init__()

Source code in roc/action.py
33
34
35
36
def __init__(self) -> None:
    super().__init__()
    self.action_bus_conn = self.connect_bus(self.bus)
    self.action_bus_conn.listen(self.action_request)

action_request(e)

Source code in roc/action.py
41
42
43
44
def action_request(self, e: ActionEvent) -> None:
    action = DefaultActionExpMod.get(default="pass").get_action()
    actevt = TakeAction(action=action)
    self.action_bus_conn.send(actevt)

event_filter(e)

Source code in roc/action.py
38
39
def event_filter(self, e: ActionEvent) -> bool:
    return isinstance(e.data, ActionRequest)

Component

Bases: ABC

An abstract component class for building pieces of ROC that will talk to each other.

Source code in roc/component.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class Component(ABC):
    """An abstract component class for building pieces of ROC that will talk to each other."""

    name: str = "<name unassigned>"
    type: str = "<type unassigned>"

    def __init__(self) -> None:
        global component_set
        component_set.add(self)
        self.bus_conns: dict[str, BusConnection[Any]] = {}
        logger.trace(f"++ incrementing component count: {self.name}:{self.type} {self}")
        # traceback.print_stack()

    def __del__(self) -> None:
        global component_set
        component_set.add(self)
        logger.trace(f"-- decrementing component count: {self.name}:{self.type} {self}")

    def connect_bus(self, bus: EventBus[T]) -> BusConnection[T]:
        """Create a new bus connection for the component, storing the result for
        later shutdown.

        Args:
            bus (EventBus[T]): The event bus to attach to

        Raises:
            ValueError: if the bus has already been connected to by this component

        Returns:
            BusConnection[T]: The bus connection for listening or sending events
        """
        if bus.name in self.bus_conns:
            raise ValueError(
                f"Component '{self.name}' attempting duplicate connection to bus '{bus.name}'"
            )

        conn = bus.connect(self)
        self.bus_conns[bus.name] = conn
        return conn

    def event_filter(self, e: Event[Any]) -> bool:
        """A filter for any incoming events. By default it filters out events
        sent by itself, but it is especially useful for creating new filters in
        sub-classes.

        Args:
            e (Event[Any]): The event to be evaluated

        Returns:
            bool: True if the event should be sent, False if it should be dropped
        """
        return e.src_id != self.id

    def shutdown(self) -> None:
        """De-initializes the component, removing any bus connections and any
        other clean-up that needs to be performed
        """
        logger.debug(f"Component {self.name}:{self.type} shutting down.")

        for conn in self.bus_conns:
            for obs in self.bus_conns[conn].attached_bus.subject.observers:
                obs.on_completed()
            self.bus_conns[conn].close()

    @property
    def id(self) -> ComponentId:
        return ComponentId(self.type, self.name)

    @staticmethod
    def init() -> None:
        """Loads all components registered as `auto` and perception components
        in the `perception_components` config field.
        """
        settings = Config.get()
        component_list = default_components
        logger.debug("perception components from settings", settings.perception_components)
        component_list = component_list.union(settings.perception_components, default_components)
        logger.debug(f"Component.init: default components: {component_list}")

        # TODO: shutdown previously loaded components

        for reg_str in component_list:
            logger.trace(f"Loading component: {reg_str} ...")
            (name, type) = reg_str.split(":")
            loaded_components[reg_str] = Component.get(name, type)

    @classmethod
    def get(cls, name: str, type: str, *args: Any, **kwargs: Any) -> Self:
        """Retreives a component with the specified name from the registry and
        creates a new version of it with the specified args. Used by
        `Config.init` and for testing.

        Args:
            name (str): The name of the component to get, as specified during
                its registration
            type (str): The type of the component to get, as specified during
                its registration
            args (Any): Fixed position arguments to pass to the Component
                constructor
            kwargs (Any): Keyword args to pass to the Component constructor

        Returns:
            Self: the component that was created, casted as the calling class.
            (e.g. `Perception.get(...)` will return a Perception component and
            `Action.get(...)` will return an Action component)
        """
        reg_str = _component_registry_key(name, type)
        return cast(Self, component_registry[reg_str](*args, **kwargs))

    @staticmethod
    def get_component_count() -> int:
        """Returns the number of currently created Components. The number goes
        up on __init__ and down on __del__. Primarily used for testing to ensure
        Components are being shutdown appropriately.

        Returns:
            int: The number of currently active Component instances
        """
        # global component_count
        # return component_count
        global component_set
        return len(component_set)

    @staticmethod
    def get_loaded_components() -> list[str]:
        """Returns the names and types of all initiated components.

        Returns:
            list[str]: A list of the names and types of components, as strings.
        """
        global loaded_components
        return [s for s in loaded_components.keys()]

    @staticmethod
    def deregister(name: str, type: str) -> None:
        """Removes a component from the Component registry. Primarlly used for testing.

        Args:
            name (str): The name of the Component to deregister
            type (str): The type of the Component to deregister
        """
        reg_str = _component_registry_key(name, type)
        del component_registry[reg_str]

    @staticmethod
    def reset() -> None:
        """Shuts down all components"""
        # shutdown all components
        global loaded_components
        for name in loaded_components:
            logger.trace(f"Shutting down component: {name}.")
            c = loaded_components[name]
            c.shutdown()

        loaded_components.clear()

        global component_set
        for c in component_set:
            c.shutdown()

bus_conns = {} instance-attribute

id property

name = '<name unassigned>' class-attribute instance-attribute

type = '<type unassigned>' class-attribute instance-attribute

__del__()

Source code in roc/component.py
48
49
50
51
def __del__(self) -> None:
    global component_set
    component_set.add(self)
    logger.trace(f"-- decrementing component count: {self.name}:{self.type} {self}")

__init__()

Source code in roc/component.py
41
42
43
44
45
def __init__(self) -> None:
    global component_set
    component_set.add(self)
    self.bus_conns: dict[str, BusConnection[Any]] = {}
    logger.trace(f"++ incrementing component count: {self.name}:{self.type} {self}")

connect_bus(bus)

Create a new bus connection for the component, storing the result for later shutdown.

Parameters:

Name Type Description Default
bus EventBus[T]

The event bus to attach to

required

Raises:

Type Description
ValueError

if the bus has already been connected to by this component

Returns:

Type Description
BusConnection[T]

BusConnection[T]: The bus connection for listening or sending events

Source code in roc/component.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def connect_bus(self, bus: EventBus[T]) -> BusConnection[T]:
    """Create a new bus connection for the component, storing the result for
    later shutdown.

    Args:
        bus (EventBus[T]): The event bus to attach to

    Raises:
        ValueError: if the bus has already been connected to by this component

    Returns:
        BusConnection[T]: The bus connection for listening or sending events
    """
    if bus.name in self.bus_conns:
        raise ValueError(
            f"Component '{self.name}' attempting duplicate connection to bus '{bus.name}'"
        )

    conn = bus.connect(self)
    self.bus_conns[bus.name] = conn
    return conn

deregister(name, type) staticmethod

Removes a component from the Component registry. Primarlly used for testing.

Parameters:

Name Type Description Default
name str

The name of the Component to deregister

required
type str

The type of the Component to deregister

required
Source code in roc/component.py
168
169
170
171
172
173
174
175
176
177
@staticmethod
def deregister(name: str, type: str) -> None:
    """Removes a component from the Component registry. Primarlly used for testing.

    Args:
        name (str): The name of the Component to deregister
        type (str): The type of the Component to deregister
    """
    reg_str = _component_registry_key(name, type)
    del component_registry[reg_str]

event_filter(e)

A filter for any incoming events. By default it filters out events sent by itself, but it is especially useful for creating new filters in sub-classes.

Parameters:

Name Type Description Default
e Event[Any]

The event to be evaluated

required

Returns:

Name Type Description
bool bool

True if the event should be sent, False if it should be dropped

Source code in roc/component.py
75
76
77
78
79
80
81
82
83
84
85
86
def event_filter(self, e: Event[Any]) -> bool:
    """A filter for any incoming events. By default it filters out events
    sent by itself, but it is especially useful for creating new filters in
    sub-classes.

    Args:
        e (Event[Any]): The event to be evaluated

    Returns:
        bool: True if the event should be sent, False if it should be dropped
    """
    return e.src_id != self.id

get(name, type, *args, **kwargs) classmethod

Retreives a component with the specified name from the registry and creates a new version of it with the specified args. Used by Config.init and for testing.

Parameters:

Name Type Description Default
name str

The name of the component to get, as specified during its registration

required
type str

The type of the component to get, as specified during its registration

required
args Any

Fixed position arguments to pass to the Component constructor

()
kwargs Any

Keyword args to pass to the Component constructor

{}

Returns:

Name Type Description
Self Self

the component that was created, casted as the calling class.

Self

(e.g. Perception.get(...) will return a Perception component and

Self

Action.get(...) will return an Action component)

Source code in roc/component.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@classmethod
def get(cls, name: str, type: str, *args: Any, **kwargs: Any) -> Self:
    """Retreives a component with the specified name from the registry and
    creates a new version of it with the specified args. Used by
    `Config.init` and for testing.

    Args:
        name (str): The name of the component to get, as specified during
            its registration
        type (str): The type of the component to get, as specified during
            its registration
        args (Any): Fixed position arguments to pass to the Component
            constructor
        kwargs (Any): Keyword args to pass to the Component constructor

    Returns:
        Self: the component that was created, casted as the calling class.
        (e.g. `Perception.get(...)` will return a Perception component and
        `Action.get(...)` will return an Action component)
    """
    reg_str = _component_registry_key(name, type)
    return cast(Self, component_registry[reg_str](*args, **kwargs))

get_component_count() staticmethod

Returns the number of currently created Components. The number goes up on init and down on del. Primarily used for testing to ensure Components are being shutdown appropriately.

Returns:

Name Type Description
int int

The number of currently active Component instances

Source code in roc/component.py
144
145
146
147
148
149
150
151
152
153
154
155
156
@staticmethod
def get_component_count() -> int:
    """Returns the number of currently created Components. The number goes
    up on __init__ and down on __del__. Primarily used for testing to ensure
    Components are being shutdown appropriately.

    Returns:
        int: The number of currently active Component instances
    """
    # global component_count
    # return component_count
    global component_set
    return len(component_set)

get_loaded_components() staticmethod

Returns the names and types of all initiated components.

Returns:

Type Description
list[str]

list[str]: A list of the names and types of components, as strings.

Source code in roc/component.py
158
159
160
161
162
163
164
165
166
@staticmethod
def get_loaded_components() -> list[str]:
    """Returns the names and types of all initiated components.

    Returns:
        list[str]: A list of the names and types of components, as strings.
    """
    global loaded_components
    return [s for s in loaded_components.keys()]

init() staticmethod

Loads all components registered as auto and perception components in the perception_components config field.

Source code in roc/component.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@staticmethod
def init() -> None:
    """Loads all components registered as `auto` and perception components
    in the `perception_components` config field.
    """
    settings = Config.get()
    component_list = default_components
    logger.debug("perception components from settings", settings.perception_components)
    component_list = component_list.union(settings.perception_components, default_components)
    logger.debug(f"Component.init: default components: {component_list}")

    # TODO: shutdown previously loaded components

    for reg_str in component_list:
        logger.trace(f"Loading component: {reg_str} ...")
        (name, type) = reg_str.split(":")
        loaded_components[reg_str] = Component.get(name, type)

reset() staticmethod

Shuts down all components

Source code in roc/component.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@staticmethod
def reset() -> None:
    """Shuts down all components"""
    # shutdown all components
    global loaded_components
    for name in loaded_components:
        logger.trace(f"Shutting down component: {name}.")
        c = loaded_components[name]
        c.shutdown()

    loaded_components.clear()

    global component_set
    for c in component_set:
        c.shutdown()

shutdown()

De-initializes the component, removing any bus connections and any other clean-up that needs to be performed

Source code in roc/component.py
88
89
90
91
92
93
94
95
96
97
def shutdown(self) -> None:
    """De-initializes the component, removing any bus connections and any
    other clean-up that needs to be performed
    """
    logger.debug(f"Component {self.name}:{self.type} shutting down.")

    for conn in self.bus_conns:
        for obs in self.bus_conns[conn].attached_bus.subject.observers:
            obs.on_completed()
        self.bus_conns[conn].close()

ExpMod

Source code in roc/expmod.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class ExpMod:
    modtype: str

    def __init_subclass__(cls) -> None:
        if not hasattr(cls, "modtype"):
            raise NotImplementedError(f"{cls} must implement class attribute 'modtype'")

    @staticmethod
    def register(name: str) -> Callable[[type[ExpMod]], type[ExpMod]]:
        def register_decorator(cls: type[ExpMod]) -> type[ExpMod]:
            if name in expmod_registry[cls.modtype]:
                raise Exception(
                    f"ExpMod.register attempting to register duplicate name '{name}' for module '{cls.modtype}'"
                )
            expmod_registry[cls.modtype][name] = cls()

            return cls

        return register_decorator

    @classmethod
    def get(cls, default: str | None = None) -> Self:
        modtype = cls.modtype
        name: str | None = (
            expmod_modtype_current[modtype]
            if expmod_modtype_current[modtype] is not None
            else default
        )
        if name is None:
            raise Exception(f"ExpMod couldn't get module for type: '{modtype}'")

        return cast(Self, expmod_registry[modtype][name])

    @classmethod
    def set(cls, name: str, modtype: str | None = None) -> None:
        if modtype is None:
            modtype = cls.modtype

        if modtype not in expmod_registry:
            raise Exception(f"ExpMod.set can't find module for type: '{modtype}'")

        if name not in expmod_registry[modtype]:
            raise Exception(
                f"ExpMod.set can't find module for name: '{name}' in module '{modtype}'"
            )

        expmod_modtype_current[modtype] = name

    @staticmethod
    def import_file(filename: str, basepath: str = "") -> ModuleType:
        module_name = f"roc:expmod:{filename}"
        filepath = Path(basepath) / filename

        spec = importlib.util.spec_from_file_location(module_name, filepath)
        assert spec is not None
        assert spec.loader is not None

        module = importlib.util.module_from_spec(spec)
        sys.modules[module_name] = module
        spec.loader.exec_module(module)

        return module

    @staticmethod
    def init() -> None:
        settings = Config.get()

        mods = settings.expmods.copy()
        basepaths = settings.expmod_dirs.copy()
        basepaths.insert(0, "")

        # load module files
        missing_mods: list[str] = []
        for base in basepaths:
            for mod in mods:
                file = mod if mod.endswith(".py") else mod + ".py"
                try:
                    expmod_loaded[mod] = ExpMod.import_file(file, base)
                except FileNotFoundError:
                    missing_mods.append(mod)
            mods = missing_mods.copy()
            missing_mods.clear()

        if len(mods) > 0:
            raise FileNotFoundError(f"could not load experiment modules: {mods}")

        # set modules
        use_mods = [m.split(":") for m in settings.expmods_use]
        mod_name_count = Counter([m[0] for m in use_mods])
        duplicate_names = {k: v for k, v in mod_name_count.items() if v > 1}
        if len(duplicate_names) > 0:
            dupes = ", ".join(duplicate_names.keys())
            raise Exception(f"ExpMod.init found multiple attempts to set the same modules: {dupes}")

        for mod_tn in use_mods:
            t, n = mod_tn
            ExpMod.set(name=n, modtype=t)

modtype instance-attribute

__init_subclass__()

Source code in roc/expmod.py
21
22
23
def __init_subclass__(cls) -> None:
    if not hasattr(cls, "modtype"):
        raise NotImplementedError(f"{cls} must implement class attribute 'modtype'")

get(default=None) classmethod

Source code in roc/expmod.py
38
39
40
41
42
43
44
45
46
47
48
49
@classmethod
def get(cls, default: str | None = None) -> Self:
    modtype = cls.modtype
    name: str | None = (
        expmod_modtype_current[modtype]
        if expmod_modtype_current[modtype] is not None
        else default
    )
    if name is None:
        raise Exception(f"ExpMod couldn't get module for type: '{modtype}'")

    return cast(Self, expmod_registry[modtype][name])

import_file(filename, basepath='') staticmethod

Source code in roc/expmod.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@staticmethod
def import_file(filename: str, basepath: str = "") -> ModuleType:
    module_name = f"roc:expmod:{filename}"
    filepath = Path(basepath) / filename

    spec = importlib.util.spec_from_file_location(module_name, filepath)
    assert spec is not None
    assert spec.loader is not None

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)

    return module

init() staticmethod

Source code in roc/expmod.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@staticmethod
def init() -> None:
    settings = Config.get()

    mods = settings.expmods.copy()
    basepaths = settings.expmod_dirs.copy()
    basepaths.insert(0, "")

    # load module files
    missing_mods: list[str] = []
    for base in basepaths:
        for mod in mods:
            file = mod if mod.endswith(".py") else mod + ".py"
            try:
                expmod_loaded[mod] = ExpMod.import_file(file, base)
            except FileNotFoundError:
                missing_mods.append(mod)
        mods = missing_mods.copy()
        missing_mods.clear()

    if len(mods) > 0:
        raise FileNotFoundError(f"could not load experiment modules: {mods}")

    # set modules
    use_mods = [m.split(":") for m in settings.expmods_use]
    mod_name_count = Counter([m[0] for m in use_mods])
    duplicate_names = {k: v for k, v in mod_name_count.items() if v > 1}
    if len(duplicate_names) > 0:
        dupes = ", ".join(duplicate_names.keys())
        raise Exception(f"ExpMod.init found multiple attempts to set the same modules: {dupes}")

    for mod_tn in use_mods:
        t, n = mod_tn
        ExpMod.set(name=n, modtype=t)

register(name) staticmethod

Source code in roc/expmod.py
25
26
27
28
29
30
31
32
33
34
35
36
@staticmethod
def register(name: str) -> Callable[[type[ExpMod]], type[ExpMod]]:
    def register_decorator(cls: type[ExpMod]) -> type[ExpMod]:
        if name in expmod_registry[cls.modtype]:
            raise Exception(
                f"ExpMod.register attempting to register duplicate name '{name}' for module '{cls.modtype}'"
            )
        expmod_registry[cls.modtype][name] = cls()

        return cls

    return register_decorator

set(name, modtype=None) classmethod

Source code in roc/expmod.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@classmethod
def set(cls, name: str, modtype: str | None = None) -> None:
    if modtype is None:
        modtype = cls.modtype

    if modtype not in expmod_registry:
        raise Exception(f"ExpMod.set can't find module for type: '{modtype}'")

    if name not in expmod_registry[modtype]:
        raise Exception(
            f"ExpMod.set can't find module for name: '{name}' in module '{modtype}'"
        )

    expmod_modtype_current[modtype] = name

Perception

Bases: Component, ABC

The abstract class for Perception components. Handles perception bus connections and corresponding clean-up.

Source code in roc/perception.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class Perception(Component, ABC):
    """The abstract class for Perception components. Handles perception bus
    connections and corresponding clean-up.
    """

    bus = EventBus[PerceptionData]("perception")

    def __init__(self) -> None:
        super().__init__()
        self.pb_conn = self.connect_bus(Perception.bus)
        self.pb_conn.listen(self.do_perception)

    @abstractmethod
    def do_perception(self, e: PerceptionEvent) -> None: ...

    @classmethod
    def init(cls) -> None:
        global perception_bus
        cls.bus = EventBus[PerceptionData]("perception")

bus = EventBus[PerceptionData]('perception') class-attribute instance-attribute

pb_conn = self.connect_bus(Perception.bus) instance-attribute

__init__()

Source code in roc/perception.py
196
197
198
199
def __init__(self) -> None:
    super().__init__()
    self.pb_conn = self.connect_bus(Perception.bus)
    self.pb_conn.listen(self.do_perception)

do_perception(e) abstractmethod

Source code in roc/perception.py
201
202
@abstractmethod
def do_perception(self, e: PerceptionEvent) -> None: ...

init() classmethod

Source code in roc/perception.py
204
205
206
207
@classmethod
def init(cls) -> None:
    global perception_bus
    cls.bus = EventBus[PerceptionData]("perception")

init(config=None)

Initializes the agent before starting the agent.

Source code in roc/__init__.py
40
41
42
43
44
45
46
47
48
49
def init(config: dict[str, Any] | None = None) -> None:
    """Initializes the agent before starting the agent."""
    Config.init(config)
    roc_logger.init()
    global ng
    ng = NethackGym()
    Component.init()
    ExpMod.init()
    RocJupyterMagics.init()
    init_state()

start()

Starts the agent.

Source code in roc/__init__.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def start() -> None:
    """Starts the agent."""
    global ng
    if ng is None:
        raise Exception("Call .init() before .start()")

    if is_jupyter():
        # if running in Jupyter, start in a thread so that we can still inspect
        # or debug from the iPython shell
        roc_logger.logger.debug("Starting ROC: running in thread")
        t = Thread(target=ng.start)
        t.start()
    else:
        roc_logger.logger.debug("Starting ROC: NOT running in thread")
        ng.start()