Skip to content

graphdb

This module is a wrapper around a graph database and abstracts away all the database-specific features as various classes (GraphDB, Node, Edge, etc)

CacheDefault = TypeVar('CacheDefault') module-attribute

CacheId = TypeVar('CacheId') module-attribute

CacheKey = TypeVar('CacheKey') module-attribute

CacheType = TypeVar('CacheType') module-attribute

CacheValue = TypeVar('CacheValue') module-attribute

EdgeCache = GraphCache[EdgeId, Edge] module-attribute

EdgeCallbackFn = Callable[[Edge], None] module-attribute

EdgeConnectionsList = Iterable[tuple[str, str]] module-attribute

EdgeFilterFn = Callable[[Edge], TypeGuard[Edge]] module-attribute

EdgeId = NewType('EdgeId', int) module-attribute

EdgeType = TypeVar('EdgeType', bound='Edge') module-attribute

NodeCache = GraphCache[NodeId, Node] module-attribute

NodeCallbackFn = Callable[[Node], None] module-attribute

NodeFilterFn = Callable[[Node], bool] module-attribute

NodeId = NewType('NodeId', int) module-attribute

NodeType = TypeVar('NodeType', bound='Node') module-attribute

ProgressFn = Callable[[list[Node]], None] module-attribute

QueryParamType = dict[str, Any] module-attribute

RecordFn = Callable[[str, Iterator[Any]], None] module-attribute

WalkMode = Literal['src', 'dst', 'both'] module-attribute

edge_cache = None module-attribute

edge_registry = {} module-attribute

graph_db_singleton = None module-attribute

next_new_edge = cast(EdgeId, -1) module-attribute

next_new_node = cast(NodeId, -1) module-attribute

node_cache = None module-attribute

node_label_registry = {} module-attribute

node_registry = {} module-attribute

Edge

Bases: BaseModel

An edge (a.k.a. Relationship or Connection) between two Nodes. An edge obect automatically implements all phases of CRUD in the underlying graph database. This is a directional relationship with a "source" and "destination". The source and destination properties are dynamically loaded through property getters when they are called, and may trigger a graph database query if they don't already exist in the edge cache.

Source code in roc/graphdb.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
class Edge(BaseModel, extra="allow"):
    """An edge (a.k.a. Relationship or Connection) between two Nodes. An edge obect automatically
    implements all phases of CRUD in the underlying graph database. This is a directional
    relationship with a "source" and "destination". The source and destination properties
    are dynamically loaded through property getters when they are called, and may trigger
    a graph database query if they don't already exist in the edge cache.
    """

    _id: EdgeId
    type: str = Field(exclude=True)
    src_id: NodeId = Field(exclude=True)
    dst_id: NodeId = Field(exclude=True)
    allowed_connections: EdgeConnectionsList | None = Field(exclude=True, default=None)
    _no_save = False
    _new = False
    _deleted = False

    @property
    def id(self) -> EdgeId:
        return self._id

    @property
    def src(self) -> Node:
        return Node.get(self.src_id)

    @property
    def dst(self) -> Node:
        return Node.get(self.dst_id)

    @property
    def new(self) -> bool:
        return self._new

    def __init__(
        self,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)

        # set passed-in values or their defaults
        # self._db = kwargs["_db"] if "_db" in kwargs else GraphDB.singleton()
        self._id = kwargs["_id"] if "_id" in kwargs else get_next_new_edge_id()

        if self._id < 0:
            self._new = True
            Edge.get_cache()[self.id] = self

    def __del__(self) -> None:
        # print("Edge.__del__:", self)
        Edge.save(self)

    def __repr__(self) -> str:
        return f"Edge({self.id} [{self.src_id}>>{self.dst_id}])"

    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
        super().__init_subclass__(*args, **kwargs)

        if not hasattr(cls, "type"):
            cls.type = Field(exclude=True, default_factory=lambda: cls.__name__)
            edgetype = cls.__name__
        else:
            # XXX: not sure why this makes mypy angry here but not in Node.__init_subclass__
            if isinstance(cls.type, FieldInfo):  # type: ignore
                edgetype = cls.type.get_default(call_default_factory=True)  # type: ignore
            else:
                edgetype = cls.type

        if edgetype in edge_registry:
            raise Exception(
                f"edge_register can't register type '{edgetype}' because it has already been registered"
            )

        edge_registry[edgetype] = cls

    @classmethod
    def get_cache(self) -> EdgeCache:
        global edge_cache
        if edge_cache is None:
            settings = Config.get()
            edge_cache = EdgeCache(maxsize=settings.edge_cache_size)

        return edge_cache

    @classmethod
    def get(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
        """Looks up an Edge based on it's ID. If the Edge is cached, the cached edge is returned;
        otherwise the Edge is queried from the graph database based the ID provided and a new
        Edge is returned and cached.

        Args:
            id (EdgeId): the unique identifier for the Edge
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: returns the Edge requested by the id
        """
        cache = Edge.get_cache()
        e = cache.get(id)
        if not e:
            e = cls.load(id, db=db)
            cache[id] = e

        return cast(Self, e)

    @classmethod
    def load(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
        """Loads an Edge from the graph database without attempting to check if the Edge
        already exists in the cache. Typically this is only called by Edge.get()

        Args:
            id (EdgeId): the unique identifier of the Edge to fetch
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Raises:
            EdgeNotFound: if the specified ID does not exist in the cache or the database

        Returns:
            Self: returns the Edge requested by the id
        """
        db = db or GraphDB.singleton()
        edge_list = list(db.raw_fetch(f"MATCH (n)-[e]-(m) WHERE id(e) = {id} RETURN e LIMIT 1"))
        if not len(edge_list) == 1:
            raise EdgeNotFound(f"Couldn't find edge ID: {id}")

        e = edge_list[0]["e"]
        props = {}
        if hasattr(e, "properties"):
            props = e.properties
        return cls(
            src_id=e.start_id,
            dst_id=e.end_id,
            _id=id,
            type=e.type,
            **props,
        )

    @classmethod
    def save(cls, e: Self, *, db: GraphDB | None = None) -> Self:
        """Saves the edge to the database. Calls Edge.create if the edge is new, or Edge.update if
        edge already exists in the database.

        Args:
            e (Self): The edge to save
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: The same edge that was passed in, for convenience. The Edge may be updated with a
            new identifier if it was newly created in the database.
        """
        if e._new:
            return cls.create(e, db=db)
        else:
            return cls.update(e, db=db)

    @classmethod
    def create(cls, e: Self, *, db: GraphDB | None = None) -> Self:
        """Creates a new edge in the database. Typically only called by Edge.save

        Args:
            e (Self): The edge to create
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Raises:
            EdgeCreateFailed: Failed to write the edge to the database, for eample
                if the ID is wrong.

        Returns:
            Self: the edge that was created, with an updated identifier and other chagned attributes
        """
        if e._no_save or e.src._no_save or e.dst._no_save:
            return e

        db = db or GraphDB.singleton()
        old_id = e.id

        if e.src._new:
            Node.save(e.src)

        if e.dst._new:
            Node.save(e.dst)

        params = {"props": Edge.to_dict(e)}

        ret = list(
            db.raw_fetch(
                f"""
                MATCH (src), (dst)
                WHERE id(src) = {e.src_id} AND id(dst) = {e.dst_id} 
                CREATE (src)-[e:{e.type} $props]->(dst)
                RETURN id(e) as e_id
                """,
                params=params,
            )
        )

        if len(ret) != 1:
            raise EdgeCreateFailed("failed to create new edge")

        e._id = ret[0]["e_id"]
        e._new = False
        # update the cache; if being called during __del__ then the cache entry may not exist
        try:
            cache = Edge.get_cache()
            del cache[old_id]
            cache[e.id] = e
        except KeyError:
            pass
        # update references to edge id
        e.src.src_edges.replace(old_id, e.id)
        e.dst.dst_edges.replace(old_id, e.id)

        return e

    @classmethod
    def update(cls, e: Self, *, db: GraphDB | None = None) -> Self:
        """Updates the edge in the database. Typically only called by Edge.save

        Args:
            e (Self): The edge to update
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: The same edge that was passed in, for convenience
        """
        if e._no_save:
            return e

        db = db or GraphDB.singleton()

        params = {"props": Edge.to_dict(e)}

        db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} SET e = $props", params=params)

        return e

    @classmethod
    def connect(
        cls,
        src: Node | NodeId,
        dst: Node | NodeId,
        edgetype: str | None = None,
        db: GraphDB | None = None,
        **kwargs: Any,
    ) -> Self:
        db = db or GraphDB.singleton()
        src_id = Node.to_id(src)
        dst_id = Node.to_id(dst)
        src_node = Node.get(src_id, db=db)
        dst_node = Node.get(dst_id, db=db)

        clstype: str | None = None
        # lookup class in based on specified type
        if cls is Edge and edgetype in edge_registry:
            cls = edge_registry[edgetype]  # type: ignore

        # get type from class model
        if cls is not Edge:
            clstype = pydantic_get_default(cls, "type")

        # no class found, use edge type instead
        if clstype is None and edgetype is not None:
            clstype = edgetype

        # couldn't find any type
        if clstype is None:
            raise Exception("no Edge type provided")

        # check allowed_connections
        check_schema(cls, clstype, src_node, dst_node, db)

        e = cls(src_id=src_id, dst_id=dst_id, type=clstype, **kwargs)
        src_node.src_edges.add(e)
        dst_node.dst_edges.add(e)

        return e

    @staticmethod
    def delete(e: Edge, *, db: GraphDB | None = None) -> None:
        """Deletes the specified edge from the database. If the edge has not already been persisted
        to the database, this marks the edge as deleted and returns.

        Args:
            e (Edge): The edge to delete
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
        """
        e._deleted = True
        e._no_save = True
        db = db or GraphDB.singleton()

        # remove e from src and dst nodes
        e.src.src_edges.discard(e)
        e.dst.dst_edges.discard(e)

        # remove from cache
        edge_cache = Edge.get_cache()
        if e.id in edge_cache:
            del edge_cache[e.id]

        # delete from db
        if not e._new:
            db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} DELETE e")

    @staticmethod
    def to_dict(e: Edge, include_type: bool = False) -> dict[str, Any]:
        """Convert a Edge to a Python dictionary"""
        # XXX: the excluded fields below shouldn't have been included in the
        # first place because Pythonic should exclude fields with underscores
        ret = e.model_dump(exclude={"_id"})

        if include_type and hasattr(e, "type"):
            ret["type"] = e.type
        return ret

    @staticmethod
    def to_id(e: Edge | EdgeId) -> EdgeId:
        if isinstance(e, Edge):
            return e.id
        else:
            return e

allowed_connections = Field(exclude=True, default=None) class-attribute instance-attribute

dst property

dst_id = Field(exclude=True) class-attribute instance-attribute

id property

new property

src property

src_id = Field(exclude=True) class-attribute instance-attribute

type = Field(exclude=True) class-attribute instance-attribute

__del__()

Source code in roc/graphdb.py
351
352
353
def __del__(self) -> None:
    # print("Edge.__del__:", self)
    Edge.save(self)

__init__(**kwargs)

Source code in roc/graphdb.py
337
338
339
340
341
342
343
344
345
346
347
348
349
def __init__(
    self,
    **kwargs: Any,
):
    super().__init__(**kwargs)

    # set passed-in values or their defaults
    # self._db = kwargs["_db"] if "_db" in kwargs else GraphDB.singleton()
    self._id = kwargs["_id"] if "_id" in kwargs else get_next_new_edge_id()

    if self._id < 0:
        self._new = True
        Edge.get_cache()[self.id] = self

__init_subclass__(*args, **kwargs)

Source code in roc/graphdb.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
    super().__init_subclass__(*args, **kwargs)

    if not hasattr(cls, "type"):
        cls.type = Field(exclude=True, default_factory=lambda: cls.__name__)
        edgetype = cls.__name__
    else:
        # XXX: not sure why this makes mypy angry here but not in Node.__init_subclass__
        if isinstance(cls.type, FieldInfo):  # type: ignore
            edgetype = cls.type.get_default(call_default_factory=True)  # type: ignore
        else:
            edgetype = cls.type

    if edgetype in edge_registry:
        raise Exception(
            f"edge_register can't register type '{edgetype}' because it has already been registered"
        )

    edge_registry[edgetype] = cls

__repr__()

Source code in roc/graphdb.py
355
356
def __repr__(self) -> str:
    return f"Edge({self.id} [{self.src_id}>>{self.dst_id}])"

connect(src, dst, edgetype=None, db=None, **kwargs) classmethod

Source code in roc/graphdb.py
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
@classmethod
def connect(
    cls,
    src: Node | NodeId,
    dst: Node | NodeId,
    edgetype: str | None = None,
    db: GraphDB | None = None,
    **kwargs: Any,
) -> Self:
    db = db or GraphDB.singleton()
    src_id = Node.to_id(src)
    dst_id = Node.to_id(dst)
    src_node = Node.get(src_id, db=db)
    dst_node = Node.get(dst_id, db=db)

    clstype: str | None = None
    # lookup class in based on specified type
    if cls is Edge and edgetype in edge_registry:
        cls = edge_registry[edgetype]  # type: ignore

    # get type from class model
    if cls is not Edge:
        clstype = pydantic_get_default(cls, "type")

    # no class found, use edge type instead
    if clstype is None and edgetype is not None:
        clstype = edgetype

    # couldn't find any type
    if clstype is None:
        raise Exception("no Edge type provided")

    # check allowed_connections
    check_schema(cls, clstype, src_node, dst_node, db)

    e = cls(src_id=src_id, dst_id=dst_id, type=clstype, **kwargs)
    src_node.src_edges.add(e)
    dst_node.dst_edges.add(e)

    return e

create(e, *, db=None) classmethod

Creates a new edge in the database. Typically only called by Edge.save

