import torch
import torch.nn as nn

from math import sqrt
from torch import Tensor
from torch.nn import Module
from typing import Optional, Tuple, Union

from hyptorch.nn import ToPoincare, FromPoincare
from .activation import HopfieldCore


def _make_c_tensor(hyper_c: float, train_c: bool, device=None, dtype=None) -> torch.Tensor:
    t = torch.tensor(hyper_c, dtype=dtype or torch.float32, device=device)
    if train_c:
        t = nn.Parameter(t.clamp(min=1e-6))   # positive curvature
    return t


class Hyperbolic_Hopfield(Module):
    """
    Module with underlying Hopfield association.
    All curvature-dependent ops (Core / ToPoincare / FromPoincare) share the SAME Tensor handle `self.c`.
    """

    def __init__(self,
                 input_size: Optional[int] = None,
                 hidden_size: Optional[int] = None,
                 output_size: Optional[int] = None,
                 pattern_size: Optional[int] = None,
                 num_heads: int = 1,
                 scaling: Optional[Union[float, Tensor]] = None,
                 update_steps_max: Optional[Union[int, Tensor]] = 0,
                 update_steps_eps: Union[float, Tensor] = 1e-4,

                 normalize_stored_pattern: bool = True,
                 normalize_stored_pattern_affine: bool = True,
                 normalize_stored_pattern_eps: float = 1e-5,
                 normalize_state_pattern: bool = True,
                 normalize_state_pattern_affine: bool = True,
                 normalize_state_pattern_eps: float = 1e-5,

                 stored_pattern_as_static: bool = False,
                 state_pattern_as_static: bool = False,

                 stored_pattern_size: Optional[int] = None,

                 batch_first: bool = True,
                 association_activation: Optional[str] = None,
                 dropout: float = 0.0,
                 input_bias: bool = True,
                 concat_bias_pattern: bool = False,
                 add_zero_association: bool = False,
                 disable_out_projection: bool = False,

                 # Hyperbolic config
                 hyper_c: float = 1.0,
                 train_c: bool = False,
                 train_x: bool = False,
                 clip_r: float = 0.9,
                 input_as_hyper: bool = True,
                 out_as_hyper: bool = True,

                 # Hopfield core extras
                 theta: float = 1.0,
                 lr: float = 1e-3,
                 ):
        super().__init__()
        assert type(batch_first) == bool, f'"batch_first" needs to be a boolean, not {type(batch_first)}.'
        assert (association_activation is None) or (type(association_activation) == str)

        # 1) One shared curvature tensor
        self.c = _make_c_tensor(hyper_c, train_c)

        # 2) Hopfield core uses the SAME handle (no c in forward)
        self.association_core = HopfieldCore(
            embed_dim=input_size, num_heads=num_heads, dropout=dropout, bias=input_bias,
            add_bias_k=concat_bias_pattern, add_zero_attn=add_zero_association, kdim=stored_pattern_size,
            head_dim=hidden_size, pattern_dim=pattern_size, out_dim=output_size,
            disable_out_projection=disable_out_projection,
            key_as_static=stored_pattern_as_static,
            query_as_static=state_pattern_as_static,
            theta=theta, lr=lr, c=self.c
        )

        self.association_activation = None
        if association_activation is not None:
            self.association_activation = getattr(torch, association_activation, None)

        # 3) Optional norms
        self.norm_stored_pattern = None
        if normalize_stored_pattern_affine:
            assert normalize_stored_pattern, "affine normalization without normalization has no effect."
        if normalize_stored_pattern:
            normalized_shape = input_size if stored_pattern_size is None else stored_pattern_size
            assert normalized_shape is not None, "stored pattern size required for setting up normalisation"
            self.norm_stored_pattern = nn.LayerNorm(
                normalized_shape=normalized_shape, elementwise_affine=normalize_stored_pattern_affine,
                eps=normalize_stored_pattern_eps)

        self.norm_state_pattern = None
        if normalize_state_pattern_affine:
            assert normalize_state_pattern, "affine normalization without normalization has no effect."
        if normalize_state_pattern:
            assert input_size is not None, "input size required for setting up normalisation"
            self.norm_state_pattern = nn.LayerNorm(
                normalized_shape=input_size, elementwise_affine=normalize_state_pattern_affine,
                eps=normalize_state_pattern_eps)

        # 4) Scaling helper
        if self.association_core.static_execution:
            self.__scaling = 1.0 if scaling is None else scaling
        else:
            assert self.association_core.head_dim > 0, f'invalid hidden dimension encountered.'
            self.__scaling = (1.0 / sqrt(self.association_core.head_dim)) if scaling is None else scaling

        self.__batch_first = batch_first
        self.__update_steps_max = update_steps_max
        self.__update_steps_eps = update_steps_eps

        # 5) Poincaré mappers (will be bound to the SAME c handle in _sync_curvature)
        self.to_poincare = ToPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                      ball_dim=input_size,
                                      riemannian=True, clip_r=clip_r)
        self.from_Poincare = FromPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                          ball_dim=input_size)

        self.input_as_hyper = input_as_hyper
        self.out_as_hyper = out_as_hyper
        self.reset_parameters()
        self._sync_curvature()   # bind handles once at init

    # --- curvature sync: make all sub-modules share the SAME Tensor handle self.c ---
    def _sync_curvature(self):
        # Make sure everyone READS the very same Tensor object (no .data copy, preserve autograd)
        self.association_core.c = self.c
        if hasattr(self.to_poincare, 'c'):
            self.to_poincare.c = self.c
        if hasattr(self.from_Poincare, 'c'):
            self.from_Poincare.c = self.c

    def reset_parameters(self) -> None:
        for module in (self.association_core, self.norm_stored_pattern, self.norm_state_pattern):
            if hasattr(module, r'reset_parameters'):
                module.reset_parameters()

    def _maybe_transpose(self, *args: Tuple[Tensor, ...]) -> Union[Tensor, Tuple[Tensor, ...]]:
        transposed_result = tuple(_.transpose(0, 1) for _ in args) if self.__batch_first else args
        return transposed_result[0] if len(transposed_result) == 1 else transposed_result

    def _associate(self, data: Union[Tensor, Tuple[Tensor, Tensor]],
                   return_raw_associations: bool = False,
                   stored_pattern_padding_mask: Optional[Tensor] = None,
                   association_mask: Optional[Tensor] = None) -> Tuple[Optional[Tensor], ...]:

        # keep curvature handles synced in case c was updated by optimizer
        self._sync_curvature()

        assert (type(data) == Tensor) or ((type(data) == tuple) and (len(data) == 2)), \
            r'either one tensor to be used as "stored pattern", "state pattern" must be provided.'
        if type(data) == Tensor:
            stored_pattern, state_pattern = data, data
        else:
            stored_pattern, state_pattern = data

        stored_pattern, state_pattern = self._maybe_transpose(stored_pattern, state_pattern)

        if not self.input_as_hyper:
            # Optional norms in Euclidean coords
            if self.norm_stored_pattern is not None:
                stored_pattern = self.norm_stored_pattern(
                    input=stored_pattern.reshape(shape=(-1, stored_pattern.shape[2]))
                ).reshape(shape=stored_pattern.shape)
            if self.norm_state_pattern is not None:
                state_pattern = self.norm_state_pattern(
                    input=state_pattern.reshape(shape=(-1, state_pattern.shape[2]))
                ).reshape(shape=state_pattern.shape)

            # Map both to the ball with shared c
            stored_pattern = self.to_poincare(stored_pattern)
            state_pattern = self.to_poincare(state_pattern)

        # HopfieldCore forward DOES NOT take c; it reads the shared handle set above
        return self.association_core(
            query=state_pattern, key=stored_pattern,
            key_padding_mask=stored_pattern_padding_mask, need_weights=False, attn_mask=association_mask,
            update_steps_max=self.__update_steps_max, update_steps_eps=self.__update_steps_eps,
            return_raw_associations=return_raw_associations
        )

    def forward(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
                stored_pattern_padding_mask: Optional[Tensor] = None,
                association_mask: Optional[Tensor] = None) -> Tensor:

        assoc_out = self._maybe_transpose(self._associate(
            data=input, return_raw_associations=False,
            stored_pattern_padding_mask=stored_pattern_padding_mask,
            association_mask=association_mask
        )[0])

        if not self.out_as_hyper:
            assoc_out = self.from_Poincare(assoc_out)
            if self.association_activation is not None:
                assoc_out = self.association_activation(assoc_out)

        return assoc_out

    def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
                               stored_pattern_padding_mask: Optional[Tensor] = None,
                               association_mask: Optional[Tensor] = None) -> Tensor:
        with torch.no_grad():
            return self._associate(
                data=input, return_raw_associations=True,
                stored_pattern_padding_mask=stored_pattern_padding_mask,
                association_mask=association_mask
            )[2]

    # --- properties ---
    @property
    def batch_first(self) -> bool:
        return self.__batch_first

    @property
    def scaling(self) -> Union[float, Tensor]:
        return self.__scaling.clone() if isinstance(self.__scaling, Tensor) else self.__scaling

    @property
    def stored_pattern_dim(self) -> Optional[int]:
        return self.association_core.kdim

    @property
    def state_pattern_dim(self) -> Optional[int]:
        return self.association_core.embed_dim

    @property
    def input_size(self) -> Optional[int]:
        return self.state_pattern_dim

    @property
    def hidden_size(self) -> Optional[int]:
        return self.association_core.head_dim

    @property
    def output_size(self) -> Optional[int]:
        return self.association_core.out_dim

    @property
    def pattern_size(self) -> Optional[int]:
        return self.association_core.pattern_dim

    @property
    def update_steps_max(self) -> Optional[Union[int, Tensor]]:
        return self.__update_steps_max.clone() if isinstance(self.__update_steps_max, Tensor) else self.__update_steps_max

    @property
    def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
        return self.__update_steps_eps.clone() if isinstance(self.__update_steps_eps, Tensor) else self.__update_steps_eps

    @property
    def stored_pattern_as_static(self) -> bool:
        return self.association_core.key_as_static

    @property
    def state_pattern_as_static(self) -> bool:
        return self.association_core.query_as_static

    @property
    def normalize_stored_pattern(self) -> bool:
        return self.norm_stored_pattern is not None

    @property
    def normalize_stored_pattern_affine(self) -> bool:
        return self.normalize_stored_pattern and self.norm_stored_pattern.elementwise_affine

    @property
    def normalize_state_pattern(self) -> bool:
        return self.norm_state_pattern is not None

    @property
    def normalize_state_pattern_affine(self) -> bool:
        return self.normalize_state_pattern and self.norm_state_pattern.elementwise_affine


