Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions cq/middlewares/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Concatenate, Self

from cq import MiddlewareResult

__all__ = ("CaptureExceptionMiddleware",)


class CaptureExceptionMiddleware[**P, Exc: BaseException]:
__slots__ = ("__exceptions", "__on_error", "__reraise")

__exceptions: tuple[type[Exc], ...]
__on_error: Callable[Concatenate[Exc, P], Awaitable[Any]]
__reraise: bool

def __init__(
self,
on_error: Callable[Concatenate[Exc, P], Awaitable[Any]],
/,
exceptions: Sequence[type[Exc]] | None = None,
reraise: bool = False,
) -> None:
self.__exceptions = (Exception,) if exceptions is None else tuple(exceptions) # type: ignore[assignment]
self.__on_error = on_error
self.__reraise = reraise

async def __call__(
self,
/,
*args: P.args,
**kwargs: P.kwargs,
) -> MiddlewareResult[Any]:
try:
yield
except self.__exceptions as exc:
await self.__on_error(exc, *args, **kwargs)
if self.__reraise:
raise

@classmethod
def sync(
cls,
on_error: Callable[Concatenate[Exc, P], Any],
/,
exceptions: Sequence[type[Exc]] | None = None,
reraise: bool = False,
) -> Self:
async def async_on_error(
exception: Exc,
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Any:
return on_error(exception, *args, **kwargs)

return cls(async_on_error, exceptions, reraise)
23 changes: 23 additions & 0 deletions docs/guides/configuring.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,26 @@ The parameters are:
* `exceptions`: the exception types that trigger a retry. Defaults to `(Exception,)`, which retries on any non-`BaseException` failure.

If every attempt fails, the last exception is re-raised.

### `CaptureExceptionMiddleware`

`cq.middlewares.exc.CaptureExceptionMiddleware` catches exceptions raised by downstream handlers and forwards them to a callback. Use it to log, report, or push errors to an external sink without changing how they propagate:

```python
from cq import new_command_bus
from cq.middlewares.exc import CaptureExceptionMiddleware

async def report(exception, message):
sentry_sdk.capture_exception(exception)

bus = new_command_bus()
bus.add_middlewares(CaptureExceptionMiddleware(report, reraise=True))
```

The parameters are:

* `on_error`: an async callback invoked with the captured exception followed by the same arguments the handler received (typically the message). Use `CaptureExceptionMiddleware.sync(...)` if your callback is synchronous.
* `exceptions`: the exception types to capture. Defaults to `(Exception,)`.
* `reraise`: whether to re-raise the exception after the callback returns. Defaults to `False`, in which case the exception is swallowed.

`on_error` is meant for side effects only (logging, metrics, notifications) and must not raise. If it does, its own exception will propagate in place of the original one.
60 changes: 60 additions & 0 deletions tests/middlewares/test_exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Self

import anyio
import pytest

from cq import Bus
from cq.middlewares.exc import CaptureExceptionMiddleware


class TestCaptureExceptionMiddleware:
async def test_capture_exception_middleware_with_success(
self,
bus: Bus[Any, Any],
) -> None:
class Handler:
async def handle(self, message: str) -> str:
raise Exception

@classmethod
async def async_factory(cls) -> Self:
return cls()

captured = anyio.Event()

def capture(exc: Exception, message: str) -> None:
captured.set()

bus.add_middlewares(CaptureExceptionMiddleware.sync(capture))
bus.subscribe(str, Handler.async_factory)

assert not captured.is_set()
await bus.dispatch("Hello world!")
assert captured.is_set()

async def test_capture_exception_middleware_with_reraise(
self,
bus: Bus[Any, Any],
) -> None:
class Handler:
async def handle(self, message: str) -> str:
raise Exception

@classmethod
async def async_factory(cls) -> Self:
return cls()

captured = anyio.Event()

def capture(exc: Exception, message: str) -> None:
captured.set()

bus.add_middlewares(CaptureExceptionMiddleware.sync(capture, reraise=True))
bus.subscribe(str, Handler.async_factory)

assert not captured.is_set()

with pytest.raises(Exception):
await bus.dispatch("Hello world!")

assert captured.is_set()