Parameters:

Name Type Description Default
e Self

The edge to create

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Raises:

Type Description
EdgeCreateFailed

Failed to write the edge to the database, for eample if the ID is wrong.

Returns:

Name Type Description
Self Self

the edge that was created, with an updated identifier and other chagned attributes

Source code in roc/graphdb.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
@classmethod
def create(cls, e: Self, *, db: GraphDB | None = None) -> Self:
    """Creates a new edge in the database. Typically only called by Edge.save

    Args:
        e (Self): The edge to create
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Raises:
        EdgeCreateFailed: Failed to write the edge to the database, for eample
            if the ID is wrong.

    Returns:
        Self: the edge that was created, with an updated identifier and other chagned attributes
    """
    if e._no_save or e.src._no_save or e.dst._no_save:
        return e

    db = db or GraphDB.singleton()
    old_id = e.id

    if e.src._new:
        Node.save(e.src)

    if e.dst._new:
        Node.save(e.dst)

    params = {"props": Edge.to_dict(e)}

    ret = list(
        db.raw_fetch(
            f"""
            MATCH (src), (dst)
            WHERE id(src) = {e.src_id} AND id(dst) = {e.dst_id} 
            CREATE (src)-[e:{e.type} $props]->(dst)
            RETURN id(e) as e_id
            """,
            params=params,
        )
    )

    if len(ret) != 1:
        raise EdgeCreateFailed("failed to create new edge")

    e._id = ret[0]["e_id"]
    e._new = False
    # update the cache; if being called during __del__ then the cache entry may not exist
    try:
        cache = Edge.get_cache()
        del cache[old_id]
        cache[e.id] = e
    except KeyError:
        pass
    # update references to edge id
    e.src.src_edges.replace(old_id, e.id)
    e.dst.dst_edges.replace(old_id, e.id)

    return e

delete(e, *, db=None) staticmethod

Deletes the specified edge from the database. If the edge has not already been persisted to the database, this marks the edge as deleted and returns.

Parameters:

Name Type Description Default
e Edge

The edge to delete

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None
Source code in roc/graphdb.py
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
@staticmethod
def delete(e: Edge, *, db: GraphDB | None = None) -> None:
    """Deletes the specified edge from the database. If the edge has not already been persisted
    to the database, this marks the edge as deleted and returns.

    Args:
        e (Edge): The edge to delete
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton
    """
    e._deleted = True
    e._no_save = True
    db = db or GraphDB.singleton()

    # remove e from src and dst nodes
    e.src.src_edges.discard(e)
    e.dst.dst_edges.discard(e)

    # remove from cache
    edge_cache = Edge.get_cache()
    if e.id in edge_cache:
        del edge_cache[e.id]

    # delete from db
    if not e._new:
        db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} DELETE e")

get(id, *, db=None) classmethod

Looks up an Edge based on it's ID. If the Edge is cached, the cached edge is returned; otherwise the Edge is queried from the graph database based the ID provided and a new Edge is returned and cached.

Parameters:

Name Type Description Default
id EdgeId

the unique identifier for the Edge

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

returns the Edge requested by the id

