Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ disable = [
"too-many-positional-arguments",
"too-many-public-methods",
"too-many-return-statements",
"too-many-instance-attributes",
]


Expand Down
17 changes: 11 additions & 6 deletions sequence_layers/jax/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@

import dataclasses
import typing
from typing import Callable
from typing import Callable, override

import flax.linen as nn
import jax
import jax.numpy as jnp

from sequence_layers.jax import meta
from sequence_layers.jax import types
from sequence_layers.jax import utils

from sequence_layers.specs import dense as spec

__all__ = (
# go/keep-sorted start
Expand All @@ -34,11 +35,11 @@
)


class Dense(types.Stateless, utils.EinsumCommon):
class Dense(types.Stateless, utils.EinsumCommon, spec.Dense):
"""A basic dense layer."""

@dataclasses.dataclass(frozen=True)
class Config(types.SequenceLayerConfig):
class Config(types.SequenceLayerConfig, spec.Dense.Config):
"""Dense config."""

# The number of output features for the dense layer.
Expand Down Expand Up @@ -73,6 +74,8 @@ class Config(types.SequenceLayerConfig):
def make(self) -> 'Dense':
return Dense(self, name=self.name)



config: Config

@nn.nowrap
Expand Down Expand Up @@ -269,7 +272,7 @@ def layer(
)


class EinsumDense(types.Stateless, utils.EinsumCommon):
class EinsumDense(types.Stateless, utils.EinsumCommon, spec.EinsumDense):
"""A dense layer that transforms the channel shape with an einsum equation.

Equation input and output specs must have leading ellipses to broadcast over
Expand All @@ -291,7 +294,7 @@ class EinsumDense(types.Stateless, utils.EinsumCommon):
"""

@dataclasses.dataclass(frozen=True)
class Config(types.SequenceLayerConfig):
class Config(types.SequenceLayerConfig, spec.EinsumDense.Config):
"""EinsumDense config."""

# An equation describing the einsum to perform. This equation must be a
Expand Down Expand Up @@ -338,6 +341,8 @@ def __post_init__(self):
def make(self) -> 'EinsumDense':
return EinsumDense(self, name=self.name)



config: Config

@nn.nowrap
Expand Down
81 changes: 21 additions & 60 deletions sequence_layers/jax/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,14 @@
import flax.linen as nn
import jax
import jax.numpy as jnp

from sequence_layers.jax import dense
from sequence_layers.jax import test_utils
from sequence_layers.jax import types
from sequence_layers.specs import dense_behaviors as spec


class DenseTest(test_utils.SequenceLayerTest):

def test_rank2_unsupported(self):
key = jax.random.PRNGKey(1234)
l = dense.Dense.Config(
3, bias_init=nn.initializers.normal(), name='dense'
).make()
x = test_utils.random_sequence(2, 13)
with self.assertRaises(ValueError):
self.init_and_bind_layer(key, l, x)
class DenseTest(test_utils.SequenceLayerTest, spec.DenseTest):

@parameterized.parameters(((5,),), ((5, 7),))
def test_dense(self, channels_shape):
Expand All @@ -49,7 +42,7 @@ def test_dense(self, channels_shape):
self.assertEqual(
l.get_output_shape_for_sequence(x), channels_shape[:-1] + (3,)
)
self.verify_contract(l, x, training=False, grad_rtol=1e-5, grad_atol=1e-5)
self.verify_contract(l, x, training=False, rtol=1e-5, atol=1e-5, grad_rtol=1e-5, grad_atol=1e-5)

chex.assert_trees_all_equal_shapes_and_dtypes(
flax.core.meta.unbox(l.variables),
Expand All @@ -61,17 +54,6 @@ def test_dense(self, channels_shape):
},
)

@parameterized.parameters(True, False)
def test_use_bias(self, use_bias):
"""Check that use_bias controls whether a bias is created."""
key = jax.random.PRNGKey(1234)
l = dense.Dense.Config(3, use_bias=use_bias).make()
x = test_utils.random_sequence(2, 3, 5)
l = self.init_and_bind_layer(key, l, x)
self.assertCountEqual(
l.variables['params'], ['kernel', 'bias'] if use_bias else ['kernel']
)

