Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions cq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._core.related_events import AnyIORelatedEvents, RelatedEvents

__all__ = (
"__cq__",
"AnyCommandBus",
"AnyIORelatedEvents",
"Bus",
Expand Down Expand Up @@ -55,30 +56,27 @@
)

try:
from cq.ext.injection import InjectionAdapter as _InjectionAdapter
from .ext.injection import InjectionAdapter as _InjectionAdapter

except ImportError: # pragma: no cover
_default = CQ(_NoDI())
__cq__ = CQ(_NoDI())

else:
_default = CQ(_InjectionAdapter())
__cq__ = CQ(_InjectionAdapter())

_default.register_defaults()
__cq__.register_defaults()

command_handler = _default.command_handler
event_handler = _default.event_handler
query_handler = _default.query_handler
command_handler = __cq__.command_handler
event_handler = __cq__.event_handler
query_handler = __cq__.query_handler

new_command_bus = _default.new_command_bus
new_event_bus = _default.new_event_bus
new_query_bus = _default.new_query_bus
new_command_bus = __cq__.new_command_bus
new_event_bus = __cq__.new_event_bus
new_query_bus = __cq__.new_query_bus


class ContextCommandPipeline[C: Command](_ContextCommandPipeline[C]):
__slots__ = ()

def __init__(self, di: DIAdapter = _default.di) -> None:
def __init__(self, di: DIAdapter = __cq__.di) -> None:
super().__init__(di)


del _default
13 changes: 13 additions & 0 deletions cq/_core/cq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import KeysView
from typing import Any, Self

from cq._core.di import DIAdapter
Expand Down Expand Up @@ -34,14 +35,26 @@ def di(self) -> DIAdapter:
def command_handler(self) -> HandlerDecorator[Command, Any]:
return HandlerDecorator(self.__command_registry, self.__di)

@property
def command_types(self) -> KeysView[type[Command]]:
return self.__command_registry.message_types

@property
def event_handler(self) -> HandlerDecorator[Event, Any]:
return HandlerDecorator(self.__event_registry, self.__di)

@property
def event_types(self) -> KeysView[type[Event]]:
return self.__event_registry.message_types

@property
def query_handler(self) -> HandlerDecorator[Query, Any]:
return HandlerDecorator(self.__query_registry, self.__di)

@property
def query_types(self) -> KeysView[type[Query]]:
return self.__query_registry.message_types

def new_command_bus(self) -> Bus[Command, Any]:
bus = SimpleBus(self.__command_registry)
command_middleware = CommandDispatchScopeMiddleware(self.__di)
Expand Down
15 changes: 14 additions & 1 deletion cq/_core/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable, Iterator
from collections.abc import Awaitable, Callable, Iterator, KeysView
from dataclasses import dataclass, field
from functools import partial
from inspect import Parameter, isclass, unwrap
Expand Down Expand Up @@ -49,6 +49,11 @@ def create(
class HandlerRegistry[I, O](Protocol):
__slots__ = ()

@property
@abstractmethod
def message_types(self) -> KeysView[type[I]]:
raise NotImplementedError

@abstractmethod
def handlers_from(self, message_type: type[I]) -> Iterator[HandleFunction[[I], O]]:
raise NotImplementedError
Expand All @@ -71,6 +76,10 @@ class MultipleHandlerRegistry[I, O](HandlerRegistry[I, O]):
init=False,
)

@property
def message_types(self) -> KeysView[type[I]]:
return self.__values.keys()

def handlers_from(self, message_type: type[I]) -> Iterator[HandleFunction[[I], O]]:
for key_type in _iter_key_types(message_type):
yield from self.__values.get(key_type, ())
Expand All @@ -97,6 +106,10 @@ class SingleHandlerRegistry[I, O](HandlerRegistry[I, O]):
init=False,
)

@property
def message_types(self) -> KeysView[type[I]]:
return self.__values.keys()

def handlers_from(self, message_type: type[I]) -> Iterator[HandleFunction[[I], O]]:
for key_type in _iter_key_types(message_type):
function = self.__values.get(key_type, None)
Expand Down