from flax import linen as nn
import flax
import jax
import jax.numpy as jnp
from functools import partial
from LRU.lru.model import LRU, SequenceLayer
from dataclasses import field
import numpy as np

from LRU.lru.model import (
    theta_init,
    matrix_init,
    nu_init,
    binary_operator_diag,
    gamma_log_init
)

jax.config.update("jax_default_matmul_precision", "high")
#jax.config.update("jax_enable_x64", True)

parallel_scan = jax.lax.associative_scan

bmm = jax.vmap(jnp.matmul, in_axes=(0,0))
def binary_operator_nondiag(q_i, q_j):
    """Binary operator for parallel scan of linear recurrence"""
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j @ A_i, bmm(A_j, b_i) + b_j

class LRU_CIFAR10(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension
    d_model: int  # input and output dimensions
    r_min: float=0.0  # smallest lambda norm
    r_max: float=1.0  # largest lambda norm
    max_phase: float=6.28  # max phase lambda
    nu_coeff: float=1.0
    nu_init: float=0.25

    def setup(self):
        self.theta_log=self.param(
            "theta_log", partial(
                theta_init, max_phase=self.max_phase), (self.d_hidden,)
        )
        self.nu_log = self.param(
            "nu_log", jax.nn.initializers.constant(self.nu_init),
            (self.d_hidden,)
        )

        self.gamma_log=self.param(
            "gamma_log", gamma_log_init, (self.nu_log, self.theta_log))

        # Glorot initialized Input/Output projection matrices
        self.B_re=self.param(
            "B_re",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.B_im=self.param(
            "B_im",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.C_re=self.param(
            "C_re",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.C_im=self.param(
            "C_im",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.D=self.param("D", matrix_init, (self.d_model,))

    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        # diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        # + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))
        diag_lambda = 1 / (self.nu_coeff + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        # B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        B_norm=self.B_re# * jnp.expand_dims(jnp.exp(self.gamma_log),
                        #                      axis=-1)
        # C = self.C_re + 1j * self.C_im
        C=self.C_re  # + 1j * self.C_im

        Lambda_elements=jnp.repeat(diag_lambda[None, ...],
                                   inputs.shape[0], axis=0)
        Bu_elements=jax.vmap(lambda u: B_norm @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_diag,
                                       (Lambda_elements, Bu_elements))
        hidden_states = jnp.vstack([jnp.zeros((1, hidden_states.shape[-1])),
                                    hidden_states[:-1]])
        # Use them to compute the output of the module
        outputs=jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
                                                                 inputs)

        return outputs


class SequenceLayer_CIFAR10(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float=0.0  # dropout probability
    norm: str="layer"  # which normalization to use
    training: bool=True  # in training mode (dropout in trainign mode only)
    u_T: int = 0

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq=self.lru()
        self.out1=nn.Dense(self.d_model, use_bias=False)
        self.out2=nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization=nn.LayerNorm()
        else:
            self.normalization=nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
        self.drop=nn.Dropout(self.dropout, broadcast_dims=[0],
                             deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x=inputs
        x=self.seq(x)  # call LRU
        # x = self.drop(nn.gelu(x))
        # x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        # x = self.drop(self.out1(x))
        x=nn.relu(x)
        #x=self.drop(x)
        #x = self.out2(x)
        # x = self.drop(x)
        # return inputs + x  # skip connection
        return x  # skip connection

class LRU_small_generator(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int = 2  # hidden state dimension
    d_model: int  = 1 # input and output dimensions
    r_min: float=0.0  # smallest lambda norm
    r_max: float=1.0  # largest lambda norm
    max_phase: float=6.28  # max phase lambda
    scale: float = 0.5

    def setup(self):
        self.theta_1 = self.param("theta_1", nn.initializers.normal(0.001), (1,))
        self.theta_2 = self.param("theta_2", nn.initializers.normal(0.001), (1,))
        self.B = jnp.array([[0], [1]])
        self.C = jnp.array([[1, 0]])

    def __call__(self, inputs):
        # inputs shape (T, n_in (input dim))
        # B @ inputs shape" (T, d_hidden)
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""

        a = (jnp.exp(-self.theta_1**2) * self.scale).item()
        b = (jnp.exp(-self.theta_2**2) * self.scale).item()
        A = jnp.array([[0,1], [-a * b, a + b]])
        A = jnp.repeat(A[None, ...], inputs.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: self.B @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_nondiag,
                                       (A, Bu_elements))
        #jax.debug.print("inputs: {}", inputs)
        #jax.debug.print("hidden_states: {}", hidden_states)
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda h: (self.C @ h))(hidden_states)
        #jax.debug.print("outputs: {}", outputs)

        return outputs

class LRU_small_learner(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int = 2  # hidden state dimension
    d_model: int  = 1 # input and output dimensions
    r_min: float=0.0  # smallest lambda norm
    r_max: float=1.0  # largest lambda norm
    max_phase: float=6.28  # max phase lambda
    scale: float = 0.5

    def setup(self):
        def random_init(mean, stdev):
            init = lambda key, shape: mean + stdev * jax.random.normal(key, shape)
            return init
        self.A = self.param("A", random_init(0, 0.07), (2,2))
        self.B = self.param("B", random_init(0, 0.3), (2,1))
        self.C = self.param("C", random_init(0, 0.3), (1,2))

    def __call__(self, inputs):
        # inputs shape (T, n_in (input dim))
        # B @ inputs shape" (T, d_hidden)
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""

        A = jnp.repeat(self.A[None, ...], inputs.shape[0], axis=0)
        Bu_elements=jax.vmap(lambda u: self.B @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_nondiag,
                                       (A, Bu_elements))
        #jax.debug.print("inputs: {}", inputs)
        #jax.debug.print("hidden_states: {}", hidden_states)
        # Use them to compute the output of the module
        outputs=jax.vmap(lambda h: (self.C @ h))(hidden_states)
        #jax.debug.print("outputs: {}", outputs)

        return outputs

class SequenceLayer_small(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float=0.0  # dropout probability
    norm: str="layer"  # which normalization to use
    training: bool=True  # in training mode (dropout in trainign mode only)

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq=self.lru()
        self.out1=nn.Dense(self.d_model, use_bias=False)
        self.out2=nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization=nn.LayerNorm()
        else:
            self.normalization=nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
#        self.drop=nn.Dropout(self.dropout, broadcast_dims=[0],
#                             deterministic=not self.training)
#
    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x=inputs
        x=self.seq(x)  # call LRU
        # x = self.drop(nn.gelu(x))
        # x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        # x = self.drop(self.out1(x))
        #x=nn.relu(x)
        #x=self.drop(x)
        # x = self.out2(x)
        # x = self.drop(x)
        # return inputs + x  # skip connection
        return x  # skip connection

class LRUCircle(LRU):
    def setup(self):
        self.nu_log = self.param(
            "nu_log", jax.nn.initializers.constant(1),
            (self.d_hidden,)
        )
        super().setup()

    def __call__(self, inputs):
        diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
                                              axis=-1)
        #C = self.C_re + 1j * self.C_im
        C = self.C_re# + 1j * self.C_im

        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
                                     inputs.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
        # Compute hidden states
        _, hidden_states = parallel_scan(binary_operator_diag,
                                         (Lambda_elements, Bu_elements))
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
                                                                   inputs)

        return outputs


class SequenceLayerCircle(SequenceLayer):
    def __call__(self, inputs):
        x = inputs
        x = self.seq(x)  # call LRU
        #x = self.drop(self.out1(x))
        x = self.drop(self.out1(x))
        x = nn.relu(x)
        x = self.drop(x)
        x = self.out2(x)
        #return inputs + x  # skip connection
        return x  # skip connection


class LRU_spiral(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    use_D: bool = False # wheter to use D or not
    nu_log_coeff: float = 1.0

    def setup(self):
        self.theta_log = self.param(
            "theta_log", partial(theta_init, max_phase=self.max_phase),
            (self.d_hidden,)
        )
       # if not self.has_variable("params", "nu_log"):
       #     self.nu_log = self.param(
       #         "nu_log", partial(nu_init,
       #                           r_min=self.r_min, r_max=self.r_max),
       #         (self.d_hidden,)
       #     )
       # else:
       #     self.nu_log = self.variables["params"]["nu_log"]
        self.nu_log = self.param(
            "nu_log", jax.nn.initializers.constant(2.25),
            (self.d_hidden,)
        )
        #self.nu_log = self.param(
        #    "nu_log", partial(nu_init,
        #                      r_min=self.r_min, r_max=self.r_max),
        #    (self.d_hidden,)
        #)
        self.gamma_log = self.param("gamma_log", gamma_log_init,
                                    (self.nu_log, self.theta_log))

        # Glorot initialized Input/Output projection matrices
        self.B_re = self.param(
            "B_re",
            partial(matrix_init,
                    normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.B_im = self.param(
            "B_im",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.C_re = self.param(
            "C_re",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.C_im = self.param(
            "C_im",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.D = jnp.zeros((self.d_model,))
        if self.use_D:
            self.D = self.param("D", matrix_init, (self.d_model,))

    #def build_matrices(self):
    #    return 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))


    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        diag_lambda = 1 / (self.nu_log_coeff + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.build_matrices()# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        B_norm = self.B_re # * jnp.expand_dims(jnp.exp(self.gamma_log),
                           #                   axis=-1)
        #C = self.C_re + 1j * self.C_im
        C = self.C_re# + 1j * self.C_im

        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
                                     inputs.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
        # Compute hidden states
        _, hidden_states = parallel_scan(binary_operator_diag,
                                         (Lambda_elements, Bu_elements))
        # this is needed because we want to have our output starting at
        # initial state 0
        hidden_states = jnp.vstack([jnp.zeros((1, hidden_states.shape[-1])),
                                    hidden_states[:-1]])
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
                                                                   inputs)

        return outputs


class SequenceLayer_spiral(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)
    u_T: int = 0

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.lru()
        self.out1 = nn.Dense(self.d_model, use_bias=False)
        self.out2 = nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
        #self.drop = nn.Dropout(self.dropout, broadcast_dims=[0],
        #                       deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x = inputs
        x = self.seq(x)  # call LRU
        #x = self.drop(nn.gelu(x))
        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        #x = self.drop(self.out1(x))
        x = nn.relu(x)
        #x = self.drop(x)
        #x = self.out2(x)
        #x = self.drop(x)
        #return inputs + x  # skip connection
        return x  # skip connection

class LRU_classif_generator(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    use_D: bool = False # wheter to use D or not

    def setup(self):
        self.theta_log = self.param(
            "theta_log", partial(theta_init, max_phase=self.max_phase),
            (self.d_hidden,)
        )
       # if not self.has_variable("params", "nu_log"):
       #     self.nu_log = self.param(
       #         "nu_log", partial(nu_init,
       #                           r_min=self.r_min, r_max=self.r_max),
       #         (self.d_hidden,)
       #     )
       # else:
       #     self.nu_log = self.variables["params"]["nu_log"]
        self.nu_log = self.param(
            "nu_log", jax.nn.initializers.constant(0.35),
            (self.d_hidden,)
        )
        #self.nu_log = self.param(
        #    "nu_log", partial(nu_init,
        #                      r_min=self.r_min, r_max=self.r_max),
        #    (self.d_hidden,)
        #)
        self.gamma_log = self.param("gamma_log", gamma_log_init,
                                    (self.nu_log, self.theta_log))

        # Glorot initialized Input/Output projection matrices
        self.B_re = self.param(
            "B_re",
            partial(matrix_init,
                    normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.B_im = self.param(
            "B_im",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.C_re = self.param(
            "C_re",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.C_im = self.param(
            "C_im",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.D = jnp.zeros((self.d_model,))
        if self.use_D:
            self.D = self.param("D", matrix_init, (self.d_model,))

    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
                                              axis=-1)
        #C = self.C_re + 1j * self.C_im
        C = self.C_re# + 1j * self.C_im

        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
                                     inputs.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
        # Compute hidden states
        _, hidden_states = parallel_scan(binary_operator_diag,
                                         (Lambda_elements, Bu_elements))
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
                                                                   inputs)

        return outputs


class SequenceLayer_classif_generator(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.lru()
        self.out1 = nn.Dense(self.d_model, use_bias=False)
        self.out2 = nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
        #self.drop = nn.Dropout(self.dropout, broadcast_dims=[0],
        #                       deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x = inputs
        x = self.seq(x)  # call LRU
        #x = self.drop(nn.gelu(x))
        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        #x = self.drop(self.out1(x))
        x = nn.relu(x)
        #x = self.drop(x)
        #x = self.out2(x)
        #x = self.drop(x)
        #return inputs + x  # skip connection
        return x  # skip connection

class LRU_classif(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    use_D: bool = False # wheter to use D or not

    def setup(self):
        self.theta_log = self.param(
            "theta_log", partial(theta_init, max_phase=self.max_phase),
            (self.d_hidden,)
        )
       # if not self.has_variable("params", "nu_log"):
       #     self.nu_log = self.param(
       #         "nu_log", partial(nu_init,
       #                           r_min=self.r_min, r_max=self.r_max),
       #         (self.d_hidden,)
       #     )
       # else:
       #     self.nu_log = self.variables["params"]["nu_log"]
        self.nu_log = self.param(
            "nu_log", jax.nn.initializers.constant(0.35),
            (self.d_hidden,)
        )
        #self.nu_log = self.param(
        #    "nu_log", partial(nu_init,
        #                      r_min=self.r_min, r_max=self.r_max),
        #    (self.d_hidden,)
        #)
        self.gamma_log = self.param("gamma_log", gamma_log_init,
                                    (self.nu_log, self.theta_log))

        # Glorot initialized Input/Output projection matrices
        self.B_re = self.param(
            "B_re",
            partial(matrix_init,
                    normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.B_im = self.param(
            "B_im",
            partial(matrix_init, normalization=jnp.sqrt(2 * self.d_model)),
            (self.d_hidden, self.d_model),
        )
        self.C_re = self.param(
            "C_re",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.C_im = self.param(
            "C_im",
            partial(matrix_init, normalization=jnp.sqrt(self.d_hidden)),
            (self.d_model, self.d_hidden),
        )
        self.D = jnp.zeros((self.d_model,))
        if self.use_D:
            self.D = self.param("D", matrix_init, (self.d_model,))

    # inputs shape (T, n_u)
    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
                                              axis=-1)
        #C = self.C_re + 1j * self.C_im
        C = self.C_re# + 1j * self.C_im


        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
                                     inputs.shape[0], axis=0)
        jax.debug.print("Lambda_elements: {}", Lambda_elements)
        jax.debug.print("inputs shape: {}", inputs.shape)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
        # Compute hidden states
        _, hidden_states = parallel_scan(binary_operator_diag,
                                         (Lambda_elements, Bu_elements))
        # Use them to compute the output of the module
        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
                                                                   inputs)

        return outputs


class SequenceLayer_classif(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.lru()
        self.out1 = nn.Dense(self.d_model, use_bias=False)
        self.out2 = nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
        #self.drop = nn.Dropout(self.dropout, broadcast_dims=[0],
        #                       deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x = inputs
        x = self.seq(x)  # call LRU
        #x = self.drop(nn.gelu(x))
        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        #x = self.drop(self.out1(x))
        x = nn.relu(x)
        #x = self.drop(x)
        #x = self.out2(x)
        #x = self.drop(x)
        #return inputs + x  # skip connection
        return x  # skip connection

def gamma_log_init(key, lamb):
    nu, theta = lamb
    diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
    return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))

def matrix_init_LRU_spec_gen(matrix):
    assert matrix in ["A", "B", "C", "D"]
    match matrix:
        case "A":
            def init(key, lambs):
                lamb1, lamb2, _, __ = lambs
                exp1 = jnp.exp(-lamb1.value**2).item()
                exp2 = jnp.exp(-lamb2.value**2).item()
                return jnp.array([[0, 1], [-exp1 * exp2, exp1 + exp2]])
        case "B":
            def init(key, lambs):
                lamb1, lamb2, lamb3, lamb4 = lambs
                exp1 = jnp.exp(-lamb1.value**2).item()
                exp2 = jnp.exp(-lamb2.value**2).item()
                exp3 = jnp.exp(-lamb3.value**2).item()
                exp4 = jnp.exp(-lamb4.value**2).item()
                return jnp.array([[0],[-1 * exp1 * exp2 + exp3 * exp4]])
        case "C":
            def init(key, lambs):
                return jnp.array([[1, 0]])
        case "D":
            def init(key, lambs):
                return jnp.eye(1)
    return init

def matrix_init_LRU_spec_learner_lambdas(matrix):
    assert matrix in ["A", "B", "C", "D"]
    match matrix:
        case "A":
            def init(key, lambs):
                lamb1, lamb2, _, __ = lambs
                exp1 = jnp.exp(-lamb1**2)[0]
                exp2 = jnp.exp(-lamb2**2)[0]
                return jnp.array([[0, 1], [-exp1 * exp2, exp1 + exp2]])
        case "B":
            def init(key, lambs):
                lamb1, lamb2, lamb3, lamb4 = lambs
                exp1 = jnp.exp(-lamb1**2)[0]
                exp2 = jnp.exp(-lamb2**2)[0]
                exp3 = jnp.exp(-lamb3**2)[0]
                exp4 = jnp.exp(-lamb4**2)[0]
                return jnp.array([[0],[exp1 * exp2 + exp3 * exp4]])
        case "C":
            def init(key, lambs):
                return jnp.array([[1.0, 0.0]])
        case "D":
            def init(key, lambs):
                return jnp.eye(1, dtype=jnp.float32)
    return init

def matrix_init_LRU_spec_learner(matrix):
    assert matrix in ["A", "B", "C", "D"]
    match matrix:
        case "A":
            def init(key, shape, const=None):
                if const:
                    return const * jnp.ones(shape)
                return jax.random.normal(key, shape)
        case "B":
            def init(key, shape, const=None):
                if const:
                    return const * jnp.ones(shape)
                return jax.random.normal(key, shape)
        case "C":
            def init(key, shape, const=None):
                if const:
                    return const * jnp.ones(shape)
                return jax.random.normal(key, shape)
        case "D":
            def init(key, shape, const=None):
                if const:
                    return const * jnp.ones(shape)
                return jax.random.normal(key, shape)
    return init


matrix_initializers = {
    "LRU_spec_gen": matrix_init_LRU_spec_gen,
    "LRU_spec_gen_inv": matrix_init_LRU_spec_gen,
    "LRU_spec_learner": matrix_init_LRU_spec_learner,
    "LRU_spec_learner_lambdas": matrix_init_LRU_spec_learner_lambdas,
}


class LRU_spec_gen(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    use_D: bool = True # wheter to use D or not
    lambs: dict[str, float] = field(default_factory=lambda: {
        "lambda_1": 0.0,
        "lambda_2": 0.0,
        "lambda_3": 0.0,
        "lambda_4": 0.0,
        })

    def setup(self):
        self.lambda_1 = self.variable(
            "const", "lambda_1", jax.nn.initializers.constant(self.lambs["lambda_1"]),
             self.make_rng('const'), (1,)
        )
        self.lambda_2 = self.variable(
            "const", "lambda_2", jax.nn.initializers.constant(self.lambs["lambda_2"]),
             self.make_rng('const'), (1,)
        )

        self.lambda_3 = self.variable(
            "const", "lambda_3", jax.nn.initializers.constant(self.lambs["lambda_3"]),
             self.make_rng('const'), (1,)
        )

        self.lambda_4 = self.variable(
            "const", "lambda_4", jax.nn.initializers.constant(self.lambs["lambda_4"]),
             self.make_rng('const'), (1,)
        )

        self.A = self.variable(
            "const", "A", matrix_initializers[self.__class__.__name__]("A"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )
        self.B = self.variable(
            "const", "B", matrix_initializers[self.__class__.__name__]("B"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )
        self.C = self.variable(
            "const", "C", matrix_initializers[self.__class__.__name__]("C"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )

        self.D = self.variable(
            "const", "D", matrix_initializers[self.__class__.__name__]("D"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )

    @staticmethod
    def get_matrices(cls, lambs):
        class float_with_value:
            def __init__(self, value):
                self.value = value
        if lambs and not isinstance(lambs, flax.core.scope.Variable):
            lambs = [float_with_value(l) for l in lambs]
        A = matrix_initializers[cls]("A")(None, lambs)
        B = matrix_initializers[cls]("B")(None, lambs)
        C = matrix_initializers[cls]("C")(None, lambs)
        D = matrix_initializers[cls]("D")(None, lambs)

        return A, B, C, D

    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
#        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
#                                              axis=-1)
#        #C = self.C_re + 1j * self.C_im
#        C = self.C_re# + 1j * self.C_im
#
#        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
#                                     inputs.shape[0], axis=0)
#        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
#        # Compute hidden states
#        _, hidden_states = parallel_scan(binary_operator_diag,
#                                         (Lambda_elements, Bu_elements))
#        # Use them to compute the output of the module
#        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
#                                                                   inputs)
#
#        return outputs
        A = jnp.repeat(self.A.value[None, ...], inputs.shape[0], axis=0)
        Bu_elements=jax.vmap(lambda u: self.B.value @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_nondiag,
                                       (A, Bu_elements))
        # this is needed because we want to have our output starting at
        # initial state 0
        hidden_states = jnp.vstack([jnp.zeros((1, hidden_states.shape[-1])),
                                    hidden_states[:-1]])
        #jax.debug.print("inputs: {}", inputs)
        #jax.debug.print("hidden_states: {}", hidden_states)
        # Use them to compute the output of the module
        outputs=jax.vmap(lambda h, x: self.C.value @ h
                         + self.D.value @ x)(hidden_states, inputs)

        return outputs

class LRU_spec_gen_inv(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    use_D: bool = True # wheter to use D or not
    lambs: dict[str, float] = field(default_factory=lambda: {
        "lambda_1": 0.0,
        "lambda_2": 0.0,
        "lambda_3": 0.0,
        "lambda_4": 0.0,
        })

    def setup(self):
        self.lambda_1 = self.variable(
            "const", "lambda_1", jax.nn.initializers.constant(self.lambs["lambda_1"]),
             self.make_rng('const'), (1,)
        )
        self.lambda_2 = self.variable(
            "const", "lambda_2", jax.nn.initializers.constant(self.lambs["lambda_2"]),
             self.make_rng('const'), (1,)
        )

        self.lambda_3 = self.variable(
            "const", "lambda_3", jax.nn.initializers.constant(self.lambs["lambda_3"]),
             self.make_rng('const'), (1,)
        )

        self.lambda_4 = self.variable(
            "const", "lambda_4", jax.nn.initializers.constant(self.lambs["lambda_4"]),
             self.make_rng('const'), (1,)
        )

        self.A = self.variable(
            "const", "A", matrix_initializers[self.__class__.__name__]("A"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )
        self.B = self.variable(
            "const", "B", matrix_initializers[self.__class__.__name__]("B"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )
        self.C = self.variable(
            "const", "C", matrix_initializers[self.__class__.__name__]("C"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )

        self.D = self.variable(
            "const", "D", matrix_initializers[self.__class__.__name__]("D"),
            None,
            (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
        )

    @staticmethod
    def get_matrices(cls, lambs):
        class float_with_value:
            def __init__(self, value):
                self.value = value
        if lambs and not isinstance(lambs, flax.core.scope.Variable):
            lambs = [float_with_value(l) for l in lambs]
        A = matrix_initializers[cls]("A")(None, lambs)
        B = matrix_initializers[cls]("B")(None, lambs)
        C = matrix_initializers[cls]("C")(None, lambs)
        D = matrix_initializers[cls]("D")(None, lambs)

        return A, B, C, D

    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
#        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
#                                              axis=-1)
#        #C = self.C_re + 1j * self.C_im
#        C = self.C_re# + 1j * self.C_im
#
#        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
#                                     inputs.shape[0], axis=0)
#        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
#        # Compute hidden states
#        _, hidden_states = parallel_scan(binary_operator_diag,
#                                         (Lambda_elements, Bu_elements))
#        # Use them to compute the output of the module
#        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
#                                                                   inputs)
#
#        return outputs
        tmp_A = self.A.value - self.B.value @ self.C.value
        A = jnp.repeat(tmp_A[None, ...], inputs.shape[0], axis=0)
        Bu_elements=jax.vmap(lambda u: self.B.value @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_nondiag,
                                       (A, Bu_elements))
        # this is needed because we want to have our output starting at
        # initial state 0
        hidden_states = jnp.vstack([jnp.zeros((1, hidden_states.shape[-1])),
                                    hidden_states[:-1]])
        #jax.debug.print("inputs: {}", inputs)
        #jax.debug.print("hidden_states: {}", hidden_states)
        # Use them to compute the output of the module
        tmp_C = -1 * self.C.value
        outputs=jax.vmap(lambda h, x: tmp_C @ h
                         + self.D.value @ x)(hidden_states, inputs)

        return outputs


class SequenceLayer_spec_gen(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)
    u_T: int = 0

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.lru()
        self.out1 = nn.Dense(self.d_model, use_bias=False)
        self.out2 = nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name="batch",
                use_bias=False, use_scale=False
            )
        #self.drop = nn.Dropout(self.dropout, broadcast_dims=[0],
        #                       deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x

    def build_matrices(self):
        pass


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x = inputs
        x = self.seq(x)  # call LRU
        #x = self.drop(nn.gelu(x))
        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        #x = self.drop(self.out1(x))
        #x = nn.relu(x)
        #x = self.drop(x)
        #x = self.out2(x)
        #x = self.drop(x)
        #return inputs + x  # skip connection
        return x  # skip connection

#def format_array(arr):
#    return np.array2string(np.array(arr), separator=', ')

def format_array(arr):
    formatted_str = np.array2string(np.array(arr), separator=', ')
    # Remove trailing dot from numbers
    formatted_str = formatted_str.replace('.]', ']').replace('. ', ' ')
    return formatted_str

class LRU_spec_learner(nn.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    d_hidden: int  # hidden state dimension, n_x
    d_model: int  = 1 # input and output dimensions, n_u
    r_min: float = 0.0  # smallest lambda norm
    r_max: float = 1.0  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda
    gen_matrices: dict = None
    use_D: bool = True # wheter to use D or not
    k: int = 10
    d: int = 2
    init_lambs: dict[str, float] = field(default_factory=lambda: {
        "lambda_1": 0.0,
        "lambda_2": 0.0,
        "lambda_3": 0.0,
        "lambda_4": 0.0,
        })

    def setup(self):
        self.lambda_1 = self.param(
            "lambda_1", jax.nn.initializers.constant(self.init_lambs["lambda_1"]),
            (1,)
        )
        self.lambda_2 = self.param(
            "lambda_2", jax.nn.initializers.constant(self.init_lambs["lambda_2"]),
            (1,)
        )
        self.lambda_3 = self.param(
            "lambda_3", jax.nn.initializers.constant(self.init_lambs["lambda_3"]),
            (1,)
        )
        self.lambda_4 = self.param(
            "lambda_4", jax.nn.initializers.constant(self.init_lambs["lambda_4"]),
            (1,)
        )

       # self.A = self.param(
       #     "A", matrix_initializers["LRU_spec_learner_lambdas"]("A"),
       #     (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
       # )
       # self.B = self.param(
       #     "B", matrix_initializers["LRU_spec_learner_lambdas"]("B"),
       #     (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
       # )
       # self.C = self.param(
       #     "C", matrix_initializers["LRU_spec_learner_lambdas"]("C"),
       #     (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
       # )

       # self.D = self.param(
       #     "D", matrix_initializers["LRU_spec_learner_lambdas"]("D"),
       #     (self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4)
       # )

        #self.A = self.param(
        #    "A", matrix_initializers[self.__class__.__name__]("A"),
        #    (2,2)
        #)
        #self.B = self.param(
        #    "B", matrix_initializers[self.__class__.__name__]("B"),
        #    (2,1)
        #)
        #self.C = self.param(
        #    "C", matrix_initializers[self.__class__.__name__]("C"),
        #    (1,2)
        #)

        #self.D = self.param(
        #    "D", matrix_initializers[self.__class__.__name__]("D"),
        #    (1,1)
        #)

    def get_sigma_g(self):
        A_g = self.gen_matrices["A_g"]
        B_g = self.gen_matrices["B_g"]
        C_g = self.gen_matrices["C_g"]
        D_g = self.gen_matrices["D_g"]
        return A_g, B_g, C_g, D_g

    @staticmethod
    def get_sigma_f(k, d):
        A_f=  jnp.hstack([jnp.vstack([jnp.zeros(k-1), jnp.eye(k-1)]),
                          jnp.zeros((k, 1))])
        B_f = jnp.vstack([jnp.ones(1), jnp.zeros((k-1, 1))])
        C_f = jnp.pow(d * jnp.ones((1, k)), jnp.arange(1, k+1, 1))
        D_f = jnp.eye(1)
        return A_f, B_f, C_f, D_f

    @staticmethod
    def build_sigma_l(k, d, A, B, C, D, A_g, B_g, C_g, D_g, A_f, B_f, C_f, D_f):
        block_row_1 = jnp.hstack([A, jnp.zeros((2, 4 + k))])
        block_row_2 = jnp.hstack([B_g @ C, A_g, jnp.zeros((2, 2+k))])
        block_row_3 = jnp.hstack([jnp.zeros((2, 4)), A_g - B_g @ C_g,
                                  jnp.zeros((2, k))])
        block_row_4 = jnp.hstack([jnp.zeros((k, 4)), -B_f @ C_g,  A_f])
        A_l = jnp.vstack([block_row_1, block_row_2, block_row_3, block_row_4])

        B_l = jnp.vstack([B, B_g @ D, B_g, B_f])

        block_row_1 = jnp.hstack([-D_g @ C, -C_g, jnp.zeros((1, 2+k))])
        block_row_2 = jnp.hstack([jnp.zeros((1, 4)), -D_f @ C_g, C_f])

        C_l = jnp.vstack([block_row_1, block_row_2])

        D_l = jnp.vstack([jnp.eye(1) - D_g @ D, D_f])
        return A_l, B_l, C_l, D_l

    @staticmethod
    def matrices_from_lambdas(lambda_1, lambda_2):
        #jax.debug.print("learner init")
        #jax.debug.print("lambda_1 {}", lambda_1)
        #jax.debug.print("lambda_2 {}", lambda_2)
        ##jax.debug.print("lambda_3 {}", lambda_3)
        #jax.debug.print("exp 1 {}", jnp.exp(-lambda_1**2))
        #jax.debug.print("exp 2 {}", jnp.exp(-lambda_2**2))
        lambda_3 = jnp.sqrt((lambda_1**2 + lambda_2**2) / 2)
        lambda_4 = jnp.sqrt(-1 * jnp.log(jnp.exp(-lambda_1**2)
                                         + jnp.exp(-lambda_2**2)
                                         - jnp.exp(-lambda_3**2)))
        A = jnp.vstack([
            jnp.array([0, 1]),
            jnp.hstack([- jnp.exp(-lambda_3**2) * jnp.exp(-lambda_4**2),
                        jnp.exp(-lambda_1**2) + jnp.exp(-lambda_2**2)])
        ])
        #jax.debug.print("A learner init {}", A)
        B = jnp.vstack([
            jnp.zeros((1,)),
            -1 * jnp.exp(-lambda_1**2) * jnp.exp(-lambda_2**2)
            + jnp.exp(-lambda_3**2) * jnp.exp(-lambda_4**2)
        ])
        #B = B_g
        C = jnp.array([[-1, 0]])
        D = jnp.eye(1)
        return A, B, C, D

    @staticmethod
    def matrices_from_lambdas_old(lambda_1, lambda_2, lambda_3, B_g):
        #jax.debug.print("learner init")
        jax.debug.print("lambda_1 {}", lambda_1)
        jax.debug.print("lambda_2 {}", lambda_2)
        #jax.debug.print("lambda_3 {}", lambda_3)
        jax.debug.print("exp 1 {}", jnp.exp(-lambda_1**2))
        jax.debug.print("exp 2 {}", jnp.exp(-lambda_2**2))
        A = jnp.vstack([
            jnp.array([0, 1]),
            jnp.hstack([-1 * jnp.exp(-lambda_1**2) * jnp.exp(-lambda_2**2),
                        jnp.exp(-lambda_1**2) + jnp.exp(-lambda_2**2)])
        ])
        jax.debug.print("A learner init {}", A)
        lambda_4 = jnp.sqrt(-1 * jnp.log(jnp.exp(-lambda_1**2)
                                         + jnp.exp(-lambda_2**2)
                                         - jnp.exp(-lambda_3**2)))
       # B = jnp.vstack([
       #     jnp.zeros((1,)),
       #     jnp.exp(-lambda_1**2) * jnp.exp(-lambda_2**2)
       #     + jnp.exp(-lambda_3**2) * jnp.exp(-lambda_4**2)
       # ])
        B = B_g
        C = jnp.array([[-1, 0]])
        D = jnp.eye(1)
        return A - B@jnp.array([[1, 0]]), B, C, D

    def build_matrices_inner(self):
        A_f, B_f, C_f, D_f = self.get_sigma_f(self.k, self.d)
        A_g, B_g, C_g, D_g = self.get_sigma_g()
        #A, B, C, D = self.A, self.B, self.C, self.D
        #jax.debug.print("self.lambda_1.shape {}", self.lambda_1.shape)
        A, B, C, D = self.matrices_from_lambdas(self.lambda_1,
                                                self.lambda_2)
        #A = jnp.vstack([
        #    jnp.array([0, 1]),
        #    jnp.hstack([jnp.exp(-self.lambda_1**2) * jnp.exp(-self.lambda_2**2),
        #                jnp.exp(-self.lambda_1**2) + jnp.exp(-self.lambda_2**2)])
        #])
        #lambda_4 = jnp.sqrt(-1 * jnp.log(jnp.exp(-self.lambda_1**2)
        #                                 + jnp.exp(-self.lambda_2**2)
        #                                 - jnp.exp(-self.lambda_3**2)))
        #B = jnp.vstack([
        #    jnp.zeros((1,)),
        #    jnp.exp(-self.lambda_1**2) * jnp.exp(-self.lambda_2**2)
        #    - jnp.exp(-self.lambda_3**2) * jnp.exp(-lambda_4**2)
        #])
        #C = jnp.array([[1, 0]])
        #D = jnp.eye(1)
        #A = jnp.array([[0,1],
        #               [jnp.exp(-self.lambda_1**2) * jnp.exp(-self.lambda_2**2),
        #                jnp.exp(-self.lambda_1**2) + jnp.exp(-self.lambda_2**2)]])

        #jax.debug.print("A {}", A)
        #jax.debug.print("B {}", B)
        #jax.debug.print("C {}", C)
        #jax.debug.print("D {}", D)
        ## true optimal values to be learnt
        #A = A_g - B_g @ C_g
        #B = B_g
        #C = -C_g
        #D = D_g
        #jax.debug.print("A {}", A)
        #jax.debug.print("B {}", B)
        #jax.debug.print("C {}", C)
        #jax.debug.print("D {}", D)
        #exit()
        #A = A_g - B_g @ C_g
        #B = B_g
        #C = -C_g
        #D = D_g
        #jax.debug.print("A {}", A)
        #jax.debug.print("B {}", B)
        #jax.debug.print("C {}", C)
        #jax.debug.print("D {}", D)
        #A = A_g

        A_l, B_l, C_l, D_l = self.build_sigma_l(self.k, self.d,
                                                A, B, C, D,
                                                A_g, B_g, C_g, D_g,
                                                A_f, B_f, C_f, D_f)

        return A_l, B_l, C_l, D_l

    @staticmethod
    def build_matrices_outer(cls, sigma):
        A_f, B_f, C_f, D_f = cls.get_sigma_f(cls.k, cls.d)
        A_g, B_g, C_g, D_g = cls.get_sigma_g()
        A, B, C, D = cls.matrices_from_lambdas(sigma["lambda_1"],
                                               sigma["lambda_2"])

#        A = sigma["A"]
#        B = sigma["B"]
#        C = sigma["C"]
#        D = sigma["D"]

        A_l, B_l, C_l, D_l = cls.build_sigma_l(cls.k, cls.d,
                                               A, B, C, D,
                                               A_g, B_g, C_g, D_g,
                                               A_f, B_f, C_f, D_f)

        return A_l, B_l, C_l, D_l


    def __call__(self, inputs):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        #diag_lambda = jnp.exp(-jnp.exp(self.nu_log))# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = self.nu_log# + 1j * jnp.exp(self.theta_log))
        #diag_lambda = 1 / (1 + self.nu_log**2)# + 1j * jnp.exp(self.theta_log))
        #B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(jnp.exp(self.gamma_log), axis=-1)
#        B_norm = self.B_re  * jnp.expand_dims(jnp.exp(self.gamma_log),
#                                              axis=-1)
#        #C = self.C_re + 1j * self.C_im
#        C = self.C_re# + 1j * self.C_im
#
#        Lambda_elements = jnp.repeat(diag_lambda[None, ...],
#                                     inputs.shape[0], axis=0)
#        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
#        # Compute hidden states
#        _, hidden_states = parallel_scan(binary_operator_diag,
#                                         (Lambda_elements, Bu_elements))
#        # Use them to compute the output of the module
#        outputs = jax.vmap(lambda h, x: (C @ h).real + self.D * x)(hidden_states,
#                                                                   inputs)
#
#        return outputs
        A, B, C, D = self.build_matrices_inner()
        A = jnp.repeat(A[None, ...], inputs.shape[0], axis=0)
        #jax.debug.print("inputs.shape {}", inputs.shape)
        Bu_elements=jax.vmap(lambda u: B @ u)(inputs)
        # Compute hidden states
        _, hidden_states=parallel_scan(binary_operator_nondiag,
                                       (A, Bu_elements))
        # this is needed because we want to have our output starting at
        # initial state 0
        hidden_states = jnp.vstack([jnp.zeros((1, hidden_states.shape[-1])),
                                    hidden_states[:-1]])
        #jax.debug.print("inputs: {}", inputs)
#        jax.debug.print("hidden_states: {}", hidden_states)
        # Use them to compute the output of the module
        outputs=jax.vmap(lambda h, x: C @ h + D @ x)(hidden_states, inputs)
        # this is needed because we want to have our output starting at

        return outputs


class SequenceLayer_spec_learner(nn.Module):
    """Single layer, with one LRU module, GLU, dropout and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    u_T: float
    dropout: float = 0.0  # dropout probability
    norm: str = "layer"  # which normalization to use
    training: bool = True  # in training mode (dropout in trainign mode only)

    def setup(self):
        """Initializes the ssm, layer norm and dropout"""
        self.seq = self.lru()
        self.out1 = nn.Dense(self.d_model, use_bias=False)
        self.out2 = nn.Dense(self.d_model, use_bias=False)
        if self.norm in ["layer"]:
            self.normalization = nn.LayerNorm()
        else:
            self.normalization = nn.BatchNorm(
                use_running_average=not self.training, axis_name ="batch",
                use_bias=False, use_scale=False
            )
        #self.drop = nn.Dropout(self.dropout, broadcast_dims=[0],
        #                       deterministic=not self.training)

#    def __call__(self, inputs):
#        #x = self.normalization(inputs)  # pre normalization
#        x = inputs
#        x = self.seq(x)  # call LRU
#        x = self.drop(nn.gelu(x))
#        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
#        x = x * jax.nn.sigmoid(self.out1(x))  # GLU our version
#        x = self.drop(x)
#        #return inputs + x  # skip connection
#        return x


    def __call__(self, inputs):
        #x = self.normalization(inputs)  # pre normalization
        x = inputs
        #jax.debug.print("input of lru: {}", x)
        x = self.seq(x)  # call LRU
        x = jnp.abs(x)
        z_1 = jnp.expand_dims(x[:, 0], 1)
        z_2 = jnp.expand_dims(x[:, 1], 1)
        z_2_shift= jnp.vstack([jnp.zeros((1,1)), z_2[:-1, :]])
        # z_1 = |y - (A,B,C)(y)|
        # z_2 = |u(T)|
        x = z_1 + z_2 - z_2_shift
        #jax.debug.print("z_1: {}", z_1)
        #jax.debug.print("z_2: {}", z_2)
        #jax.debug.print("x: {}", x)

        #x = self.drop(nn.gelu(x))
        #x = self.out1(x) * jax.nn.sigmoid(self.out2(x))  # GLU
        #x = self.drop(self.out1(x))
        #x = nn.relu(x)
        #x = self.drop(x)
        #x = self.out2(x)
        #x = self.drop(x)
        #return inputs + x  # skip connection

        return x  # skip connection