Source code in roc/graphdb.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
@classmethod
def get(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
    """Looks up an Edge based on it's ID. If the Edge is cached, the cached edge is returned;
    otherwise the Edge is queried from the graph database based the ID provided and a new
    Edge is returned and cached.

    Args:
        id (EdgeId): the unique identifier for the Edge
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: returns the Edge requested by the id
    """
    cache = Edge.get_cache()
    e = cache.get(id)
    if not e:
        e = cls.load(id, db=db)
        cache[id] = e

    return cast(Self, e)

get_cache() classmethod

Source code in roc/graphdb.py
378
379
380
381
382
383
384
385
@classmethod
def get_cache(self) -> EdgeCache:
    global edge_cache
    if edge_cache is None:
        settings = Config.get()
        edge_cache = EdgeCache(maxsize=settings.edge_cache_size)

    return edge_cache

load(id, *, db=None) classmethod

Loads an Edge from the graph database without attempting to check if the Edge already exists in the cache. Typically this is only called by Edge.get()

Parameters:

Name Type Description Default
id EdgeId

the unique identifier of the Edge to fetch

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Raises:

Type Description
EdgeNotFound

if the specified ID does not exist in the cache or the database

Returns:

Name Type Description
Self Self

returns the Edge requested by the id

Source code in roc/graphdb.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
@classmethod
def load(cls, id: EdgeId, *, db: GraphDB | None = None) -> Self:
    """Loads an Edge from the graph database without attempting to check if the Edge
    already exists in the cache. Typically this is only called by Edge.get()

    Args:
        id (EdgeId): the unique identifier of the Edge to fetch
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Raises:
        EdgeNotFound: if the specified ID does not exist in the cache or the database

    Returns:
        Self: returns the Edge requested by the id
    """
    db = db or GraphDB.singleton()
    edge_list = list(db.raw_fetch(f"MATCH (n)-[e]-(m) WHERE id(e) = {id} RETURN e LIMIT 1"))
    if not len(edge_list) == 1:
        raise EdgeNotFound(f"Couldn't find edge ID: {id}")

    e = edge_list[0]["e"]
    props = {}
    if hasattr(e, "properties"):
        props = e.properties
    return cls(
        src_id=e.start_id,
        dst_id=e.end_id,
        _id=id,
        type=e.type,
        **props,
    )

save(e, *, db=None) classmethod

Saves the edge to the database. Calls Edge.create if the edge is new, or Edge.update if edge already exists in the database.

Parameters:

Name Type Description Default
e Self

The edge to save

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

The same edge that was passed in, for convenience. The Edge may be updated with a

Self

new identifier if it was newly created in the database.

Source code in roc/graphdb.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
@classmethod
def save(cls, e: Self, *, db: GraphDB | None = None) -> Self:
    """Saves the edge to the database. Calls Edge.create if the edge is new, or Edge.update if
    edge already exists in the database.

    Args:
        e (Self): The edge to save
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: The same edge that was passed in, for convenience. The Edge may be updated with a
        new identifier if it was newly created in the database.
    """
    if e._new:
        return cls.create(e, db=db)
    else:
        return cls.update(e, db=db)

to_dict(e, include_type=False) staticmethod

Convert a Edge to a Python dictionary

Source code in roc/graphdb.py
606
607
608
609
610
611
612
613
614
615
@staticmethod
def to_dict(e: Edge, include_type: bool = False) -> dict[str, Any]:
    """Convert a Edge to a Python dictionary"""
    # XXX: the excluded fields below shouldn't have been included in the
    # first place because Pythonic should exclude fields with underscores
    ret = e.model_dump(exclude={"_id"})

    if include_type and hasattr(e, "type"):
        ret["type"] = e.type
    return ret

to_id(e) staticmethod

Source code in roc/graphdb.py
617
618
619
620
621
622
@staticmethod
def to_id(e: Edge | EdgeId) -> EdgeId:
    if isinstance(e, Edge):
        return e.id
    else:
        return e

update(e, *, db=None) classmethod

Updates the edge in the database. Typically only called by Edge.save

Parameters:

Name Type Description Default
e Self

The edge to update

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

The same edge that was passed in, for convenience

Source code in roc/graphdb.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
@classmethod
def update(cls, e: Self, *, db: GraphDB | None = None) -> Self:
    """Updates the edge in the database. Typically only called by Edge.save

    Args:
        e (Self): The edge to update
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: The same edge that was passed in, for convenience
    """
    if e._no_save:
        return e

    db = db or GraphDB.singleton()

    params = {"props": Edge.to_dict(e)}

    db.raw_execute(f"MATCH ()-[e]->() WHERE id(e) = {e.id} SET e = $props", params=params)

    return e

EdgeCreateFailed

Bases: Exception

Source code in roc/graphdb.py
292
293
class EdgeCreateFailed(Exception):
    pass

EdgeDescription

Bases: ModelDescription

Source code in roc/graphdb.py
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
class EdgeDescription(ModelDescription):
    def __init__(self, edge_cls: type[Edge]) -> None:
        super().__init__(edge_cls)

        self.edge_cls = edge_cls
        self.name = edge_cls.__name__
        self.edgetype = pydantic_get_default(edge_cls, "type")

        # allowed connections
        self.allowed_connections = cast(
            EdgeConnectionsList, pydantic_get_default(edge_cls, "allowed_connections")
        )
        assert self.allowed_connections is not None

        # related nodes
        self.related_nodes: set[str] = set()
        for conn in self.allowed_connections:
            self.related_nodes.add(conn[0])
            self.related_nodes.add(conn[1])

    def __str__(self) -> str:
        return f"EdgeDesc({self.name})"

    @property
    def resolved_name(self) -> str:
        if self.edgetype == self.name:
            return self.name

        return f"{self.edgetype} ({self.name})"

    def to_mermaid(self, indent: int = 4) -> str:
        ret = f"""\n{' ':>{indent}}%% Edge: {self.resolved_name}\n"""

        # add connections
        for conn in self.allowed_connections:
            ret += f"""{' ':>{indent}}{conn[0]} --> {conn[1]}: {self.resolved_name}\n"""

        return ret

allowed_connections = cast(EdgeConnectionsList, pydantic_get_default(edge_cls, 'allowed_connections')) instance-attribute

edge_cls = edge_cls instance-attribute

edgetype = pydantic_get_default(edge_cls, 'type') instance-attribute

name = edge_cls.__name__ instance-attribute

related_nodes = set() instance-attribute

resolved_name property

__init__(edge_cls)

Source code in roc/graphdb.py
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
def __init__(self, edge_cls: type[Edge]) -> None:
    super().__init__(edge_cls)

    self.edge_cls = edge_cls
    self.name = edge_cls.__name__
    self.edgetype = pydantic_get_default(edge_cls, "type")

    # allowed connections
    self.allowed_connections = cast(
        EdgeConnectionsList, pydantic_get_default(edge_cls, "allowed_connections")
    )
    assert self.allowed_connections is not None

    # related nodes
    self.related_nodes: set[str] = set()
    for conn in self.allowed_connections:
        self.related_nodes.add(conn[0])
        self.related_nodes.add(conn[1])

__str__()

Source code in roc/graphdb.py
1806
1807
def __str__(self) -> str:
    return f"EdgeDesc({self.name})"

to_mermaid(indent=4)

Source code in roc/graphdb.py
1816
1817
1818
1819
1820
1821
1822
1823
def to_mermaid(self, indent: int = 4) -> str:
    ret = f"""\n{' ':>{indent}}%% Edge: {self.resolved_name}\n"""

    # add connections
    for conn in self.allowed_connections:
        ret += f"""{' ':>{indent}}{conn[0]} --> {conn[1]}: {self.resolved_name}\n"""

    return ret

EdgeFetchIterator

The implementation of an iterator for an EdgeList. Only intended to be used internally by EdgeList.

Source code in roc/graphdb.py
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
class EdgeFetchIterator:
    """The implementation of an iterator for an EdgeList. Only intended to be used internally by
    EdgeList.
    """

    def __init__(self, edge_list: list[EdgeId]):
        self.__edge_list = edge_list
        self.cur = 0

    def __iter__(self) -> EdgeFetchIterator:
        return self

    def __next__(self) -> Edge:
        if self.cur >= len(self.__edge_list):
            raise StopIteration

        id = self.__edge_list[self.cur]
        self.cur = self.cur + 1
        return Edge.get(id)

__edge_list = edge_list instance-attribute

cur = 0 instance-attribute

__init__(edge_list)

Source code in roc/graphdb.py
676
677
678
def __init__(self, edge_list: list[EdgeId]):
    self.__edge_list = edge_list
    self.cur = 0

__iter__()

Source code in roc/graphdb.py
680
681
def __iter__(self) -> EdgeFetchIterator:
    return self

__next__()

Source code in roc/graphdb.py
683
684
685
686
687
688
689
def __next__(self) -> Edge:
    if self.cur >= len(self.__edge_list):
        raise StopIteration

    id = self.__edge_list[self.cur]
    self.cur = self.cur + 1
    return Edge.get(id)

EdgeList

Bases: MutableSet[Edge | EdgeId], Mapping[int, Edge]

A list of Edges that is used by Node for keeping track of the connections it has. Implements interfaces for both a MutableSet (i.e. set()) and a Mapping (i.e. read-only list())

Source code in roc/graphdb.py
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
class EdgeList(MutableSet[Edge | EdgeId], Mapping[int, Edge]):
    """A list of Edges that is used by Node for keeping track of the connections it has.
    Implements interfaces for both a MutableSet (i.e. set()) and a Mapping (i.e. read-only list())
    """

    def __init__(self, ids: Iterable[EdgeId]):
        self.__edges: list[EdgeId] = list(ids)

    def __iter__(self) -> EdgeFetchIterator:
        return EdgeFetchIterator(self.__edges)

    def __getitem__(self, key: int) -> Edge:
        return Edge.get(self.__edges[key])

    def __len__(self) -> int:
        return len(self.__edges)

    def __contains__(self, e: Any) -> bool:
        if isinstance(e, Edge) or isinstance(e, int):
            e_id = Edge.to_id(e)  # type: ignore
        else:
            return False

        return e_id in self.__edges

    def __add__(self, l2: EdgeList) -> EdgeList:
        return EdgeList(self.__edges + l2.__edges)

    def add(self, e: Edge | EdgeId) -> None:
        """Adds a new Edge to the list"""
        e_id = Edge.to_id(e)

        if e_id in self.__edges:
            return

        self.__edges.append(e_id)

    def discard(self, e: Edge | EdgeId) -> None:
        """Removes an edge from the list"""
        e_id = Edge.to_id(e)

        self.__edges.remove(e_id)

    def replace(self, old: Edge | EdgeId, new: Edge | EdgeId) -> None:
        """Replaces all instances of an old Edge with a new Edge. Useful for when an Edge is
        persisted to the graph database and its permanent ID is assigned
        """
        old_id = Edge.to_id(old)
        new_id = Edge.to_id(new)
        for i in range(len(self.__edges)):
            if self.__edges[i] == old_id:
                self.__edges[i] = new_id

    def select(
        self,
        *,
        filter_fn: EdgeFilterFn | None = None,
        type: str | None = None,
        id: EdgeId | None = None,
    ) -> EdgeList:
        edge_ids = self.__edges
        if filter_fn is not None:
            # TODO: Edge.get_many() would be more efficient here if / when it
            # gets implemented
            edge_ids = [e for e in edge_ids if filter_fn(Edge.get(e))]

        if type is not None:
            edge_ids = [e for e in edge_ids if Edge.get(e).type == type]

        if id is not None:
            edge_ids = [e for e in edge_ids if e == id]

        return EdgeList(edge_ids)

__edges = list(ids) instance-attribute

__add__(l2)

Source code in roc/graphdb.py
717
718
def __add__(self, l2: EdgeList) -> EdgeList:
    return EdgeList(self.__edges + l2.__edges)

__contains__(e)

Source code in roc/graphdb.py
709
710
711
712
713
714
715
def __contains__(self, e: Any) -> bool:
    if isinstance(e, Edge) or isinstance(e, int):
        e_id = Edge.to_id(e)  # type: ignore
    else:
        return False

    return e_id in self.__edges

__getitem__(key)

Source code in roc/graphdb.py
703
704
def __getitem__(self, key: int) -> Edge:
    return Edge.get(self.__edges[key])

__init__(ids)

Source code in roc/graphdb.py
697
698
def __init__(self, ids: Iterable[EdgeId]):
    self.__edges: list[EdgeId] = list(ids)

__iter__()

Source code in roc/graphdb.py
700
701
def __iter__(self) -> EdgeFetchIterator:
    return EdgeFetchIterator(self.__edges)

__len__()

Source code in roc/graphdb.py
706
707
def __len__(self) -> int:
    return len(self.__edges)

add(e)

Adds a new Edge to the list

Source code in roc/graphdb.py
720
721
722
723
724
725
726
727
def add(self, e: Edge | EdgeId) -> None:
    """Adds a new Edge to the list"""
    e_id = Edge.to_id(e)

    if e_id in self.__edges:
        return

    self.__edges.append(e_id)

discard(e)

Removes an edge from the list

Source code in roc/graphdb.py
729
730
731
732
733
def discard(self, e: Edge | EdgeId) -> None:
    """Removes an edge from the list"""
    e_id = Edge.to_id(e)

    self.__edges.remove(e_id)

replace(old, new)

Replaces all instances of an old Edge with a new Edge. Useful for when an Edge is persisted to the graph database and its permanent ID is assigned

Source code in roc/graphdb.py
735
736
737
738
739
740
741
742
743
def replace(self, old: Edge | EdgeId, new: Edge | EdgeId) -> None:
    """Replaces all instances of an old Edge with a new Edge. Useful for when an Edge is
    persisted to the graph database and its permanent ID is assigned
    """
    old_id = Edge.to_id(old)
    new_id = Edge.to_id(new)
    for i in range(len(self.__edges)):
        if self.__edges[i] == old_id:
            self.__edges[i] = new_id

select(*, filter_fn=None, type=None, id=None)

Source code in roc/graphdb.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
def select(
    self,
    *,
    filter_fn: EdgeFilterFn | None = None,
    type: str | None = None,
    id: EdgeId | None = None,
) -> EdgeList:
    edge_ids = self.__edges
    if filter_fn is not None:
        # TODO: Edge.get_many() would be more efficient here if / when it
        # gets implemented
        edge_ids = [e for e in edge_ids if filter_fn(Edge.get(e))]

    if type is not None:
        edge_ids = [e for e in edge_ids if Edge.get(e).type == type]

    if id is not None:
        edge_ids = [e for e in edge_ids if e == id]

    return EdgeList(edge_ids)

EdgeNotFound

Bases: Exception

Source code in roc/graphdb.py
288
289
class EdgeNotFound(Exception):
    pass

ErrorSavingDuringDelWarning

Bases: Warning

An error that occurs while saving a Node during del

Source code in roc/graphdb.py
61
62
class ErrorSavingDuringDelWarning(Warning):
    """An error that occurs while saving a Node during __del__"""

FieldDescription

Source code in roc/graphdb.py
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
class FieldDescription:
    def __init__(self, model: type[BaseModel], fieldname: str) -> None:
        self.model = model
        self.field_info = pydantic_get_field(model, fieldname)
        self.name = fieldname
        self.default_val = self.field_info.get_default(call_default_factory=True)
        self.type = clean_annotation(self.field_info.annotation)
        self.exclude = self.field_info.exclude

    def __str__(self) -> str:
        return f"Field({self.name}: {self.type} = {self.default_val})"

    @property
    def default_val_str(self) -> str:
        """Control over a reliable and reproducable default value for printing
        the schema.
        """
        if isinstance(self.default_val, set):
            return str(sorted(self.default_val))

        return str(self.default_val)

default_val = self.field_info.get_default(call_default_factory=True) instance-attribute

default_val_str property

Control over a reliable and reproducable default value for printing the schema.

exclude = self.field_info.exclude instance-attribute

field_info = pydantic_get_field(model, fieldname) instance-attribute

model = model instance-attribute

name = fieldname instance-attribute

type = clean_annotation(self.field_info.annotation) instance-attribute

__init__(model, fieldname)

Source code in roc/graphdb.py
1679
1680
1681
1682
1683
1684
1685
def __init__(self, model: type[BaseModel], fieldname: str) -> None:
    self.model = model
    self.field_info = pydantic_get_field(model, fieldname)
    self.name = fieldname
    self.default_val = self.field_info.get_default(call_default_factory=True)
    self.type = clean_annotation(self.field_info.annotation)
    self.exclude = self.field_info.exclude

__str__()

Source code in roc/graphdb.py
1687
1688
def __str__(self) -> str:
    return f"Field({self.name}: {self.type} = {self.default_val})"

GraphCache

Bases: LRUCache[CacheKey, CacheValue], Generic[CacheKey, CacheValue]

A generic cache that is used for both the Node cache and the Edge cache

Source code in roc/graphdb.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
class GraphCache(LRUCache[CacheKey, CacheValue], Generic[CacheKey, CacheValue]):
    """A generic cache that is used for both the Node cache and the Edge cache"""

    def __init__(self, maxsize: int):
        super().__init__(maxsize=maxsize)
        self.hits = 0
        self.misses = 0

    def __str__(self) -> str:
        return f"Size: {self.currsize}/{self.maxsize} ({self.currsize/self.maxsize*100:1.2f}%), Hits: {self.hits}, Misses: {self.misses}"

    def get(  # type: ignore [override]
        self,
        key: CacheKey,
        /,
        default: CacheValue | None = None,
    ) -> CacheValue | None:
        """Uses the specified CacheKey to fetch an object from the cache.

        Args:
            key (CacheKey): The key to use to fetch the object
            default (CacheValue | None, optional): If the object isn't found,
                the default value to return. Defaults to None.

        Returns:
            CacheValue | None: The object from the cache, or None if not found.
        """
        v = super().get(key)
        if not v:
            self.misses = self.misses + 1
            if self.currsize == self.maxsize:
                logger.warning(
                    f"Cache miss and cache is full ({self.currsize}/{self.maxsize}). Cache may start thrashing and performance may be impaired."
                )
        else:
            self.hits = self.hits + 1
        return v

    def clear(self) -> None:
        """Clears out all items from the cache and resets the cache
        statistics
        """
        super().clear()
        self.hits = 0
        self.misses = 0

hits = 0 instance-attribute

misses = 0 instance-attribute

__init__(maxsize)

Source code in roc/graphdb.py
241
242
243
244
def __init__(self, maxsize: int):
    super().__init__(maxsize=maxsize)
    self.hits = 0
    self.misses = 0

__str__()

Source code in roc/graphdb.py
246
247
def __str__(self) -> str:
    return f"Size: {self.currsize}/{self.maxsize} ({self.currsize/self.maxsize*100:1.2f}%), Hits: {self.hits}, Misses: {self.misses}"

clear()

Clears out all items from the cache and resets the cache statistics

Source code in roc/graphdb.py
276
277
278
279
280
281
282
def clear(self) -> None:
    """Clears out all items from the cache and resets the cache
    statistics
    """
    super().clear()
    self.hits = 0
    self.misses = 0

get(key, /, default=None)

Uses the specified CacheKey to fetch an object from the cache.

Parameters:

Name Type Description Default
key CacheKey

The key to use to fetch the object

required
default CacheValue | None

If the object isn't found, the default value to return. Defaults to None.

None

Returns:

Type Description
CacheValue | None

CacheValue | None: The object from the cache, or None if not found.

Source code in roc/graphdb.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def get(  # type: ignore [override]
    self,
    key: CacheKey,
    /,
    default: CacheValue | None = None,
) -> CacheValue | None:
    """Uses the specified CacheKey to fetch an object from the cache.

    Args:
        key (CacheKey): The key to use to fetch the object
        default (CacheValue | None, optional): If the object isn't found,
            the default value to return. Defaults to None.

    Returns:
        CacheValue | None: The object from the cache, or None if not found.
    """
    v = super().get(key)
    if not v:
        self.misses = self.misses + 1
        if self.currsize == self.maxsize:
            logger.warning(
                f"Cache miss and cache is full ({self.currsize}/{self.maxsize}). Cache may start thrashing and performance may be impaired."
            )
    else:
        self.hits = self.hits + 1
    return v

GraphDB

A graph database singleton. Settings for the graph database come from the config module.

Source code in roc/graphdb.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class GraphDB:
    """A graph database singleton. Settings for the graph database come from the config module."""

    def __init__(self) -> None:
        settings = Config.get()
        self.host = settings.db_host
        self.port = settings.db_port
        self.encrypted = settings.db_conn_encrypted
        self.username = settings.db_username
        self.password = settings.db_password
        self.lazy = settings.db_lazy
        self.strict_schema = settings.db_strict_schema
        self.strict_schema_warns = settings.db_strict_schema_warns
        self.client_name = "roc-graphdb-client"
        self.db_conn = self.connect()
        self.closed = False

        if self.strict_schema:
            Schema.validate()

    def raw_fetch(
        self, query: str, *, params: dict[str, Any] | None = None
    ) -> Iterator[dict[str, Any]]:
        """Executes a Cypher query and returns the results as an iterator of
        dictionaries. Used for any query that has a 'RETURN' clause.

        Args:
            query (str): The Cypher query to execute
            params (dict[str, Any] | None, optional): Any parameters to pass to
                the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters

        Yields:
            Iterator[dict[str, Any]]: An iterator of the results from the database.
        """
        params = params or {}
        logger.trace(f"raw_fetch: '{query}' *** with params: *** '{params}")

        cursor = self.db_conn.cursor()
        cursor.execute(query, params)
        while True:
            row = cursor.fetchone()
            if row is None:
                break
            yield {dsc.name: row[index] for index, dsc in enumerate(cursor.description)}

    def raw_execute(self, query: str, *, params: dict[str, Any] | None = None) -> None:
        """Executes a query with no return value. Used for 'SET', 'DELETE' or
        other queries without a 'RETURN' clause.

        Args:
            query (str): The Cypher query to execute
            params (dict[str, Any] | None, optional): Any parameters to pass to
                the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters
        """
        params = params or {}
        logger.trace(f"raw_execute: '{query}' *** with params: *** '{params}'")

        cursor = self.db_conn.cursor()
        cursor.execute(query, params)
        cursor.fetchall()

    def connected(self) -> bool:
        """Returns True if the database is connected, False otherwise"""
        return self.db_conn is not None and self.db_conn.status == mgclient.CONN_STATUS_READY

    def connect(self) -> mgclient.Connection:
        """Connects to the database and returns a Connection object"""
        sslmode = mgclient.MG_SSLMODE_REQUIRE if self.encrypted else mgclient.MG_SSLMODE_DISABLE
        connection = mgclient.connect(
            host=self.host,
            port=self.port,
            username=self.username,
            password=self.password,
            sslmode=sslmode,
            lazy=self.lazy,
            client_name=self.client_name,
        )
        connection.autocommit = True
        return connection

    def close(self) -> None:
        """Closes the connection to the database"""
        self.db_conn.close()
        self.closed = True

    @classmethod
    def singleton(cls) -> GraphDB:
        """This returns a singleton object for the graph database. If the
        singleton isn't created yet, it creates it.
        """
        global graph_db_singleton
        if not graph_db_singleton:
            graph_db_singleton = GraphDB()

        assert graph_db_singleton.closed is False
        return graph_db_singleton

    @staticmethod
    def to_networkx(
        db: GraphDB | None = None,
        node_ids: set[NodeId] | None = None,
        filter: NodeFilterFn | None = None,
    ) -> nx.DiGraph:
        """Converts the entire graph database (and local cache of objects) into
        a NetworkX graph

        Args:
            db (GraphDB | None, optional): The database to convert to NetworkX.
                Defaults to the GraphDB singleton if not specified.
            node_ids (set[NodeId] | None, optional): The NodeIDs to add to the
                NetworkX graph. Defaults to all IDs if not specified.
            filter (NodeFilterFn | None, optional): A Node filter to filter out
                nodes before adding them to the NetworkX graph. Also useful for a
                callback that can be used for progress updates. Defaults to None.

        Returns:
            nx.DiGraph: _description_
        """
        db = db or GraphDB.singleton()
        node_ids = node_ids or Node.all_ids(db=db)
        filter = filter or true_filter
        G = nx.DiGraph()

        def nx_add(n: Node) -> None:
            n_data = Node.to_dict(n, include_labels=True)

            # TODO: this converts labels to a string, but maybe there's a better
            # way to preserve the list so that it can be used for filtering in
            # external programs
            if "labels" in n_data and isinstance(n_data["labels"], set):
                n_data["labels"] = ", ".join(n_data["labels"])

            G.add_node(n.id, **n_data)

            for e in n.src_edges:
                e_data = Edge.to_dict(e, include_type=True)
                G.add_edge(e.src_id, e.dst_id, **e_data)

        # iterate all specified node_ids, adding all of them to the nx graph
        def nx_add_many(nodes: list[Node]) -> None:
            for n in nodes:
                if filter(n):
                    nx_add(n)

        Node.get_many(node_ids, load_edges=True, progress_callback=nx_add_many)

        return G

client_name = 'roc-graphdb-client' instance-attribute

closed = False instance-attribute

db_conn = self.connect() instance-attribute

encrypted = settings.db_conn_encrypted instance-attribute

host = settings.db_host instance-attribute

lazy = settings.db_lazy instance-attribute

password = settings.db_password instance-attribute

port = settings.db_port instance-attribute

strict_schema = settings.db_strict_schema instance-attribute

strict_schema_warns = settings.db_strict_schema_warns instance-attribute

username = settings.db_username instance-attribute

__init__()

Source code in roc/graphdb.py
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(self) -> None:
    settings = Config.get()
    self.host = settings.db_host
    self.port = settings.db_port
    self.encrypted = settings.db_conn_encrypted
    self.username = settings.db_username
    self.password = settings.db_password
    self.lazy = settings.db_lazy
    self.strict_schema = settings.db_strict_schema
    self.strict_schema_warns = settings.db_strict_schema_warns
    self.client_name = "roc-graphdb-client"
    self.db_conn = self.connect()
    self.closed = False

    if self.strict_schema:
        Schema.validate()

close()

Closes the connection to the database

Source code in roc/graphdb.py
161
162
163
164
def close(self) -> None:
    """Closes the connection to the database"""
    self.db_conn.close()
    self.closed = True

connect()

Connects to the database and returns a Connection object

Source code in roc/graphdb.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def connect(self) -> mgclient.Connection:
    """Connects to the database and returns a Connection object"""
    sslmode = mgclient.MG_SSLMODE_REQUIRE if self.encrypted else mgclient.MG_SSLMODE_DISABLE
    connection = mgclient.connect(
        host=self.host,
        port=self.port,
        username=self.username,
        password=self.password,
        sslmode=sslmode,
        lazy=self.lazy,
        client_name=self.client_name,
    )
    connection.autocommit = True
    return connection

connected()

Returns True if the database is connected, False otherwise

Source code in roc/graphdb.py
142
143
144
def connected(self) -> bool:
    """Returns True if the database is connected, False otherwise"""
    return self.db_conn is not None and self.db_conn.status == mgclient.CONN_STATUS_READY

raw_execute(query, *, params=None)

Executes a query with no return value. Used for 'SET', 'DELETE' or other queries without a 'RETURN' clause.

Parameters:

Name Type Description Default
query str

The Cypher query to execute

required
params dict[str, Any] | None

Any parameters to pass to the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters

None
Source code in roc/graphdb.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def raw_execute(self, query: str, *, params: dict[str, Any] | None = None) -> None:
    """Executes a query with no return value. Used for 'SET', 'DELETE' or
    other queries without a 'RETURN' clause.

    Args:
        query (str): The Cypher query to execute
        params (dict[str, Any] | None, optional): Any parameters to pass to
            the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters
    """
    params = params or {}
    logger.trace(f"raw_execute: '{query}' *** with params: *** '{params}'")

    cursor = self.db_conn.cursor()
    cursor.execute(query, params)
    cursor.fetchall()

raw_fetch(query, *, params=None)

Executes a Cypher query and returns the results as an iterator of dictionaries. Used for any query that has a 'RETURN' clause.

Parameters:

Name Type Description Default
query str

The Cypher query to execute

required
params dict[str, Any] | None

Any parameters to pass to the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters

None

Yields:

Type Description
dict[str, Any]

Iterator[dict[str, Any]]: An iterator of the results from the database.

Source code in roc/graphdb.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def raw_fetch(
    self, query: str, *, params: dict[str, Any] | None = None
) -> Iterator[dict[str, Any]]:
    """Executes a Cypher query and returns the results as an iterator of
    dictionaries. Used for any query that has a 'RETURN' clause.

    Args:
        query (str): The Cypher query to execute
        params (dict[str, Any] | None, optional): Any parameters to pass to
            the query. Defaults to None. See also: https://memgraph.com/docs/querying/expressions#parameters

    Yields:
        Iterator[dict[str, Any]]: An iterator of the results from the database.
    """
    params = params or {}
    logger.trace(f"raw_fetch: '{query}' *** with params: *** '{params}")

    cursor = self.db_conn.cursor()
    cursor.execute(query, params)
    while True:
        row = cursor.fetchone()
        if row is None:
            break
        yield {dsc.name: row[index] for index, dsc in enumerate(cursor.description)}

singleton() classmethod

This returns a singleton object for the graph database. If the singleton isn't created yet, it creates it.

Source code in roc/graphdb.py
166
167
168
169
170
171
172
173
174
175
176
@classmethod
def singleton(cls) -> GraphDB:
    """This returns a singleton object for the graph database. If the
    singleton isn't created yet, it creates it.
    """
    global graph_db_singleton
    if not graph_db_singleton:
        graph_db_singleton = GraphDB()

    assert graph_db_singleton.closed is False
    return graph_db_singleton

to_networkx(db=None, node_ids=None, filter=None) staticmethod

Converts the entire graph database (and local cache of objects) into a NetworkX graph

Parameters:

Name Type Description Default
db GraphDB | None

The database to convert to NetworkX. Defaults to the GraphDB singleton if not specified.

None
node_ids set[NodeId] | None

The NodeIDs to add to the NetworkX graph. Defaults to all IDs if not specified.

None
filter NodeFilterFn | None

A Node filter to filter out nodes before adding them to the NetworkX graph. Also useful for a callback that can be used for progress updates. Defaults to None.

None

Returns:

Type Description
DiGraph

nx.DiGraph: description

Source code in roc/graphdb.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@staticmethod
def to_networkx(
    db: GraphDB | None = None,
    node_ids: set[NodeId] | None = None,
    filter: NodeFilterFn | None = None,
) -> nx.DiGraph:
    """Converts the entire graph database (and local cache of objects) into
    a NetworkX graph

    Args:
        db (GraphDB | None, optional): The database to convert to NetworkX.
            Defaults to the GraphDB singleton if not specified.
        node_ids (set[NodeId] | None, optional): The NodeIDs to add to the
            NetworkX graph. Defaults to all IDs if not specified.
        filter (NodeFilterFn | None, optional): A Node filter to filter out
            nodes before adding them to the NetworkX graph. Also useful for a
            callback that can be used for progress updates. Defaults to None.

    Returns:
        nx.DiGraph: _description_
    """
    db = db or GraphDB.singleton()
    node_ids = node_ids or Node.all_ids(db=db)
    filter = filter or true_filter
    G = nx.DiGraph()

    def nx_add(n: Node) -> None:
        n_data = Node.to_dict(n, include_labels=True)

        # TODO: this converts labels to a string, but maybe there's a better
        # way to preserve the list so that it can be used for filtering in
        # external programs
        if "labels" in n_data and isinstance(n_data["labels"], set):
            n_data["labels"] = ", ".join(n_data["labels"])

        G.add_node(n.id, **n_data)

        for e in n.src_edges:
            e_data = Edge.to_dict(e, include_type=True)
            G.add_edge(e.src_id, e.dst_id, **e_data)

    # iterate all specified node_ids, adding all of them to the nx graph
    def nx_add_many(nodes: list[Node]) -> None:
        for n in nodes:
            if filter(n):
                nx_add(n)

    Node.get_many(node_ids, load_edges=True, progress_callback=nx_add_many)

    return G

GraphDBInternalError

Bases: Exception

An generic exception for unexpected errors

Source code in roc/graphdb.py
65
66
class GraphDBInternalError(Exception):
    """An generic exception for unexpected errors"""

MethodDescription

Source code in roc/graphdb.py
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
class MethodDescription:
    def __init__(self, model: type[BaseModel], name: str) -> None:
        self.model = model
        self.name = name
        self.signature = inspect.signature(getattr(model, name))
        self.return_type = clean_annotation(self.signature.return_annotation)
        self.params = self.signature.parameters

    @property
    def uml_params(self) -> list[str]:
        ret: list[str] = []

        for param_name, param in self.params.items():
            if param_name == "self":
                continue

            t = (
                f"{clean_annotation(param.annotation)} "
                if param.annotation is not inspect._empty
                else ""
            )
            default_val = f" = {param.default}" if param.default is not inspect._empty else ""
            ret.append(f"{t}{param_name}{default_val}")

        return ret

model = model instance-attribute

name = name instance-attribute

params = self.signature.parameters instance-attribute

return_type = clean_annotation(self.signature.return_annotation) instance-attribute

signature = inspect.signature(getattr(model, name)) instance-attribute

uml_params property

__init__(model, name)

Source code in roc/graphdb.py
1702
1703
1704
1705
1706
1707
def __init__(self, model: type[BaseModel], name: str) -> None:
    self.model = model
    self.name = name
    self.signature = inspect.signature(getattr(model, name))
    self.return_type = clean_annotation(self.signature.return_annotation)
    self.params = self.signature.parameters

ModelDescription

Source code in roc/graphdb.py
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
class ModelDescription:
    def __init__(self, model: type[BaseModel]) -> None:
        self.model = model

        # fields
        self.fields = [
            FieldDescription(model, fieldname) for fieldname in pydantic_get_fields(model)
        ]
        self.fields.sort(key=lambda f: f.name)

        # parents
        self.parent_class_names = get_node_parent_names(model)
        self.parents = [
            NodeDescription(node_registry[node_name]) for node_name in self.parent_class_names
        ]
        self.parents.sort(key=lambda p: p.name)

        # methods
        self.method_names = (
            get_methods(model) - get_methods(object) - get_methods(BaseModel) - get_methods(Node)
        )
        self.methods = [MethodDescription(model, name) for name in self.method_names]
        self.methods.sort(key=lambda m: m.name)

fields = [FieldDescription(model, fieldname) for fieldname in pydantic_get_fields(model)] instance-attribute

method_names = get_methods(model) - get_methods(object) - get_methods(BaseModel) - get_methods(Node) instance-attribute

methods = [MethodDescription(model, name) for name in self.method_names] instance-attribute

model = model instance-attribute

parent_class_names = get_node_parent_names(model) instance-attribute

parents = [NodeDescription(node_registry[node_name]) for node_name in self.parent_class_names] instance-attribute

__init__(model)

Source code in roc/graphdb.py
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
def __init__(self, model: type[BaseModel]) -> None:
    self.model = model

    # fields
    self.fields = [
        FieldDescription(model, fieldname) for fieldname in pydantic_get_fields(model)
    ]
    self.fields.sort(key=lambda f: f.name)

    # parents
    self.parent_class_names = get_node_parent_names(model)
    self.parents = [
        NodeDescription(node_registry[node_name]) for node_name in self.parent_class_names
    ]
    self.parents.sort(key=lambda p: p.name)

    # methods
    self.method_names = (
        get_methods(model) - get_methods(object) - get_methods(BaseModel) - get_methods(Node)
    )
    self.methods = [MethodDescription(model, name) for name in self.method_names]
    self.methods.sort(key=lambda m: m.name)

Node

Bases: BaseModel

An graph database node that automatically handles CRUD for the underlying graph database objects

Source code in roc/graphdb.py
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
class Node(BaseModel, extra="allow"):
    """An graph database node that automatically handles CRUD for the underlying graph database objects"""

    _id: NodeId
    labels: set[str] = Field(exclude=True, default_factory=set)
    _orig_labels: set[str]
    _src_edges: EdgeList
    _dst_edges: EdgeList
    _db: GraphDB
    _new = False
    _no_save = False
    _deleted = False

    @property
    def id(self) -> NodeId:
        """The unique ID of the node"""
        return self._id

    @property
    def src_edges(self) -> EdgeList:
        """All Edges that originate at this Node"""
        return self._src_edges

    @property
    def dst_edges(self) -> EdgeList:
        """All Edges that terminate at this Node"""
        return self._dst_edges

    @property
    def edges(self) -> EdgeList:
        """All Edges attached to this Node, regardless of direction"""
        return self._src_edges + self._dst_edges

    @property
    def predecessors(self) -> NodeList:
        """All Nodes connected with an directed Edge that ends with this node.
        Also referred to as an 'in-neighbor'.
        """
        return NodeList([e.src.id for e in self.dst_edges])

    @property
    def successors(self) -> NodeList:
        """All Nodes connected with an directed Edge that starts with this node.
        Also referred to as an 'out-neighbor'.
        """
        return NodeList([e.dst.id for e in self.src_edges])

    @property
    def neighbors(self) -> NodeList:
        """All adjacent nodes, regardless of edge direction"""
        return self.successors + self.predecessors

    @property
    def new(self) -> bool:
        """Whether or not this Node is new (not saved to the database yet)"""
        return self._new

    def __init__(
        self,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)

        # set passed-in private values or their defaults
        self._db = kwargs["_db"] if "_db" in kwargs else GraphDB.singleton()
        self._id = kwargs["_id"] if "_id" in kwargs else get_next_new_node_id()
        self._src_edges = kwargs["_src_edges"] if "_src_edges" in kwargs else EdgeList([])
        self._dst_edges = kwargs["_dst_edges"] if "_dst_edges" in kwargs else EdgeList([])

        if self.id < 0:
            self._new = True  # TODO: derived?
            Node.get_cache()[self.id] = self

        self._orig_labels = self.labels.copy()

    def __del__(self) -> None:
        # print("Node.__del__:", self)
        try:
            self.__class__.save(self, db=self._db)
        except Exception as e:
            err_msg = f"error saving during del: {e}"
            # logger.warning(err_msg)
            warnings.warn(err_msg, ErrorSavingDuringDelWarning)

    def __repr__(self) -> str:
        return f"Node({self.id})"

    def __str__(self) -> str:
        return f"Node({self.id}, labels={self.labels})"

    def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
        super().__init_subclass__(*args, **kwargs)
        clsname = cls.__name__

        if not hasattr(cls, "labels"):
            new_lbls = {c.__name__ for c in cls.__mro__ if c not in [Node, BaseModel, object]}

            def default_subclass_fields() -> set[str]:
                return new_lbls

            cls.labels = Field(default_factory=default_subclass_fields, exclude=True)
            labels_key = frozenset(new_lbls)
        else:
            if isinstance(cls.labels, FieldInfo):
                labels_key = frozenset(cls.labels.get_default(call_default_factory=True))
            else:
                labels_key = frozenset(cls.labels)

        if clsname in node_registry:
            raise Exception(
                f"""node_register can't register '{clsname}' because that name has already been registered"""
            )

        if labels_key in node_label_registry:
            labels = ", ".join(sorted(list(labels_key)))
            raise Exception(
                f"""node_register can't register labels '{labels}' because they have already been registered"""
            )

        node_registry[clsname] = cls
        node_label_registry[labels_key] = cls

    @classmethod
    def load(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
        """Loads a node from the database. Use `Node.get` or other methods instead.

        Args:
            id (NodeId): The identifier of the node to fetch
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Raises:
            NodeNotFound: The node specified by the identifier does not exist in the database
            GraphDBInternalError: If the requested ID returns multiple nodes

        Returns:
            Self: The node from the database
        """
        res = cls.load_many({id}, db=db)

        # print("RES", res)

        if len(res) < 1:
            raise NodeNotFound(f"Couldn't find node ID: {id}")

        if len(res) > 1:
            raise GraphDBInternalError(
                f"Too many nodes returned while trying to load single node: {id}"
            )

        return res[0]

    @classmethod
    def load_many(
        cls,
        node_set: set[NodeId],
        db: GraphDB | None = None,
        load_edges: bool = False,
    ) -> list[Self]:
        db = db or GraphDB.singleton()
        node_ids = ",".join(map(str, node_set))

        ret = cls.find(
            where=f"id(src) IN [{node_ids}]",  # TODO: use params?
            db=db,
            load_edges=load_edges,
        )

        if len(ret) != len(node_set):
            id_set = {n.id for n in ret}
            missing_ids = node_set - id_set
            raise NodeNotFound(f"Couldn't find node IDs: {', '.join(map(str, missing_ids))}")

        return ret

    @classmethod
    def find(
        cls,
        where: str,
        src_node_name: str = "src",
        src_labels: set[str] = set(),
        edge_name: str = "e",
        edge_type: str = "",
        params: QueryParamType = dict(),
        db: GraphDB | None = None,
        load_edges: bool = False,
        params_to_str: bool = True,
    ) -> list[Self]:
        db = db or GraphDB.singleton()

        if load_edges:
            edge_fmt = f"{edge_name}"
        else:
            edge_fmt = f"{{id: id({edge_name}), start: id(startNode({edge_name})), end: id(endNode({edge_name}))}}"

        if len(src_labels) == 0:
            src_label_str = ""
        else:
            src_label_str = f":{':'.join(src_labels)}"

        if len(edge_type) > 0:
            edge_type = ":" + edge_type

        if params_to_str:
            for k in params.keys():
                params[k] = str(params[k])

        res_iter = db.raw_fetch(
            f"""
                MATCH ({src_node_name}{src_label_str})-[{edge_name}{edge_type}*0..1]-() 
                WITH {src_node_name}, head({edge_name}) AS {edge_name}
                WHERE {where}
                RETURN {src_node_name} AS n, collect({edge_fmt}) AS edges
                """,
            params=params,
        )

        ret_list = list()
        for r in res_iter:
            logger.trace(f"find result: {r}")
            n = r["n"]
            if n is None:
                # NOTE: I can't think of any circumstances where there would be
                # multiple "None" results, so I think this is just an empty list
                continue

            if load_edges:
                # XXX: memgraph converts edges to Relationship objects if you
                # return the whole edge
                src_edges = list()
                dst_edges = list()
                edge_cache = Edge.get_cache()
                for e in r["edges"]:
                    # add edge_id to to the right list for the node creation below
                    if n.id == e.start_id:
                        src_edges.append(e.id)
                    else:
                        dst_edges.append(e.id)

                    # edge already loaded, continue to next one
                    if e.id in edge_cache:
                        continue

                    # create a new edge
                    props = {}
                    if hasattr(e, "properties"):
                        props = e.properties
                    new_edge = Edge(
                        src_id=e.start_id,
                        dst_id=e.end_id,
                        _id=e.id,
                        type=e.type,
                        **props,
                    )
                    edge_cache[e.id] = new_edge
            else:
                # edges are just the IDs
                src_edges = [e["id"] for e in r["edges"] if e["start"] == n.id]
                dst_edges = [e["id"] for e in r["edges"] if e["end"] == n.id]

            node_cache = cls.get_cache()
            if n.id in node_cache:
                new_node = cast(Self, node_cache[n.id])
            else:
                mkcls = cls
                cls_lbls = frozenset(n.labels)
                if cls is Node and cls_lbls in node_label_registry:
                    mkcls = cast(type[Self], node_label_registry[cls_lbls])
                new_node = mkcls(
                    _id=n.id,
                    _src_edges=EdgeList(src_edges),
                    _dst_edges=EdgeList(dst_edges),
                    labels=n.labels,
                    **n.properties,
                )
                node_cache[n.id] = new_node
            ret_list.append(new_node)

        return ret_list

    @classmethod
    def find_one(
        cls,
        where: str,
        src_node_name: str = "src",
        src_labels: set[str] = set(),
        edge_name: str = "e",
        edge_type: str = "",
        params: QueryParamType = dict(),
        db: GraphDB | None = None,
        load_edges: bool = False,
        params_to_str: bool = True,
        exactly_one: bool = False,
    ) -> Self | None:
        """Finds a single Node.find results down to a single node. Raises an
        exception of the list contains more than one node.

        Args:
            nodes (Sequence[NodeType]): The list of nodes returned by Node.find

        Raises:
            Exception: Raised if there is more than 1 node in the list

        Returns:
            NodeType | None: Returns None if the list is empty, or the node in the list.
        """
        nodes = cls.find(
            where=where,
            src_node_name=src_node_name,
            src_labels=src_labels,
            edge_name=edge_name,
            edge_type=edge_type,
            params=params,
            db=db,
            load_edges=load_edges,
            params_to_str=params_to_str,
        )

        match len(nodes):
            case 0:
                if exactly_one:
                    raise Exception("expect exactly one node in find_one")
                return None
            case 1:
                return nodes[0]
            case _:
                raise Exception("expected zero or one node in find_one")

    @classmethod
    def get_many(
        cls,
        node_ids: Collection[NodeId],
        *,
        batch_size: int = 128,
        db: GraphDB | None = None,
        load_edges: bool = False,
        return_nodes: bool = False,
        progress_callback: ProgressFn | None = None,
    ) -> list[Node]:
        db = db or GraphDB.singleton()

        if not isinstance(node_ids, set):
            node_ids = set(node_ids)

        c = Node.get_cache()
        if len(node_ids) > c.maxsize:
            raise GraphDBInternalError(
                f"get_many attempting to load more nodes than cache size ({len(node_ids)} > {c.maxsize})"
            )

        cache_ids = set(c.keys())
        fetch_ids = node_ids - cache_ids

        start = 0
        curr = batch_size
        ret_list = [c[nid] for nid in c]
        if progress_callback:
            progress_callback(ret_list)
        while start < len(fetch_ids):
            id_set = set(islice(fetch_ids, start, curr))

            res = cls.load_many(id_set, db=db, load_edges=load_edges)
            for n in res:
                c[n.id] = n

            if progress_callback:
                progress_callback(res)

            ret_list.extend(res)
            # import pprint
            # pprint.pp(list(res))
            # print(f"got {len(list(res))} nodes")

            start = curr
            curr += batch_size

        assert len(ret_list) == len(node_ids)
        return ret_list

    @classmethod
    def get_cache(cls) -> NodeCache:
        global node_cache
        if node_cache is None:
            settings = Config.get()
            node_cache = NodeCache(settings.node_cache_size)

        return node_cache

    @classmethod
    def get(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
        """Returns a cached node with the specified id. If no node is cached, it is retrieved from
        the database.


        Args:
            id (NodeId): The unique identifier of the node to fetch
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: the cached or newly retrieved node
        """
        cache = Node.get_cache()
        n = cache.get(id)
        if n is None:
            n = cls.load(id, db=db)

        return cast(Self, n)

    @classmethod
    def save(cls, n: Self, *, db: GraphDB | None = None) -> Self:
        """Save a node to persistent storage

        Writes the specified node to the GraphDB for persistent storage. If the node does not
        already exist in storage, it is created via the `create` method. If the node does exist, it
        is updated via the `update` method.

        If the _no_save flag is True on the node, the save request will be silently ignored.

        Args:
            n (Self): The Node to be saved
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: As a convenience, the node that was stored is returned. This may be useful
            since the the id of the node may change if it was created in the database.
        """
        if n._new:
            return cls.create(n, db=db)
        else:
            return cls.update(n, db=db)

    @classmethod
    def update(cls, n: Self, *, db: GraphDB | None = None) -> Self:
        """Update an existing node in the GraphDB.

        Calling `save` is preferred to using this method so that the caller doesn't need to know the
        state of the node.

        Args:
            n (Self): The node to be updated
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Self: The node that was passed in, for convenience
        """
        if n._no_save:
            return n

        db = db or GraphDB.singleton()

        orig_labels = n._orig_labels
        curr_labels = set(n.labels)
        new_labels = curr_labels - orig_labels
        rm_labels = orig_labels - curr_labels
        set_label_str = Node.mklabels(new_labels)
        if set_label_str:
            set_query = f"SET n{set_label_str}, n = $props"
        else:
            set_query = "SET n = $props"
        rm_label_str = Node.mklabels(rm_labels)
        if rm_label_str:
            rm_query = f"REMOVE n{rm_label_str}"
        else:
            rm_query = ""

        params = {"props": Node.to_dict(n)}

        db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} {set_query} {rm_query}", params=params)

        return n

    @classmethod
    def create(cls, n: Self, *, db: GraphDB | None = None) -> Self:
        """Creates the specified node in the GraphDB.

        Calling `save` is preferred to using this method so that the caller doesn't need to know the
        state of the node.

        Args:
            n (Self): the node to be created
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Raises:
            NodeCreationFailed: if creating the node failed in the database

        Returns:
            Self: the node that was passed in, albeit with a new `id` and potenitally other new
            fields
        """
        if n._no_save:
            return n

        db = db or GraphDB.singleton()
        old_id = n.id

        label_str = Node.mklabels(n.labels)
        params = {"props": Node.to_dict(n)}

        res = list(db.raw_fetch(f"CREATE (n{label_str} $props) RETURN id(n) as id", params=params))

        if not len(res) >= 1:
            raise NodeCreationFailed(f"Couldn't create node ID: {id}")

        new_id = res[0]["id"]
        n._id = new_id
        n._new = False
        # update the cache; if being called during c then the cache entry may not exist
        try:
            cache = Node.get_cache()
            del cache[old_id]
            cache[new_id] = n
        except KeyError:
            pass

        for e in n.src_edges:
            assert e.src_id == old_id
            e.src_id = new_id

        for e in n.dst_edges:
            assert e.dst_id == old_id
            e.dst_id = new_id

        return n

    @classmethod
    def connect(
        cls,
        src: NodeId | Self,
        dst: NodeId | Self,
        type: str | None = None,
        *,
        db: GraphDB | None = None,
    ) -> Edge:
        """Connects two nodes (creates an Edge between two nodes)

        Args:
            src (NodeId | Node): The Node to use at the start of the connection
            dst (NodeId | Node): The Node to use at the end of the connection
            type (str): The type of the edge to use for the connection
            db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

        Returns:
            Edge: The Edge that was created
        """
        return Edge.connect(src, dst, type, db=db)

    @staticmethod
    def delete(n: Node, *, db: GraphDB | None = None) -> None:
        db = db or GraphDB.singleton()

        # remove edges
        for e in n.src_edges:
            Edge.delete(e)

        for e in n.dst_edges:
            Edge.delete(e)

        # remove from cache
        node_cache = Node.get_cache()
        if n.id in node_cache:
            del node_cache[n.id]

        if not n._new:
            db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} DELETE n")

        n._deleted = True
        n._no_save = True

    @staticmethod
    def to_dict(n: Node, include_labels: bool = False) -> dict[str, Any]:
        """Convert a Node to a Python dictionary"""
        # XXX: the excluded fields below shouldn't have been included in the
        # first place because Pythonic should exclude fields with underscores
        ret = n.model_dump(exclude={"_id", "_src_edges", "_dst_edges"})

        if include_labels and hasattr(n, "labels"):
            ret["labels"] = n.labels

        return ret

    @staticmethod
    def mklabels(labels: set[str]) -> str:
        """Converts a list of strings into proper Cypher syntax for a graph database query"""
        labels_list = [i for i in labels]
        labels_list.sort()
        label_str = ":".join(labels_list)
        if len(label_str) > 0:
            label_str = ":" + label_str

        return label_str

    @staticmethod
    def all_ids(db: GraphDB | None = None) -> set[NodeId]:
        """Returns an exhaustive Set of all NodeIds that exist in both the graph
        database and the NodeCache
        """
        db = db or GraphDB.singleton()

        # get all NodeIds in the cache
        c = Node.get_cache()
        cached_ids = set(c.keys())

        # get all NodeIds in the database
        db_ids = {n["id"] for n in db.raw_fetch("MATCH (n) RETURN id(n) as id")}

        # return the combination of both
        return db_ids.union(cached_ids)

    @staticmethod
    def to_id(n: Node | NodeId) -> NodeId:
        if isinstance(n, Node):
            return n.id
        else:
            return n

    @staticmethod
    def walk(
        n: Node,
        *,
        mode: WalkMode = "both",
        edge_filter: EdgeFilterFn | None = None,
        # edge_callback: EdgeCallbackFn | None = None,
        node_filter: NodeFilterFn | None = None,
        node_callback: NodeCallbackFn | None = None,
        _walk_history: set[int] | None = None,
    ) -> None:
        # if we have walked this node before, just return
        _walk_history = _walk_history or set()
        if n.id in _walk_history:
            return
        _walk_history.add(n.id)

        edge_filter = edge_filter or cast(EdgeFilterFn, true_filter)
        node_filter = node_filter or true_filter
        # edge_callback = edge_callback or no_callback
        node_callback = node_callback or no_callback

        # callback for this node, if not filtered
        if node_filter(n):
            node_callback(n)
        else:
            return

        if mode == "src" or mode == "both":
            for e in n.src_edges:
                if edge_filter(e):
                    Node.walk(
                        e.dst,
                        mode=mode,
                        edge_filter=edge_filter,
                        # edge_callback=edge_callback,
                        node_filter=node_filter,
                        node_callback=node_callback,
                        _walk_history=_walk_history,
                    )

        if mode == "dst" or mode == "both":
            for e in n.dst_edges:
                if edge_filter(e):
                    Node.walk(
                        e.src,
                        mode=mode,
                        edge_filter=edge_filter,
                        # edge_callback=edge_callback,
                        node_filter=node_filter,
                        node_callback=node_callback,
                        _walk_history=_walk_history,
                    )

dst_edges property

All Edges that terminate at this Node

edges property

All Edges attached to this Node, regardless of direction

id property

The unique ID of the node

labels = Field(exclude=True, default_factory=set) class-attribute instance-attribute

neighbors property

All adjacent nodes, regardless of edge direction

new property

Whether or not this Node is new (not saved to the database yet)

predecessors property

All Nodes connected with an directed Edge that ends with this node. Also referred to as an 'in-neighbor'.

src_edges property

All Edges that originate at this Node

successors property

All Nodes connected with an directed Edge that starts with this node. Also referred to as an 'out-neighbor'.

__del__()

Source code in roc/graphdb.py
860
861
862
863
864
865
866
867
def __del__(self) -> None:
    # print("Node.__del__:", self)
    try:
        self.__class__.save(self, db=self._db)
    except Exception as e:
        err_msg = f"error saving during del: {e}"
        # logger.warning(err_msg)
        warnings.warn(err_msg, ErrorSavingDuringDelWarning)

__init__(**kwargs)

Source code in roc/graphdb.py
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def __init__(
    self,
    **kwargs: Any,
):
    super().__init__(**kwargs)

    # set passed-in private values or their defaults
    self._db = kwargs["_db"] if "_db" in kwargs else GraphDB.singleton()
    self._id = kwargs["_id"] if "_id" in kwargs else get_next_new_node_id()
    self._src_edges = kwargs["_src_edges"] if "_src_edges" in kwargs else EdgeList([])
    self._dst_edges = kwargs["_dst_edges"] if "_dst_edges" in kwargs else EdgeList([])

    if self.id < 0:
        self._new = True  # TODO: derived?
        Node.get_cache()[self.id] = self

    self._orig_labels = self.labels.copy()

__init_subclass__(*args, **kwargs)

Source code in roc/graphdb.py
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
    super().__init_subclass__(*args, **kwargs)
    clsname = cls.__name__

    if not hasattr(cls, "labels"):
        new_lbls = {c.__name__ for c in cls.__mro__ if c not in [Node, BaseModel, object]}

        def default_subclass_fields() -> set[str]:
            return new_lbls

        cls.labels = Field(default_factory=default_subclass_fields, exclude=True)
        labels_key = frozenset(new_lbls)
    else:
        if isinstance(cls.labels, FieldInfo):
            labels_key = frozenset(cls.labels.get_default(call_default_factory=True))
        else:
            labels_key = frozenset(cls.labels)

    if clsname in node_registry:
        raise Exception(
            f"""node_register can't register '{clsname}' because that name has already been registered"""
        )

    if labels_key in node_label_registry:
        labels = ", ".join(sorted(list(labels_key)))
        raise Exception(
            f"""node_register can't register labels '{labels}' because they have already been registered"""
        )

    node_registry[clsname] = cls
    node_label_registry[labels_key] = cls

__repr__()

Source code in roc/graphdb.py
869
870
def __repr__(self) -> str:
    return f"Node({self.id})"

__str__()

Source code in roc/graphdb.py
872
873
def __str__(self) -> str:
    return f"Node({self.id}, labels={self.labels})"

all_ids(db=None) staticmethod

Returns an exhaustive Set of all NodeIds that exist in both the graph database and the NodeCache

Source code in roc/graphdb.py
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
@staticmethod
def all_ids(db: GraphDB | None = None) -> set[NodeId]:
    """Returns an exhaustive Set of all NodeIds that exist in both the graph
    database and the NodeCache
    """
    db = db or GraphDB.singleton()

    # get all NodeIds in the cache
    c = Node.get_cache()
    cached_ids = set(c.keys())

    # get all NodeIds in the database
    db_ids = {n["id"] for n in db.raw_fetch("MATCH (n) RETURN id(n) as id")}

    # return the combination of both
    return db_ids.union(cached_ids)

connect(src, dst, type=None, *, db=None) classmethod

Connects two nodes (creates an Edge between two nodes)

Parameters:

Name Type Description Default
src NodeId | Node

The Node to use at the start of the connection

required
dst NodeId | Node

The Node to use at the end of the connection

required
type str

The type of the edge to use for the connection

None
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Edge Edge

The Edge that was created

Source code in roc/graphdb.py
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
@classmethod
def connect(
    cls,
    src: NodeId | Self,
    dst: NodeId | Self,
    type: str | None = None,
    *,
    db: GraphDB | None = None,
) -> Edge:
    """Connects two nodes (creates an Edge between two nodes)

    Args:
        src (NodeId | Node): The Node to use at the start of the connection
        dst (NodeId | Node): The Node to use at the end of the connection
        type (str): The type of the edge to use for the connection
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Edge: The Edge that was created
    """
    return Edge.connect(src, dst, type, db=db)

create(n, *, db=None) classmethod

Creates the specified node in the GraphDB.

Calling save is preferred to using this method so that the caller doesn't need to know the state of the node.

Parameters:

Name Type Description Default
n Self

the node to be created

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Raises:

Type Description
NodeCreationFailed

if creating the node failed in the database

Returns:

Name Type Description
Self Self

the node that was passed in, albeit with a new id and potenitally other new

Self

fields

Source code in roc/graphdb.py
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
@classmethod
def create(cls, n: Self, *, db: GraphDB | None = None) -> Self:
    """Creates the specified node in the GraphDB.

    Calling `save` is preferred to using this method so that the caller doesn't need to know the
    state of the node.

    Args:
        n (Self): the node to be created
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Raises:
        NodeCreationFailed: if creating the node failed in the database

    Returns:
        Self: the node that was passed in, albeit with a new `id` and potenitally other new
        fields
    """
    if n._no_save:
        return n

    db = db or GraphDB.singleton()
    old_id = n.id

    label_str = Node.mklabels(n.labels)
    params = {"props": Node.to_dict(n)}

    res = list(db.raw_fetch(f"CREATE (n{label_str} $props) RETURN id(n) as id", params=params))

    if not len(res) >= 1:
        raise NodeCreationFailed(f"Couldn't create node ID: {id}")

    new_id = res[0]["id"]
    n._id = new_id
    n._new = False
    # update the cache; if being called during c then the cache entry may not exist
    try:
        cache = Node.get_cache()
        del cache[old_id]
        cache[new_id] = n
    except KeyError:
        pass

    for e in n.src_edges:
        assert e.src_id == old_id
        e.src_id = new_id

    for e in n.dst_edges:
        assert e.dst_id == old_id
        e.dst_id = new_id

    return n

delete(n, *, db=None) staticmethod

Source code in roc/graphdb.py
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
@staticmethod
def delete(n: Node, *, db: GraphDB | None = None) -> None:
    db = db or GraphDB.singleton()

    # remove edges
    for e in n.src_edges:
        Edge.delete(e)

    for e in n.dst_edges:
        Edge.delete(e)

    # remove from cache
    node_cache = Node.get_cache()
    if n.id in node_cache:
        del node_cache[n.id]

    if not n._new:
        db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} DELETE n")

    n._deleted = True
    n._no_save = True

