diff --git a/pyproject.toml b/pyproject.toml index d9a6672..32b834c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ disable = [ "too-many-positional-arguments", "too-many-public-methods", "too-many-return-statements", + "too-many-instance-attributes", ] diff --git a/sequence_layers/jax/dense.py b/sequence_layers/jax/dense.py index 5fbadb0..7abae4c 100644 --- a/sequence_layers/jax/dense.py +++ b/sequence_layers/jax/dense.py @@ -15,15 +15,16 @@ import dataclasses import typing -from typing import Callable +from typing import Callable, override import flax.linen as nn import jax import jax.numpy as jnp + from sequence_layers.jax import meta from sequence_layers.jax import types from sequence_layers.jax import utils - +from sequence_layers.specs import dense as spec __all__ = ( # go/keep-sorted start @@ -34,11 +35,11 @@ ) -class Dense(types.Stateless, utils.EinsumCommon): +class Dense(types.Stateless, utils.EinsumCommon, spec.Dense): """A basic dense layer.""" @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.Dense.Config): """Dense config.""" # The number of output features for the dense layer. @@ -73,6 +74,8 @@ class Config(types.SequenceLayerConfig): def make(self) -> 'Dense': return Dense(self, name=self.name) + + config: Config @nn.nowrap @@ -269,7 +272,7 @@ def layer( ) -class EinsumDense(types.Stateless, utils.EinsumCommon): +class EinsumDense(types.Stateless, utils.EinsumCommon, spec.EinsumDense): """A dense layer that transforms the channel shape with an einsum equation. Equation input and output specs must have leading ellipses to broadcast over @@ -291,7 +294,7 @@ class EinsumDense(types.Stateless, utils.EinsumCommon): """ @dataclasses.dataclass(frozen=True) - class Config(types.SequenceLayerConfig): + class Config(types.SequenceLayerConfig, spec.EinsumDense.Config): """EinsumDense config.""" # An equation describing the einsum to perform. This equation must be a @@ -338,6 +341,8 @@ def __post_init__(self): def make(self) -> 'EinsumDense': return EinsumDense(self, name=self.name) + + config: Config @nn.nowrap diff --git a/sequence_layers/jax/dense_test.py b/sequence_layers/jax/dense_test.py index 0edd20b..3820c1f 100644 --- a/sequence_layers/jax/dense_test.py +++ b/sequence_layers/jax/dense_test.py @@ -19,21 +19,14 @@ import flax.linen as nn import jax import jax.numpy as jnp + from sequence_layers.jax import dense from sequence_layers.jax import test_utils from sequence_layers.jax import types +from sequence_layers.specs import dense_behaviors as spec -class DenseTest(test_utils.SequenceLayerTest): - - def test_rank2_unsupported(self): - key = jax.random.PRNGKey(1234) - l = dense.Dense.Config( - 3, bias_init=nn.initializers.normal(), name='dense' - ).make() - x = test_utils.random_sequence(2, 13) - with self.assertRaises(ValueError): - self.init_and_bind_layer(key, l, x) +class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest): @parameterized.parameters(((5,),), ((5, 7),)) def test_dense(self, channels_shape): @@ -49,7 +42,7 @@ def test_dense(self, channels_shape): self.assertEqual( l.get_output_shape_for_sequence(x), channels_shape[:-1] + (3,) ) - self.verify_contract(l, x, training=False, grad_rtol=1e-5, grad_atol=1e-5) + self.verify_contract(l, x, training=False, rtol=1e-5, atol=1e-5, grad_rtol=1e-5, grad_atol=1e-5) chex.assert_trees_all_equal_shapes_and_dtypes( flax.core.meta.unbox(l.variables), @@ -61,17 +54,6 @@ def test_dense(self, channels_shape): }, ) - @parameterized.parameters(True, False) - def test_use_bias(self, use_bias): - """Check that use_bias controls whether a bias is created.""" - key = jax.random.PRNGKey(1234) - l = dense.Dense.Config(3, use_bias=use_bias).make() - x = test_utils.random_sequence(2, 3, 5) - l = self.init_and_bind_layer(key, l, x) - self.assertCountEqual( - l.variables['params'], ['kernel', 'bias'] if use_bias else ['kernel'] - ) - def test_use_einsum_factory(self): """Check that einsum_factory produces is used for dense einsum.""" @@ -254,7 +236,7 @@ def test_dtypes(self, param_dtype, input_dtype, compute_dtype, use_bias): ) -class EinsumDenseTest(test_utils.SequenceLayerTest): +class EinsumDenseTest(test_utils.SequenceLayerTest, spec.EinsumDenseTest): @parameterized.parameters( ( @@ -461,22 +443,22 @@ def custom_einsum(equation, *args, **kwargs): @parameterized.product( test_utils.standard_dtype_configs(), ( - dict( - shape=(2, 3, 5, 7, 11), - equation='...abc,bd->...bd', - output_shape=(None, 13), - expected_kernel_shape=(7, 13), - bias_axes='', - expected_bias_shape=None, - ), - dict( - shape=(2, 3, 5), - equation='...a,abcd->...bcd', - output_shape=(7, 11, 13), - expected_kernel_shape=(5, 7, 11, 13), - bias_axes='cd', - expected_bias_shape=(11, 13), - ), + { + 'shape': (2, 3, 5, 7, 11), + 'equation': '...abc,bd->...bd', + 'output_shape': (None, 13), + 'expected_kernel_shape': (7, 13), + 'bias_axes': '', + 'expected_bias_shape': None, + }, + { + 'shape': (2, 3, 5), + 'equation': '...a,abcd->...bcd', + 'output_shape': (7, 11, 13), + 'expected_kernel_shape': (5, 7, 11, 13), + 'bias_axes': 'cd', + 'expected_bias_shape': (11, 13), + }, ), ) def test_dtypes( @@ -536,27 +518,6 @@ def test_dtypes( ).mask_invalid() self.assertSequencesClose(y, y_expected) - def test_einsum_dense_nonbroadcasting_equation(self): - with self.assertRaises(ValueError): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 4, 5, 6) - l = dense.EinsumDense.Config( - 'btabc,bc->btad', output_shape=[None, 2] - ).make() - self.init_and_bind_layer(key, l, x) - - def test_einsum_dense_inconsistent_input_shape(self): - key = jax.random.PRNGKey(1234) - x = test_utils.random_sequence(2, 3, 5) - l = dense.EinsumDense.Config( - '...abc,bc->...ad', output_shape=[None, 2] - ).make() - with self.assertRaises(ValueError): - self.init_and_bind_layer(key, l, x) - # Show it works with the right input shape. - x = test_utils.random_sequence(2, 3, 5, 7, 11) - self.assertEqual(l.get_output_shape_for_sequence(x), (5, 2)) - if __name__ == '__main__': test_utils.main() diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 6a17923..4eab10f 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -14,6 +14,7 @@ """Sequence layers in MLX.""" from . import backend +from . import dense from . import simple from . import types # CRITICAL: Do NOT use wildcard imports (e.g., `from .simple import *`) here. @@ -26,6 +27,8 @@ # Explicit imports (e.g., `from .simple import Relu`) DO NOT trigger this issue. # If you need to expose specific layers at the package level, import them # explicitly instead of using a star import. +from .dense import Dense +from .dense import EinsumDense from .simple import Abs from .simple import Add from .simple import Cast @@ -65,6 +68,7 @@ from .types import SequenceLayerConfig __all__ = [ + 'dense', 'backend', 'types', 'simple', @@ -72,6 +76,8 @@ 'MaskedSequence', 'SequenceLayer', 'SequenceLayerConfig', + 'Dense', + 'EinsumDense', 'Identity', 'Relu', 'Gelu', diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py new file mode 100644 index 0000000..2472d76 --- /dev/null +++ b/sequence_layers/mlx/dense.py @@ -0,0 +1,250 @@ +"""Dense sequence layer for MLX.""" + +import dataclasses +from typing import Callable, override + +from mlx import nn +import mlx.core as mx + +from sequence_layers.mlx import types +from sequence_layers.mlx.simple import _to_mx_dtype +from sequence_layers.specs import dense as spec + + +class Dense(types.Stateless, spec.Dense): + """A basic dense layer with deferred initialization. + + Matches JAX interface where in_features is inferred on first call. + """ + + @dataclasses.dataclass(frozen=True) + class Config(types.SequenceLayerConfig, spec.Dense.Config): + """Dense config.""" + + features: int + use_bias: bool = True + activation: Callable | None = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + @override + def make(self) -> 'Dense': + return Dense(self) + + def __init__(self, config: Config): + """Initialize Dense.""" + super().__init__() + self.config = config + self._compute_dtype = _to_mx_dtype(config.compute_dtype) + self._param_dtype = _to_mx_dtype(config.param_dtype) + self._linear = None + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + def _ensure_initialized(self, in_features: int): + """Ensure nn.Linear is initialized on first call.""" + if self._linear is not None: + return + self._linear = nn.Linear( + in_features, self.config.features, bias=self.config.use_bias + ) + + @override + def get_output_shape(self, input_shape, *, constants=None): + """Get output shape.""" + if not input_shape: + raise ValueError( + f'Dense requires at least rank 3 input. Got: {input_shape=}' + ) + return tuple(input_shape[:-1]) + (self.config.features,) + + @override + def get_output_dtype(self, input_dtype, *, constants=None): + if self._compute_dtype is not None: + return self._compute_dtype + assert self._param_dtype is not None + return self._param_dtype + + @override + @types.check_layer + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + if x.ndim < 3: + raise ValueError(f'Dense requires at least rank 3 input. Got: {x.shape=}') + self._ensure_initialized(x.shape[-1]) + assert self._linear is not None + activation = self.config.activation + compute_dtype = self.get_output_dtype(x.dtype) + + def dense_fn(v): + y = self._linear(v.astype(compute_dtype)) + if activation is not None: + y = activation(y) + return y + + if self.config.use_bias or activation is not None: + return x.apply_values(dense_fn) + return x.apply_values_masked(dense_fn) + + +class EinsumDense(types.Stateless, spec.EinsumDense): + """Dense layer using Einstein summation notation.""" + + @dataclasses.dataclass(frozen=True) + class Config(types.SequenceLayerConfig, spec.EinsumDense.Config): + """MLX-native configuration for EinsumDense.""" + + equation: str = '' + output_shape: tuple[int | None, ...] = () + bias_axes: str = '' + activation: Callable | None = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + + @override + def make(self) -> 'EinsumDense': + return EinsumDense(self) + + def __init__(self, config: Config): + """Initialize EinsumDense.""" + super().__init__() + self.config = config + self._compute_dtype = _to_mx_dtype(config.compute_dtype) + self._param_dtype = _to_mx_dtype(config.param_dtype) + self.kernel = None + self.bias = None + self._initialized = False + self._resolved_output_shape = None + + @property + @override + def receptive_field(self) -> tuple[int, int]: + return (0, 0) + + def _ensure_initialized(self, input_shape): + """Ensure parameters are initialized.""" + if self._initialized: + return + output_shape, kernel_shape, bias_shape = _compute_shapes( + self.config.equation, + input_shape, + self.config.output_shape, + self.config.bias_axes, + ) + self._resolved_output_shape = output_shape + self.kernel = mx.zeros(kernel_shape, dtype=self._param_dtype) + if bias_shape is not None: + self.bias = mx.zeros(bias_shape, dtype=self._param_dtype) + self._initialized = True + + @override + def get_output_shape(self, input_shape, *, constants=None): + """Get output shape.""" + output_shape, _, _ = _compute_shapes( + self.config.equation, + input_shape, + self.config.output_shape, + self.config.bias_axes, + ) + return output_shape + + @override + def get_output_dtype(self, input_dtype, *, constants=None): + if self._compute_dtype is not None: + return self._compute_dtype + assert self._param_dtype is not None + return self._param_dtype + + @override + @types.check_layer + def layer( # pyrefly: ignore[missing-override-decorator] + self, x, *, training: bool, constants=None + ): + self._ensure_initialized(x.channel_shape) + compute_dtype = self.get_output_dtype(x.dtype) + activation = self.config.activation + + def einsum_fn(v): + y = mx.einsum(self.config.equation, v.astype(compute_dtype), self.kernel) + if self.bias is not None: + y = y + self.bias + if activation is not None: + y = activation(y) + return y + + if self.bias is not None or activation is not None: + return x.apply_values(einsum_fn) + return x.apply_values_masked(einsum_fn) + + +def _parse_equation(equation): + """Parse einsum equation of form '...ab,bc->...ac'.""" + if '->' not in equation: + raise ValueError(f'equation is not valid for EinsumDense: {equation}') + left, output_spec = equation.split('->') + input_spec, kernel_spec = left.split(',') + if not input_spec.startswith('...') or not output_spec.startswith('...'): + raise ValueError('Equation must be of the form "...X,Y->...Z".') + if 3 + len(set(input_spec[3:])) != len(input_spec): + raise ValueError( + f'Equation {input_spec=} must not contain duplicate variables.' + ) + if 3 + len(set(output_spec[3:])) != len(output_spec): + raise ValueError( + f'Equation {output_spec=} must not contain duplicate variables.' + ) + return input_spec, kernel_spec, output_spec + + +def _compute_shapes(equation, input_shape, output_shape_spec, bias_axes): + """Compute kernel_shape and bias_shape from equation and shapes.""" + input_spec, kernel_spec, output_spec = _parse_equation(equation) + in_spec = input_spec[3:] + out_spec = output_spec[3:] + + if len(in_spec) != len(input_shape): + raise ValueError(f'Equation {in_spec=} does not match {input_shape=} rank.') + + input_dims = {d: input_shape[i] for i, d in enumerate(in_spec)} + output_shape = list(output_shape_spec) + if len(out_spec) != len(output_shape): + raise ValueError(f'Equation {out_spec=} does not match {output_shape=}.') + + for i, d in enumerate(out_spec): + if output_shape[i] is None: + output_shape[i] = input_dims[d] + elif d in input_dims and output_shape[i] != input_dims[d]: + raise ValueError( + f'Inconsistent dimension {d=}. {output_shape=} vs {input_shape=}' + ) + + output_dim_map = {d: output_shape[i] for i, d in enumerate(out_spec)} + + kernel_shape = [] + for d in kernel_spec: + if d in input_dims: + kernel_shape.append(input_dims[d]) + elif d in output_dim_map: + kernel_shape.append(output_dim_map[d]) + else: + raise ValueError(f"Weight dimension '{d}' not in input or output spec.") + + if bias_axes: + first_bias_loc = min(out_spec.find(c) for c in bias_axes) + bias_out_spec = out_spec[first_bias_loc:] + bias_shape = [ + output_dim_map[c] if c in bias_axes else 1 for c in bias_out_spec + ] + else: + bias_shape = None + + return tuple(output_shape), tuple(kernel_shape), bias_shape diff --git a/sequence_layers/mlx/dense_test.py b/sequence_layers/mlx/dense_test.py new file mode 100644 index 0000000..a323f24 --- /dev/null +++ b/sequence_layers/mlx/dense_test.py @@ -0,0 +1,27 @@ +"""Tests for Dense MLX sequence layers.""" + +from absl.testing import absltest +from mlx import nn + +from sequence_layers.mlx import dense +from sequence_layers.mlx import test_utils +from sequence_layers.specs import dense_behaviors as spec + + +class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest): + """Test behavior of Dense layer.""" + + def test_activation(self): + """Test activation in Dense.""" + layer = dense.Dense.Config(features=8, activation=nn.relu).make() + x = self.random_sequence(2, 3, 4) + layer = self.init_layer(layer, x) + self.verify_contract(layer, x) + + +class EinsumDenseTest(test_utils.SequenceLayerTest, spec.EinsumDenseTest): + """Test behavior of EinsumDense layer.""" + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/specs/__init__.py b/sequence_layers/specs/__init__.py index 10df687..c5ee4cd 100644 --- a/sequence_layers/specs/__init__.py +++ b/sequence_layers/specs/__init__.py @@ -5,6 +5,7 @@ from typing import Protocol, runtime_checkable, TYPE_CHECKING from . import backend as _backend +from . import dense as _dense from . import simple as _simple from . import types as _types @@ -116,3 +117,11 @@ def Embedding(self) -> type[_simple.Embedding]: @property def Softmax(self) -> type[_simple.Softmax]: ... + + @property + def Dense(self) -> type[_dense.Dense]: + ... + + @property + def EinsumDense(self) -> type[_dense.EinsumDense]: + ... diff --git a/sequence_layers/specs/dense.py b/sequence_layers/specs/dense.py new file mode 100644 index 0000000..ec5e10d --- /dev/null +++ b/sequence_layers/specs/dense.py @@ -0,0 +1,47 @@ +"""Specifications for dense layers. + +See the corresponding _behaviors module for behaviors. +""" + +import abc +import dataclasses +from typing import Any, Callable, Sequence + +from sequence_layers.specs import types as types_spec + + +class Dense(types_spec.Stateless, metaclass=abc.ABCMeta): + """Specification for Dense layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for Dense layer.""" + + features: int + use_bias: bool = True + activation: Callable | None = None + compute_dtype: types_spec.DType | None = None + param_dtype: types_spec.DType | None = None + name: str | None = None + + def make(self) -> Any: + """Dummy make to satisfy Pyrefly.""" + + +class EinsumDense(types_spec.Stateless, metaclass=abc.ABCMeta): + """Specification for EinsumDense layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(types_spec.SequenceLayerConfig): + """Configuration for EinsumDense layer.""" + + equation: str + output_shape: Sequence[int | None] + bias_axes: str = '' + activation: Callable | None = None + compute_dtype: types_spec.DType | None = None + param_dtype: types_spec.DType | None = None + name: str | None = None + + def make(self) -> Any: + """Dummy make to satisfy Pyrefly.""" diff --git a/sequence_layers/specs/dense_behaviors.py b/sequence_layers/specs/dense_behaviors.py new file mode 100644 index 0000000..87a2773 --- /dev/null +++ b/sequence_layers/specs/dense_behaviors.py @@ -0,0 +1,109 @@ +"""Behavior tests for dense layers. + +Backend-specific test files should inherit from these tests. +""" + +# pylint: disable=abstract-method + +from absl.testing import parameterized + +from sequence_layers.specs import test_utils + + +class DenseTest(test_utils.SequenceLayerTest): + """Test behavior of Dense layer.""" + + def test_rank2_unsupported(self): + l = self.sl.Dense.Config(features=3, name='dense').make() + x = self.random_sequence(2, 13) + with self.assertRaises(ValueError): + l = self.init_layer(l, x) + l.layer(x, training=False) + + @parameterized.parameters(((5,),), ((5, 7),)) + def test_dense(self, channels_shape): + l = self.sl.Dense.Config(features=3, name='dense').make() + x = self.random_sequence(2, 13, *channels_shape, random_mask=True) + l = self.init_layer(l, x) + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.name, 'dense') + self.assertEqual( + l.get_output_shape_for_sequence(x), channels_shape[:-1] + (3,) + ) + self.verify_contract(l, x, training=False) + + @parameterized.parameters(True, False) + def test_use_bias(self, use_bias): + l = self.sl.Dense.Config(features=3, use_bias=use_bias).make() + x = self.random_sequence(2, 3, 5) + l = self.init_layer(l, x) + self.verify_contract(l, x, training=False) + + +class EinsumDenseTest(test_utils.SequenceLayerTest): + """Test behavior of EinsumDense layer.""" + + @parameterized.parameters( + ( + (2, 3, 5), + '...a,ab->...b', + (7,), + '', + (7,), + ), + ( + (2, 3, 5, 7), + '...ab,ac->...cb', + (11, 7), + 'c', + (11, 7), + ), + ( + (2, 3, 5, 7), + '...ab,b->...a', + (None,), + '', + (5,), + ), + ) + def test_einsum_dense( + self, + shape, + equation, + output_shape, + bias_axes, + expected_output_shape, + ): + x = self.random_sequence(*shape) + l = self.sl.EinsumDense.Config( + equation=equation, + output_shape=output_shape, + bias_axes=bias_axes, + name='einsum_dense', + ).make() + l = self.init_layer(l, x) + + self.assertEqual(l.block_size, 1) + self.assertEqual(l.output_ratio, 1) + self.assertEqual(l.name, 'einsum_dense') + self.assertEqual(l.get_output_shape_for_sequence(x), expected_output_shape) + self.verify_contract(l, x, training=False) + + def test_einsum_dense_nonbroadcasting_equation(self): + with self.assertRaises(ValueError): + x = self.random_sequence(2, 3, 4, 5, 6) + l = self.sl.EinsumDense.Config( + equation='btabc,bc->btad', output_shape=[None, 2] + ).make() + l = self.init_layer(l, x) + l.layer(x, training=False) + + def test_einsum_dense_inconsistent_input_shape(self): + x = self.random_sequence(2, 3, 5) + l = self.sl.EinsumDense.Config( + equation='...abc,bc->...ad', output_shape=[None, 2] + ).make() + with self.assertRaises(ValueError): + l = self.init_layer(l, x) + l.layer(x, training=False)