class Hyperbolic_HopfieldPooling(Module):
    """
    Pooling wrapper with a trainable state pattern and Hopfield.
    Curvature is shared via the same Tensor handle `self.c`.
    """

    def __init__(self,
                 input_size: int,
                 hidden_size: Optional[int] = None,
                 output_size: Optional[int] = None,
                 pattern_size: Optional[int] = None,
                 num_heads: int = 1,
                 scaling: Optional[Union[float, Tensor]] = None,
                 update_steps_max: Optional[Union[int, Tensor]] = 0,
                 update_steps_eps: Union[float, Tensor] = 1e-4,
                 normalize_stored_pattern: bool = True,
                 normalize_stored_pattern_affine: bool = True,
                 normalize_state_pattern: bool = True,
                 normalize_state_pattern_affine: bool = True,
                 stored_pattern_as_static: bool = False,
                 state_pattern_as_static: bool = False,
                 stored_pattern_size: Optional[int] = None,
                 batch_first: bool = True,
                 association_activation: Optional[str] = None,
                 dropout: float = 0.0,
                 input_bias: bool = True,
                 concat_bias_pattern: bool = False,
                 add_zero_association: bool = False,
                 disable_out_projection: bool = False,
                 quantity: int = 1,
                 trainable: bool = True,

                 # hyperbolic
                 hyper_c: float = 1.0,
                 train_c: bool = False,
                 train_x: bool = False,
                 clip_r: float = 0.98,
                 input_as_hyper: bool = True,
                 out_as_hyper: bool = True,

                 theta: float = 1.0,
                 lr: float = 1e-3,
                 ):
        super().__init__()

        # Shared curvature tensor
        self.c = _make_c_tensor(hyper_c, train_c)

        self.hopfield = Hyperbolic_Hopfield(
            input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
            num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
            normalize_stored_pattern=normalize_stored_pattern,
            normalize_stored_pattern_affine=normalize_stored_pattern_affine,
            normalize_state_pattern=normalize_state_pattern,
            normalize_state_pattern_affine=normalize_state_pattern_affine,
            stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
            stored_pattern_size=stored_pattern_size,
            batch_first=batch_first,
            association_activation=association_activation, dropout=dropout, input_bias=input_bias,
            concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
            disable_out_projection=disable_out_projection,

            # pass the SAME tensor to inner Hopfield
            hyper_c=hyper_c, train_c=False,   # inner Hopfield will hold its own Parameter; we'll override handle
            train_x=train_x, clip_r=clip_r, input_as_hyper=input_as_hyper, out_as_hyper=True,
            theta=theta, lr=lr
        )
        # replace inner Hopfield c with our handle
        self.hopfield.c = self.c
        self.hopfield._sync_curvature()

        self._quantity = quantity
        pooling_weight_size = self.hopfield.hidden_size if state_pattern_as_static else self.hopfield.input_size
        self.pooling_weights = nn.Parameter(torch.empty(
            size=(((1, quantity) if batch_first else (quantity, 1)) +
                  (input_size if pooling_weight_size is None else pooling_weight_size,))),
            requires_grad=trainable
        )

        # Mappers share the SAME c
        self.to_poincare = ToPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                      ball_dim=input_size if pooling_weight_size is None else pooling_weight_size,
                                      riemannian=True, clip_r=clip_r)
        self.from_Poincare = FromPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                          ball_dim=input_size if pooling_weight_size is None else pooling_weight_size)
        self.input_as_hyper = input_as_hyper
        self.out_as_hyper = out_as_hyper
        self.reset_parameters()
        self._sync_curvature()

    def _sync_curvature(self):
        # Bind mapper handles
        if hasattr(self.to_poincare, 'c'):
            self.to_poincare.c = self.c
        if hasattr(self.from_Poincare, 'c'):
            self.from_Poincare.c = self.c
        # Bind inner hopfield handle (and its submodules)
        self.hopfield.c = self.c
        self.hopfield._sync_curvature()

    def reset_parameters(self) -> None:
        if hasattr(self.hopfield, r'reset_parameters'):
            self.hopfield.reset_parameters()
        nn.init.normal_(self.pooling_weights, mean=0.0, std=0.02)

    def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        batch_size = input.shape[0 if self.batch_first else 1]
        if self.input_as_hyper:
            return input, self.to_poincare(self.pooling_weights.expand(size=(*(
                (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
            ), self.pooling_weights.shape[2])))
        else:
            return input, self.pooling_weights.expand(size=(*(
                (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
            ), self.pooling_weights.shape[2]))

    def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
                association_mask: Optional[Tensor] = None) -> Tensor:
        self._sync_curvature()
        association_output = self.hopfield(
            input=self._prepare_input(input=input),
            stored_pattern_padding_mask=stored_pattern_padding_mask,
            association_mask=association_mask).flatten(start_dim=1)

        if not self.out_as_hyper:
            association_output = self.from_Poincare(association_output)

        return association_output

    def get_association_matrix(self, input: Union[Tensor, Tuple[Tensor, Tensor]],
                               stored_pattern_padding_mask: Optional[Tensor] = None,
                               association_mask: Optional[Tensor] = None) -> Tensor:
        self._sync_curvature()
        with torch.no_grad():
            return self.hopfield.get_association_matrix(
                input=self._prepare_input(input=input),
                stored_pattern_padding_mask=stored_pattern_padding_mask,
                association_mask=association_mask)

    @property
    def batch_first(self) -> bool:
        return self.hopfield.batch_first

    @property
    def scaling(self) -> Union[float, Tensor]:
        return self.hopfield.scaling

    @property
    def stored_pattern_dim(self) -> Optional[int]:
        return self.hopfield.stored_pattern_dim

    @property
    def state_pattern_dim(self) -> Optional[int]:
        return self.hopfield.state_pattern_dim

    @property
    def input_size(self) -> Optional[int]:
        return self.hopfield.input_size

    @property
    def hidden_size(self) -> int:
        return self.hopfield.hidden_size

    @property
    def output_size(self) -> Optional[int]:
        return self.hopfield.output_size

    @property
    def quantity(self) -> int:
        return self._quantity

    @property
    def update_steps_max(self) -> Optional[Union[int, Tensor]]:
        return self.hopfield.update_steps_max

    @property
    def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
        return self.hopfield.update_steps_eps

    @property
    def stored_pattern_as_static(self) -> bool:
        return self.hopfield.stored_pattern_as_static

    @property
    def state_pattern_as_static(self) -> bool:
        return self.hopfield.state_pattern_as_static

    @property
    def normalize_stored_pattern(self) -> bool:
        return self.hopfield.normalize_stored_pattern

    @property
    def normalize_stored_pattern_affine(self) -> bool:
        return self.hopfield.normalize_stored_pattern_affine

    @property
    def normalize_state_pattern(self) -> bool:
        return self.hopfield.normalize_state_pattern

    @property
    def normalize_state_pattern_affine(self) -> bool:
        return self.hopfield.normalize_state_pattern_affine


class Hyperbolic_HopfieldLayer(Module):
    """
    Lookup wrapper with trainable stored patterns.
    Curvature is shared via the same Tensor handle `self.c`.
    """

    def __init__(self,
                 input_size: int,
                 hidden_size: Optional[int] = None,
                 output_size: Optional[int] = None,
                 pattern_size: Optional[int] = None,
                 num_heads: int = 1,
                 scaling: Optional[Union[float, Tensor]] = None,
                 update_steps_max: Optional[Union[int, Tensor]] = 0,
                 update_steps_eps: Union[float, Tensor] = 1e-4,

                 normalize_stored_pattern: bool = True,
                 normalize_stored_pattern_affine: bool = True,
                 normalize_state_pattern: bool = True,
                 normalize_state_pattern_affine: bool = True,
                 stored_pattern_as_static: bool = False,
                 state_pattern_as_static: bool = False,

                 stored_pattern_size: Optional[int] = None,

                 batch_first: bool = True,
                 association_activation: Optional[str] = None,
                 dropout: float = 0.0,
                 input_bias: bool = True,
                 concat_bias_pattern: bool = False,
                 add_zero_association: bool = False,
                 disable_out_projection: bool = False,
                 quantity: int = 1,
                 trainable: bool = True,

                 # hyperbolic
                 hyper_c: float = 1.0,
                 train_c: bool = False,
                 train_x: bool = False,
                 clip_r: float = 0.98,
                 input_as_hyper: bool = True,
                 out_as_hyper: bool = True,

                 theta: float = 1.0,
                 lr: float = 1e-3,
                 ):
        super().__init__()
        self.c = _make_c_tensor(hyper_c, train_c)

        self.hopfield = Hyperbolic_Hopfield(
            input_size=input_size, hidden_size=hidden_size, output_size=output_size, pattern_size=pattern_size,
            num_heads=num_heads, scaling=scaling, update_steps_max=update_steps_max, update_steps_eps=update_steps_eps,
            normalize_stored_pattern=normalize_stored_pattern,
            normalize_stored_pattern_affine=normalize_stored_pattern_affine,
            normalize_state_pattern=normalize_state_pattern,
            normalize_state_pattern_affine=normalize_state_pattern_affine,
            stored_pattern_as_static=stored_pattern_as_static, state_pattern_as_static=state_pattern_as_static,
            stored_pattern_size=stored_pattern_size,
            batch_first=batch_first,
            association_activation=association_activation, dropout=dropout, input_bias=input_bias,
            concat_bias_pattern=concat_bias_pattern, add_zero_association=add_zero_association,
            disable_out_projection=disable_out_projection,

            hyper_c=hyper_c, train_c=False,  # inner Hopfield will be rebound to our handle
            train_x=train_x, clip_r=clip_r, input_as_hyper=input_as_hyper, out_as_hyper=True,
            theta=theta, lr=lr
        )
        self.hopfield.c = self.c
        self.hopfield._sync_curvature()

        self._quantity = quantity
        lookup_weight_size = self.hopfield.hidden_size if stored_pattern_as_static else self.hopfield.stored_pattern_dim
        self.lookup_weights = nn.Parameter(torch.empty(
            size=(((1, quantity) if batch_first else (quantity, 1)) +
                  (input_size if lookup_weight_size is None else lookup_weight_size,))),
            requires_grad=trainable
        )

        self.to_poincare = ToPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                      ball_dim=input_size if lookup_weight_size is None else lookup_weight_size,
                                      riemannian=True, clip_r=clip_r)
        self.from_Poincare = FromPoincare(c=hyper_c, train_c=False, train_x=train_x,
                                          ball_dim=input_size if lookup_weight_size is None else lookup_weight_size)
        self.input_as_hyper = input_as_hyper
        self.out_as_hyper = out_as_hyper
        self.reset_parameters()
        self._sync_curvature()

    def _sync_curvature(self):
        if hasattr(self.to_poincare, 'c'):
            self.to_poincare.c = self.c
        if hasattr(self.from_Poincare, 'c'):
            self.from_Poincare.c = self.c
        self.hopfield.c = self.c
        self.hopfield._sync_curvature()

    def reset_parameters(self) -> None:
        if hasattr(self.hopfield, r'reset_parameters'):
            self.hopfield.reset_parameters()
        nn.init.normal_(self.lookup_weights, mean=0.0, std=0.02)

    def _prepare_input(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        batch_size = input.shape[0 if self.batch_first else 1]
        if self.input_as_hyper:
            stored_pattern = self.to_poincare(self.lookup_weights.expand(size=(*(
                (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
            ), self.lookup_weights.shape[2])))
        else:
            stored_pattern = self.lookup_weights.expand(size=(*(
                (batch_size, self.quantity) if self.batch_first else (self.quantity, batch_size)
            ), self.lookup_weights.shape[2]))
        return stored_pattern, input

    def forward(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
                association_mask: Optional[Tensor] = None) -> Tensor:
        self._sync_curvature()
        association_output = self.hopfield(
            input=self._prepare_input(input=input),
            stored_pattern_padding_mask=stored_pattern_padding_mask,
            association_mask=association_mask)
        if not self.out_as_hyper:
            association_output = self.from_Poincare(association_output)
        return association_output

    def get_association_matrix(self, input: Tensor, stored_pattern_padding_mask: Optional[Tensor] = None,
                               association_mask: Optional[Tensor] = None) -> Tensor:
        self._sync_curvature()
        with torch.no_grad():
            return self.hopfield.get_association_matrix(
                input=self._prepare_input(input=input),
                stored_pattern_padding_mask=stored_pattern_padding_mask,
                association_mask=association_mask)

    @property
    def batch_first(self) -> bool:
        return self.hopfield.batch_first

    @property
    def scaling(self) -> Union[float, Tensor]:
        return self.hopfield.scaling

    @property
    def stored_pattern_dim(self) -> Optional[int]:
        return self.hopfield.stored_pattern_dim

    @property
    def state_pattern_dim(self) -> Optional[int]:
        return self.hopfield.state_pattern_dim

    @property
    def input_size(self) -> Optional[int]:
        return self.hopfield.input_size

    @property
    def hidden_size(self) -> int:
        return self.hopfield.hidden_size

    @property
    def output_size(self) -> Optional[int]:
        return self.hopfield.output_size

    @property
    def pattern_size(self) -> Optional[int]:
        return self.hopfield.pattern_size

    @property
    def quantity(self) -> int:
        return self._quantity

    @property
    def update_steps_max(self) -> Optional[Union[int, Tensor]]:
        return self.hopfield.update_steps_max

    @property
    def update_steps_eps(self) -> Optional[Union[float, Tensor]]:
        return self.hopfield.update_steps_eps

    @property
    def stored_pattern_as_static(self) -> bool:
        return self.hopfield.stored_pattern_as_static

    @property
    def state_pattern_as_static(self) -> bool:
        return self.hopfield.state_pattern_as_static

    @property
    def normalize_stored_pattern(self) -> bool:
        return self.hopfield.normalize_stored_pattern

    @property
    def normalize_stored_pattern_affine(self) -> bool:
        return self.hopfield.normalize_stored_pattern_affine

    @property
    def normalize_state_pattern(self) -> bool:
        return self.hopfield.normalize_state_pattern

    @property
    def normalize_state_pattern_affine(self) -> bool:
        return self.hopfield.normalize_state_pattern_affine