find(where, src_node_name='src', src_labels=set(), edge_name='e', edge_type='', params=dict(), db=None, load_edges=False, params_to_str=True) classmethod

Source code in roc/graphdb.py
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
@classmethod
def find(
    cls,
    where: str,
    src_node_name: str = "src",
    src_labels: set[str] = set(),
    edge_name: str = "e",
    edge_type: str = "",
    params: QueryParamType = dict(),
    db: GraphDB | None = None,
    load_edges: bool = False,
    params_to_str: bool = True,
) -> list[Self]:
    db = db or GraphDB.singleton()

    if load_edges:
        edge_fmt = f"{edge_name}"
    else:
        edge_fmt = f"{{id: id({edge_name}), start: id(startNode({edge_name})), end: id(endNode({edge_name}))}}"

    if len(src_labels) == 0:
        src_label_str = ""
    else:
        src_label_str = f":{':'.join(src_labels)}"

    if len(edge_type) > 0:
        edge_type = ":" + edge_type

    if params_to_str:
        for k in params.keys():
            params[k] = str(params[k])

    res_iter = db.raw_fetch(
        f"""
            MATCH ({src_node_name}{src_label_str})-[{edge_name}{edge_type}*0..1]-() 
            WITH {src_node_name}, head({edge_name}) AS {edge_name}
            WHERE {where}
            RETURN {src_node_name} AS n, collect({edge_fmt}) AS edges
            """,
        params=params,
    )

    ret_list = list()
    for r in res_iter:
        logger.trace(f"find result: {r}")
        n = r["n"]
        if n is None:
            # NOTE: I can't think of any circumstances where there would be
            # multiple "None" results, so I think this is just an empty list
            continue

        if load_edges:
            # XXX: memgraph converts edges to Relationship objects if you
            # return the whole edge
            src_edges = list()
            dst_edges = list()
            edge_cache = Edge.get_cache()
            for e in r["edges"]:
                # add edge_id to to the right list for the node creation below
                if n.id == e.start_id:
                    src_edges.append(e.id)
                else:
                    dst_edges.append(e.id)

                # edge already loaded, continue to next one
                if e.id in edge_cache:
                    continue

                # create a new edge
                props = {}
                if hasattr(e, "properties"):
                    props = e.properties
                new_edge = Edge(
                    src_id=e.start_id,
                    dst_id=e.end_id,
                    _id=e.id,
                    type=e.type,
                    **props,
                )
                edge_cache[e.id] = new_edge
        else:
            # edges are just the IDs
            src_edges = [e["id"] for e in r["edges"] if e["start"] == n.id]
            dst_edges = [e["id"] for e in r["edges"] if e["end"] == n.id]

        node_cache = cls.get_cache()
        if n.id in node_cache:
            new_node = cast(Self, node_cache[n.id])
        else:
            mkcls = cls
            cls_lbls = frozenset(n.labels)
            if cls is Node and cls_lbls in node_label_registry:
                mkcls = cast(type[Self], node_label_registry[cls_lbls])
            new_node = mkcls(
                _id=n.id,
                _src_edges=EdgeList(src_edges),
                _dst_edges=EdgeList(dst_edges),
                labels=n.labels,
                **n.properties,
            )
            node_cache[n.id] = new_node
        ret_list.append(new_node)

    return ret_list