def test_use_einsum_factory(self):
"""Check that einsum_factory produces is used for dense einsum."""

Expand Down Expand Up @@ -254,7 +236,7 @@ def test_dtypes(self, param_dtype, input_dtype, compute_dtype, use_bias):
)


class EinsumDenseTest(test_utils.SequenceLayerTest):
class EinsumDenseTest(test_utils.SequenceLayerTest, spec.EinsumDenseTest):

@parameterized.parameters(
(
Expand Down Expand Up @@ -461,22 +443,22 @@ def custom_einsum(equation, *args, **kwargs):
@parameterized.product(
test_utils.standard_dtype_configs(),
(
dict(
shape=(2, 3, 5, 7, 11),
equation='...abc,bd->...bd',
output_shape=(None, 13),
expected_kernel_shape=(7, 13),
bias_axes='',
expected_bias_shape=None,
),
dict(
shape=(2, 3, 5),
equation='...a,abcd->...bcd',
output_shape=(7, 11, 13),
expected_kernel_shape=(5, 7, 11, 13),
bias_axes='cd',
expected_bias_shape=(11, 13),
),
{
'shape': (2, 3, 5, 7, 11),
'equation': '...abc,bd->...bd',
'output_shape': (None, 13),
'expected_kernel_shape': (7, 13),
'bias_axes': '',
'expected_bias_shape': None,
},
{
'shape': (2, 3, 5),
'equation': '...a,abcd->...bcd',
'output_shape': (7, 11, 13),
'expected_kernel_shape': (5, 7, 11, 13),
'bias_axes': 'cd',
'expected_bias_shape': (11, 13),
},
),
)
def test_dtypes(
Expand Down Expand Up @@ -536,27 +518,6 @@ def test_dtypes(
).mask_invalid()
self.assertSequencesClose(y, y_expected)

def test_einsum_dense_nonbroadcasting_equation(self):
with self.assertRaises(ValueError):
key = jax.random.PRNGKey(1234)
x = test_utils.random_sequence(2, 3, 4, 5, 6)
l = dense.EinsumDense.Config(
'btabc,bc->btad', output_shape=[None, 2]
).make()
self.init_and_bind_layer(key, l, x)

def test_einsum_dense_inconsistent_input_shape(self):
key = jax.random.PRNGKey(1234)
x = test_utils.random_sequence(2, 3, 5)
l = dense.EinsumDense.Config(
'...abc,bc->...ad', output_shape=[None, 2]
).make()
with self.assertRaises(ValueError):
self.init_and_bind_layer(key, l, x)
# Show it works with the right input shape.
x = test_utils.random_sequence(2, 3, 5, 7, 11)
self.assertEqual(l.get_output_shape_for_sequence(x), (5, 2))


if __name__ == '__main__':
test_utils.main()
6 changes: 6 additions & 0 deletions sequence_layers/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Sequence layers in MLX."""

from . import backend
from . import dense
from . import simple
from . import types
# CRITICAL: Do NOT use wildcard imports (e.g., `from .simple import *`) here.
Expand All @@ -26,6 +27,8 @@
# Explicit imports (e.g., `from .simple import Relu`) DO NOT trigger this issue.
# If you need to expose specific layers at the package level, import them
# explicitly instead of using a star import.
from .dense import Dense
from .dense import EinsumDense
from .simple import Abs
from .simple import Add
from .simple import Cast
Expand Down Expand Up @@ -65,13 +68,16 @@
from .types import SequenceLayerConfig

__all__ = [
'dense',
'backend',
'types',
'simple',
'Sequence',
'MaskedSequence',
'SequenceLayer',
'SequenceLayerConfig',
'Dense',
'EinsumDense',
'Identity',
'Relu',
'Gelu',
Expand Down
Loading