find_one(where, src_node_name='src', src_labels=set(), edge_name='e', edge_type='', params=dict(), db=None, load_edges=False, params_to_str=True, exactly_one=False) classmethod

Finds a single Node.find results down to a single node. Raises an exception of the list contains more than one node.

Parameters:

Name Type Description Default
nodes Sequence[NodeType]

The list of nodes returned by Node.find

required

Raises:

Type Description
Exception

Raised if there is more than 1 node in the list

Returns:

Type Description
Self | None

NodeType | None: Returns None if the list is empty, or the node in the list.

Source code in roc/graphdb.py
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
@classmethod
def find_one(
    cls,
    where: str,
    src_node_name: str = "src",
    src_labels: set[str] = set(),
    edge_name: str = "e",
    edge_type: str = "",
    params: QueryParamType = dict(),
    db: GraphDB | None = None,
    load_edges: bool = False,
    params_to_str: bool = True,
    exactly_one: bool = False,
) -> Self | None:
    """Finds a single Node.find results down to a single node. Raises an
    exception of the list contains more than one node.

    Args:
        nodes (Sequence[NodeType]): The list of nodes returned by Node.find

    Raises:
        Exception: Raised if there is more than 1 node in the list

    Returns:
        NodeType | None: Returns None if the list is empty, or the node in the list.
    """
    nodes = cls.find(
        where=where,
        src_node_name=src_node_name,
        src_labels=src_labels,
        edge_name=edge_name,
        edge_type=edge_type,
        params=params,
        db=db,
        load_edges=load_edges,
        params_to_str=params_to_str,
    )

    match len(nodes):
        case 0:
            if exactly_one:
                raise Exception("expect exactly one node in find_one")
            return None
        case 1:
            return nodes[0]
        case _:
            raise Exception("expected zero or one node in find_one")

get(id, *, db=None) classmethod

Returns a cached node with the specified id. If no node is cached, it is retrieved from the database.

Parameters:

Name Type Description Default
id NodeId

The unique identifier of the node to fetch

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

the cached or newly retrieved node

Source code in roc/graphdb.py
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
@classmethod
def get(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
    """Returns a cached node with the specified id. If no node is cached, it is retrieved from
    the database.


    Args:
        id (NodeId): The unique identifier of the node to fetch
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: the cached or newly retrieved node
    """
    cache = Node.get_cache()
    n = cache.get(id)
    if n is None:
        n = cls.load(id, db=db)

    return cast(Self, n)

get_cache() classmethod

Source code in roc/graphdb.py
1163
1164
1165
1166
1167
1168
1169
1170
@classmethod
def get_cache(cls) -> NodeCache:
    global node_cache
    if node_cache is None:
        settings = Config.get()
        node_cache = NodeCache(settings.node_cache_size)

    return node_cache

get_many(node_ids, *, batch_size=128, db=None, load_edges=False, return_nodes=False, progress_callback=None) classmethod

Source code in roc/graphdb.py
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
@classmethod
def get_many(
    cls,
    node_ids: Collection[NodeId],
    *,
    batch_size: int = 128,
    db: GraphDB | None = None,
    load_edges: bool = False,
    return_nodes: bool = False,
    progress_callback: ProgressFn | None = None,
) -> list[Node]:
    db = db or GraphDB.singleton()

    if not isinstance(node_ids, set):
        node_ids = set(node_ids)

    c = Node.get_cache()
    if len(node_ids) > c.maxsize:
        raise GraphDBInternalError(
            f"get_many attempting to load more nodes than cache size ({len(node_ids)} > {c.maxsize})"
        )

    cache_ids = set(c.keys())
    fetch_ids = node_ids - cache_ids

    start = 0
    curr = batch_size
    ret_list = [c[nid] for nid in c]
    if progress_callback:
        progress_callback(ret_list)
    while start < len(fetch_ids):
        id_set = set(islice(fetch_ids, start, curr))

        res = cls.load_many(id_set, db=db, load_edges=load_edges)
        for n in res:
            c[n.id] = n

        if progress_callback:
            progress_callback(res)

        ret_list.extend(res)
        # import pprint
        # pprint.pp(list(res))
        # print(f"got {len(list(res))} nodes")

        start = curr
        curr += batch_size

    assert len(ret_list) == len(node_ids)
    return ret_list

load(id, *, db=None) classmethod

Loads a node from the database. Use Node.get or other methods instead.

Parameters:

Name Type Description Default
id NodeId

The identifier of the node to fetch

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Raises:

Type Description
NodeNotFound

The node specified by the identifier does not exist in the database

GraphDBInternalError

If the requested ID returns multiple nodes

Returns:

Name Type Description
Self Self

The node from the database

Source code in roc/graphdb.py
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
@classmethod
def load(cls, id: NodeId, *, db: GraphDB | None = None) -> Self:
    """Loads a node from the database. Use `Node.get` or other methods instead.

    Args:
        id (NodeId): The identifier of the node to fetch
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Raises:
        NodeNotFound: The node specified by the identifier does not exist in the database
        GraphDBInternalError: If the requested ID returns multiple nodes

    Returns:
        Self: The node from the database
    """
    res = cls.load_many({id}, db=db)

    # print("RES", res)

    if len(res) < 1:
        raise NodeNotFound(f"Couldn't find node ID: {id}")

    if len(res) > 1:
        raise GraphDBInternalError(
            f"Too many nodes returned while trying to load single node: {id}"
        )

    return res[0]

load_many(node_set, db=None, load_edges=False) classmethod

Source code in roc/graphdb.py
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
@classmethod
def load_many(
    cls,
    node_set: set[NodeId],
    db: GraphDB | None = None,
    load_edges: bool = False,
) -> list[Self]:
    db = db or GraphDB.singleton()
    node_ids = ",".join(map(str, node_set))

    ret = cls.find(
        where=f"id(src) IN [{node_ids}]",  # TODO: use params?
        db=db,
        load_edges=load_edges,
    )

    if len(ret) != len(node_set):
        id_set = {n.id for n in ret}
        missing_ids = node_set - id_set
        raise NodeNotFound(f"Couldn't find node IDs: {', '.join(map(str, missing_ids))}")

    return ret

mklabels(labels) staticmethod

Converts a list of strings into proper Cypher syntax for a graph database query

Source code in roc/graphdb.py
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
@staticmethod
def mklabels(labels: set[str]) -> str:
    """Converts a list of strings into proper Cypher syntax for a graph database query"""
    labels_list = [i for i in labels]
    labels_list.sort()
    label_str = ":".join(labels_list)
    if len(label_str) > 0:
        label_str = ":" + label_str

    return label_str

save(n, *, db=None) classmethod

Save a node to persistent storage

Writes the specified node to the GraphDB for persistent storage. If the node does not already exist in storage, it is created via the create method. If the node does exist, it is updated via the update method.

If the _no_save flag is True on the node, the save request will be silently ignored.

Parameters:

Name Type Description Default
n Self

The Node to be saved

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

As a convenience, the node that was stored is returned. This may be useful

Self

since the the id of the node may change if it was created in the database.

Source code in roc/graphdb.py
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
@classmethod
def save(cls, n: Self, *, db: GraphDB | None = None) -> Self:
    """Save a node to persistent storage

    Writes the specified node to the GraphDB for persistent storage. If the node does not
    already exist in storage, it is created via the `create` method. If the node does exist, it
    is updated via the `update` method.

    If the _no_save flag is True on the node, the save request will be silently ignored.

    Args:
        n (Self): The Node to be saved
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: As a convenience, the node that was stored is returned. This may be useful
        since the the id of the node may change if it was created in the database.
    """
    if n._new:
        return cls.create(n, db=db)
    else:
        return cls.update(n, db=db)

to_dict(n, include_labels=False) staticmethod

Convert a Node to a Python dictionary

Source code in roc/graphdb.py
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
@staticmethod
def to_dict(n: Node, include_labels: bool = False) -> dict[str, Any]:
    """Convert a Node to a Python dictionary"""
    # XXX: the excluded fields below shouldn't have been included in the
    # first place because Pythonic should exclude fields with underscores
    ret = n.model_dump(exclude={"_id", "_src_edges", "_dst_edges"})

    if include_labels and hasattr(n, "labels"):
        ret["labels"] = n.labels

    return ret

to_id(n) staticmethod

Source code in roc/graphdb.py
1392
1393
1394
1395
1396
1397
@staticmethod
def to_id(n: Node | NodeId) -> NodeId:
    if isinstance(n, Node):
        return n.id
    else:
        return n

update(n, *, db=None) classmethod

Update an existing node in the GraphDB.

Calling save is preferred to using this method so that the caller doesn't need to know the state of the node.

Parameters:

Name Type Description Default
n Self

The node to be updated

required
db GraphDB | None

the graph database to use, or None to use the GraphDB singleton

None

Returns:

Name Type Description
Self Self

The node that was passed in, for convenience

Source code in roc/graphdb.py
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
@classmethod
def update(cls, n: Self, *, db: GraphDB | None = None) -> Self:
    """Update an existing node in the GraphDB.

    Calling `save` is preferred to using this method so that the caller doesn't need to know the
    state of the node.

    Args:
        n (Self): The node to be updated
        db (GraphDB | None): the graph database to use, or None to use the GraphDB singleton

    Returns:
        Self: The node that was passed in, for convenience
    """
    if n._no_save:
        return n

    db = db or GraphDB.singleton()

    orig_labels = n._orig_labels
    curr_labels = set(n.labels)
    new_labels = curr_labels - orig_labels
    rm_labels = orig_labels - curr_labels
    set_label_str = Node.mklabels(new_labels)
    if set_label_str:
        set_query = f"SET n{set_label_str}, n = $props"
    else:
        set_query = "SET n = $props"
    rm_label_str = Node.mklabels(rm_labels)
    if rm_label_str:
        rm_query = f"REMOVE n{rm_label_str}"
    else:
        rm_query = ""

    params = {"props": Node.to_dict(n)}

    db.raw_execute(f"MATCH (n) WHERE id(n) = {n.id} {set_query} {rm_query}", params=params)

    return n

walk(n, *, mode='both', edge_filter=None, node_filter=None, node_callback=None, _walk_history=None) staticmethod

Source code in roc/graphdb.py
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
@staticmethod
def walk(
    n: Node,
    *,
    mode: WalkMode = "both",
    edge_filter: EdgeFilterFn | None = None,
    # edge_callback: EdgeCallbackFn | None = None,
    node_filter: NodeFilterFn | None = None,
    node_callback: NodeCallbackFn | None = None,
    _walk_history: set[int] | None = None,
) -> None:
    # if we have walked this node before, just return
    _walk_history = _walk_history or set()
    if n.id in _walk_history:
        return
    _walk_history.add(n.id)

    edge_filter = edge_filter or cast(EdgeFilterFn, true_filter)
    node_filter = node_filter or true_filter
    # edge_callback = edge_callback or no_callback
    node_callback = node_callback or no_callback

    # callback for this node, if not filtered
    if node_filter(n):
        node_callback(n)
    else:
        return

    if mode == "src" or mode == "both":
        for e in n.src_edges:
            if edge_filter(e):
                Node.walk(
                    e.dst,
                    mode=mode,
                    edge_filter=edge_filter,
                    # edge_callback=edge_callback,
                    node_filter=node_filter,
                    node_callback=node_callback,
                    _walk_history=_walk_history,
                )

    if mode == "dst" or mode == "both":
        for e in n.dst_edges:
            if edge_filter(e):
                Node.walk(
                    e.src,
                    mode=mode,
                    edge_filter=edge_filter,
                    # edge_callback=edge_callback,
                    node_filter=node_filter,
                    node_callback=node_callback,
                    _walk_history=_walk_history,
                )

NodeCreationFailed

Bases: Exception

An exception raised when trying to create a Node in the graph database fails

Source code in roc/graphdb.py
774
775
class NodeCreationFailed(Exception):
    """An exception raised when trying to create a Node in the graph database fails"""

NodeDescription

Bases: ModelDescription

Source code in roc/graphdb.py
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
class NodeDescription(ModelDescription):
    def __init__(self, node_cls: type[Node]) -> None:
        super().__init__(node_cls)

        self.name = node_cls.__name__

    def __str__(self) -> str:
        return f"NodeDesc({self.name})"

    def to_mermaid(self, indent: int = 4) -> str:
        ret = f"""\n{' ':>{indent}}%% Node: {self.name}\n"""

        # add fields
        for field in self.fields:
            sym = "+" if is_local(self.model, field.name) else "^"
            default_val = (
                f" = {field.default_val_str}" if field.default_val is not PydanticUndefined else ""
            )
            ret += f"""{' ':>{indent}}{self.name}: {sym}{field.type} {field.name}{default_val}\n"""

        # add methods
        for method in self.methods:
            sym = "+" if is_local(self.model, method.name) else "^"
            params = ", ".join(method.uml_params)
            ret += f"""{' ':>{indent}}{self.name}: {sym}{method.name}({params}) {method.return_type}\n"""

        # add links to inherited nodes
        for parent in self.parent_class_names:
            ret += f"""{' ':>{indent}}{self.name} ..|> {parent}: inherits\n"""

        return ret

name = node_cls.__name__ instance-attribute

__init__(node_cls)

Source code in roc/graphdb.py
1754
1755
1756
1757
def __init__(self, node_cls: type[Node]) -> None:
    super().__init__(node_cls)

    self.name = node_cls.__name__

__str__()

Source code in roc/graphdb.py
1759
1760
def __str__(self) -> str:
    return f"NodeDesc({self.name})"

to_mermaid(indent=4)

Source code in roc/graphdb.py
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
def to_mermaid(self, indent: int = 4) -> str:
    ret = f"""\n{' ':>{indent}}%% Node: {self.name}\n"""

    # add fields
    for field in self.fields:
        sym = "+" if is_local(self.model, field.name) else "^"
        default_val = (
            f" = {field.default_val_str}" if field.default_val is not PydanticUndefined else ""
        )
        ret += f"""{' ':>{indent}}{self.name}: {sym}{field.type} {field.name}{default_val}\n"""

    # add methods
    for method in self.methods:
        sym = "+" if is_local(self.model, method.name) else "^"
        params = ", ".join(method.uml_params)
        ret += f"""{' ':>{indent}}{self.name}: {sym}{method.name}({params}) {method.return_type}\n"""

    # add links to inherited nodes
    for parent in self.parent_class_names:
        ret += f"""{' ':>{indent}}{self.name} ..|> {parent}: inherits\n"""

    return ret

NodeFetchIterator

The implementation of an iterator for an NodeList. Only intended to be used internally by NodeList.

Source code in roc/graphdb.py
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
class NodeFetchIterator:
    """The implementation of an iterator for an NodeList. Only intended to be used internally by
    NodeList.
    """

    def __init__(self, node_list: list[NodeId]):
        self._node_list = node_list
        self.cur = 0

    def __iter__(self) -> NodeFetchIterator:
        return self

    def __next__(self) -> Node:
        if self.cur >= len(self._node_list):
            raise StopIteration

        id = self._node_list[self.cur]
        self.cur = self.cur + 1
        return Node.get(id)

cur = 0 instance-attribute

__init__(node_list)

Source code in roc/graphdb.py
1462
1463
1464
def __init__(self, node_list: list[NodeId]):
    self._node_list = node_list
    self.cur = 0

__iter__()

Source code in roc/graphdb.py
1466
1467
def __iter__(self) -> NodeFetchIterator:
    return self

__next__()

Source code in roc/graphdb.py
1469
1470
1471
1472
1473
1474
1475
def __next__(self) -> Node:
    if self.cur >= len(self._node_list):
        raise StopIteration

    id = self._node_list[self.cur]
    self.cur = self.cur + 1
    return Node.get(id)

NodeList

Bases: MutableSet[Node | NodeId], Mapping[int, Node]

A list of Nodes. Implements interfaces for both a MutableSet (i.e. set()) and a Mapping (i.e. read-only dict())

Source code in roc/graphdb.py
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
class NodeList(MutableSet[Node | NodeId], Mapping[int, Node]):
    """A list of Nodes. Implements interfaces for both a MutableSet (i.e. set())
    and a Mapping (i.e. read-only dict())
    """

    def __init__(self, ids: Iterable[NodeId]):
        self._nodes: list[NodeId] = list(ids)

    def __iter__(self) -> NodeFetchIterator:
        return NodeFetchIterator(self._nodes)

    def __getitem__(self, key: int) -> Node:
        return Node.get(self._nodes[key])

    def __len__(self) -> int:
        return len(self._nodes)

    def __contains__(self, n: Any) -> bool:
        if isinstance(n, Node) or isinstance(n, int):
            n_id = Node.to_id(n)  # type: ignore
        else:
            return False

        return n_id in self._nodes

    def __add__(self, l2: NodeList) -> NodeList:
        return NodeList(self._nodes + l2._nodes)

    def add(self, n: Node | NodeId) -> None:
        """Adds a new Node to the list"""
        n_id = Node.to_id(n)

        if n_id in self._nodes:
            return

        self._nodes.append(n_id)

    def discard(self, n: Node | NodeId) -> None:
        """Removes an Node from the list"""
        n_id = Node.to_id(n)

        self._nodes.remove(n_id)

    def select(
        self,
        *,
        filter_fn: NodeFilterFn | None = None,
        labels: set[str] | str | None = None,
    ) -> NodeList:
        node_ids = self._nodes
        if filter_fn is not None:
            Node.get_many(node_ids)
            node_ids = [n for n in node_ids if filter_fn(Node.get(n))]

        if labels is not None:
            labels = set(labels) if isinstance(labels, str) else labels
            node_ids = [n for n in node_ids if Node.get(n).labels == labels]

        return NodeList(node_ids)

__add__(l2)

Source code in roc/graphdb.py
1503
1504
def __add__(self, l2: NodeList) -> NodeList:
    return NodeList(self._nodes + l2._nodes)

__contains__(n)

Source code in roc/graphdb.py
1495
1496
1497
1498
1499
1500
1501
def __contains__(self, n: Any) -> bool:
    if isinstance(n, Node) or isinstance(n, int):
        n_id = Node.to_id(n)  # type: ignore
    else:
        return False

    return n_id in self._nodes

__getitem__(key)

Source code in roc/graphdb.py
1489
1490
def __getitem__(self, key: int) -> Node:
    return Node.get(self._nodes[key])

__init__(ids)

Source code in roc/graphdb.py
1483
1484
def __init__(self, ids: Iterable[NodeId]):
    self._nodes: list[NodeId] = list(ids)

__iter__()

Source code in roc/graphdb.py
1486
1487
def __iter__(self) -> NodeFetchIterator:
    return NodeFetchIterator(self._nodes)

__len__()

Source code in roc/graphdb.py
1492
1493
def __len__(self) -> int:
    return len(self._nodes)

add(n)

Adds a new Node to the list

Source code in roc/graphdb.py
1506
1507
1508
1509
1510
1511
1512
1513
def add(self, n: Node | NodeId) -> None:
    """Adds a new Node to the list"""
    n_id = Node.to_id(n)

    if n_id in self._nodes:
        return

    self._nodes.append(n_id)

discard(n)

Removes an Node from the list

Source code in roc/graphdb.py
1515
1516
1517
1518
1519
def discard(self, n: Node | NodeId) -> None:
    """Removes an Node from the list"""
    n_id = Node.to_id(n)

    self._nodes.remove(n_id)

select(*, filter_fn=None, labels=None)

Source code in roc/graphdb.py
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
def select(
    self,
    *,
    filter_fn: NodeFilterFn | None = None,
    labels: set[str] | str | None = None,
) -> NodeList:
    node_ids = self._nodes
    if filter_fn is not None:
        Node.get_many(node_ids)
        node_ids = [n for n in node_ids if filter_fn(Node.get(n))]

    if labels is not None:
        labels = set(labels) if isinstance(labels, str) else labels
        node_ids = [n for n in node_ids if Node.get(n).labels == labels]

    return NodeList(node_ids)

NodeNotFound

Bases: Exception

An exception raised when trying to retreive a Node that doesn't exist.

Source code in roc/graphdb.py
770
771
class NodeNotFound(Exception):
    """An exception raised when trying to retreive a Node that doesn't exist."""

Schema

Source code in roc/graphdb.py
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
class Schema:
    def __init__(self, skip_validation: bool = False) -> None:
        if not skip_validation:
            self.validate()

        # edges
        self.edge_names = set(edge_registry.keys())
        self.edges = [EdgeDescription(edge_cls) for edge_cls in edge_registry.values()]
        self.edges.sort(key=lambda e: e.name)

        # nodes
        self.node_names = set(node_registry.keys())
        self.nodes = [NodeDescription(node_cls) for node_cls in node_registry.values()]
        self.nodes.sort(key=lambda n: n.name)

    @classmethod
    def validate(cls) -> None:
        errors: list[str] = []
        for edge_name, edge_cls in edge_registry.items():
            allowed_connections = pydantic_get_default(edge_cls, "allowed_connections")

            if allowed_connections is None:
                continue

            for src, dst in allowed_connections:
                if src not in node_registry:
                    errors.append(
                        f"Edge '{edge_name}' requires src Node '{src}', which is not registered"
                    )

                if dst not in node_registry:
                    errors.append(
                        f"Edge '{edge_name}' requires dst Node '{dst}', which is not registered"
                    )

        if len(errors) > 0:
            raise SchemaValidationError(errors)

    def to_mermaid(self) -> str:
        ret = "classDiagram\n"

        # nodes
        for n in self.nodes:
            ret += n.to_mermaid()

        # edges
        for e in self.edges:
            ret += e.to_mermaid()

        return ret

    @classmethod
    def _repr_markdown_(cls) -> str:
        return f"``` mermaid\n{Schema().to_mermaid()}\n```\n"

edge_names = set(edge_registry.keys()) instance-attribute

edges = [EdgeDescription(edge_cls) for edge_cls in edge_registry.values()] instance-attribute

node_names = set(node_registry.keys()) instance-attribute

nodes = [NodeDescription(node_cls) for node_cls in node_registry.values()] instance-attribute

__init__(skip_validation=False)

Source code in roc/graphdb.py
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
def __init__(self, skip_validation: bool = False) -> None:
    if not skip_validation:
        self.validate()

    # edges
    self.edge_names = set(edge_registry.keys())
    self.edges = [EdgeDescription(edge_cls) for edge_cls in edge_registry.values()]
    self.edges.sort(key=lambda e: e.name)

    # nodes
    self.node_names = set(node_registry.keys())
    self.nodes = [NodeDescription(node_cls) for node_cls in node_registry.values()]
    self.nodes.sort(key=lambda n: n.name)

to_mermaid()

Source code in roc/graphdb.py
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
def to_mermaid(self) -> str:
    ret = "classDiagram\n"

    # nodes
    for n in self.nodes:
        ret += n.to_mermaid()

    # edges
    for e in self.edges:
        ret += e.to_mermaid()

    return ret

validate() classmethod

Source code in roc/graphdb.py
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
@classmethod
def validate(cls) -> None:
    errors: list[str] = []
    for edge_name, edge_cls in edge_registry.items():
        allowed_connections = pydantic_get_default(edge_cls, "allowed_connections")

        if allowed_connections is None:
            continue

        for src, dst in allowed_connections:
            if src not in node_registry:
                errors.append(
                    f"Edge '{edge_name}' requires src Node '{src}', which is not registered"
                )

            if dst not in node_registry:
                errors.append(
                    f"Edge '{edge_name}' requires dst Node '{dst}', which is not registered"
                )

    if len(errors) > 0:
        raise SchemaValidationError(errors)

SchemaValidationError

Bases: Exception

Source code in roc/graphdb.py
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
class SchemaValidationError(Exception):
    def __init__(self, errors: list[str]) -> None:
        err_str = ""
        self.errors = errors

        for errno in range(len(errors)):
            err = errors[errno]
            err_str += f"\t{errno}: {err}\n"

        super().__init__(f"Error validating schema:\n{err_str}")

errors = errors instance-attribute

__init__(errors)

Source code in roc/graphdb.py
1555
1556
1557
1558
1559
1560
1561
1562
1563
def __init__(self, errors: list[str]) -> None:
    err_str = ""
    self.errors = errors

    for errno in range(len(errors)):
        err = errors[errno]
        err_str += f"\t{errno}: {err}\n"

    super().__init__(f"Error validating schema:\n{err_str}")

StrictSchemaWarning

Bases: Warning

A warning that strict schema mode is enabled, but there was a violation

Source code in roc/graphdb.py
69
70
class StrictSchemaWarning(Warning):
    """A warning that strict schema mode is enabled, but there was a violation"""

check_schema(edge_cls, clstype, src, dst, db)

Source code in roc/graphdb.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
def check_schema(
    edge_cls: type[Edge],
    clstype: str,
    src: Node,
    dst: Node,
    db: GraphDB,
) -> None:
    allowed_connections = pydantic_get_default(edge_cls, "allowed_connections")
    src_name = src.__class__.__name__
    src_names = get_node_parent_names(src.__class__)
    src_names.add(src_name)

    dst_name = dst.__class__.__name__
    dst_names = get_node_parent_names(dst.__class__)
    dst_names.add(dst_name)

    # check if the src (or it's parents) are allowed to connect to dst (or it's parents)
    if allowed_connections is not None:
        found = False
        for conn in allowed_connections:
            if conn[0] in src_names and conn[1] in dst_names:
                found = True
                break

        if not found:
            raise Exception(
                f"attempting to connect edge '{clstype}' from '{src_name}' to '{dst_name}' not in allowed connections list"
            )
    # no allowed_connections set, which is a no-no for strict mode
    elif db.strict_schema:
        err_msg = f"allowed_connections missing in '{edge_cls.__name__}' and strict_schema is set"
        if db.strict_schema_warns:
            warnings.warn(err_msg, StrictSchemaWarning)
        else:
            raise Exception(err_msg)

clean_annotation(annotation)

Source code in roc/graphdb.py
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
def clean_annotation(annotation: Any) -> str:
    import typing

    if isinstance(annotation, str):
        return annotation
    elif annotation is None:
        return "None"
    elif isinstance(annotation, typing._GenericAlias):  # type: ignore
        # Handle generics like List, Dict, etc.
        origin = annotation.__origin__
        args = [clean_annotation(arg) for arg in annotation.__args__]
        return f"{origin.__name__}[{', '.join(args)}]"
    elif isinstance(annotation, _SpecialForm):
        # Handle special forms like Any, Union, etc.
        return annotation._name  # type: ignore
    else:
        return annotation.__name__  # type: ignore

get_methods(c) cached

Source code in roc/graphdb.py
1673
1674
1675
@functools.cache
def get_methods(c: type[object]) -> set[str]:
    return {name for name, member in inspect.getmembers(c) if inspect.isfunction(member)}

get_next_new_edge_id()

Source code in roc/graphdb.py
296
297
298
299
300
301
def get_next_new_edge_id() -> EdgeId:
    global next_new_edge
    id = next_new_edge
    next_new_edge = cast(EdgeId, next_new_edge - 1)

    return id

get_next_new_node_id()

Source code in roc/graphdb.py
778
779
780
781
782
def get_next_new_node_id() -> NodeId:
    global next_new_node
    id = next_new_node
    next_new_node = cast(NodeId, next_new_node - 1)
    return id

get_node_parent_names(model)

Source code in roc/graphdb.py
1644
1645
1646
1647
1648
1649
1650
1651
def get_node_parent_names(model: type[BaseModel]) -> set[str]:
    ret = {c.__name__ for c in model.__mro__ if Node in c.__mro__}
    if model.__name__ in ret:
        ret.remove(model.__name__)
    if "Node" in ret:
        ret.remove("Node")

    return ret

is_local(c, attr)

Source code in roc/graphdb.py
1634
1635
1636
1637
1638
1639
1640
1641
def is_local(c: type[object], attr: str) -> bool:
    if attr in c.__dict__:
        return True

    if hasattr(c, "__wrapped__"):
        return is_local(c.__wrapped__, attr)

    return False

no_callback(_)

Helper function that accepts any value and returns None. Great for default callback functions.

Source code in roc/graphdb.py
55
56
57
58
def no_callback(_: Any) -> None:
    """Helper function that accepts any value and returns None. Great for
    default callback functions.
    """

pydantic_get_default(m, f)

Source code in roc/graphdb.py
1630
1631
def pydantic_get_default(m: type[BaseModel], f: str) -> Any:
    return m.model_fields[f].get_default(call_default_factory=True)

pydantic_get_field(m, f)

Source code in roc/graphdb.py
1626
1627
def pydantic_get_field(m: type[BaseModel], f: str) -> FieldInfo:
    return m.model_fields[f]

pydantic_get_fields(m)

Source code in roc/graphdb.py
1622
1623
def pydantic_get_fields(m: type[BaseModel]) -> set[str]:
    return set(m.model_fields.keys())

true_filter(_)

Helper function that accepts any value and returns True. Great for default filters.

Source code in roc/graphdb.py
48
49
50
51
52
def true_filter(_: Any) -> bool:
    """Helper function that accepts any value and returns True. Great for
    default filters.
    """
    return True