# pylint: disable=invalid-name
"""From https://github.com/lindermanlab/S5/blob/main/s5/ssm.py"""

from flax import nnx
from jax import Array, lax, numpy as jnp, vmap
from jax.nn.initializers import lecun_normal, normal

from offline.lbp.tc.modules.ssm_init import (
    init_CV,
    init_VinvB,
    init_log_steps,
    trunc_standard_normal,
)


# Discretization functions
def discretize_bilinear(Lambda, B_tilde, Delta):
    """Discretize a diagonalized, continuous-time linear SSM
    using bilinear transform method.
    Args:
        Lambda (complex64): diagonal state matrix              (P,)
        B_tilde (complex64): input matrix                      (P, H)
        Delta (float32): discretization step sizes             (P,)
    Returns:
        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)
    """
    Identity = jnp.ones(Lambda.shape[0])

    BL = 1 / (Identity - (Delta / 2.0) * Lambda)
    Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda)
    B_bar = (BL * Delta)[..., None] * B_tilde
    return Lambda_bar, B_bar


def discretize_zoh(Lambda, B_tilde, Delta):
    """Discretize a diagonalized, continuous-time linear SSM
    using zero-order hold method.
    Args:
        Lambda (complex64): diagonal state matrix              (P,)
        B_tilde (complex64): input matrix                      (P, H)
        Delta (float32): discretization step sizes             (P,)
    Returns:
        discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)
    """
    Identity = jnp.ones(Lambda.shape[0])
    Lambda_bar = jnp.exp(Lambda * Delta)
    B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde
    return Lambda_bar, B_bar


# Parallel scan operations
@vmap
def binary_operator(q_i, q_j):
    """Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A.
    Args:
        q_i: tuple containing A_i and Bu_i at position i       (P,), (P,)
        q_j: tuple containing A_j and Bu_j at position j       (P,), (P,)
    Returns:
        new element ( A_out, Bu_out )
    """
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j


def apply_ssm(
    Lambda_bar, B_bar, C_tilde, input_sequence, conj_sym, bidirectional
):
    """Compute the LxH output of discretized SSM given an LxH input.
    Args:
        Lambda_bar (complex64): discretized diagonal state matrix    (P,)
        B_bar      (complex64): discretized input matrix             (P, H)
        C_tilde    (complex64): output matrix                        (H, P)
        input_sequence (float32): input sequence of features         (L, H)
        conj_sym (bool):         whether conjugate symmetry is enforced
        bidirectional (bool):    whether bidirectional setup is used,
                              Note for this case C_tilde will have 2P cols
    Returns:
        ys (float32): the SSM outputs (S5 layer preactivations)      (L, H)
    """
    Lambda_elements = Lambda_bar * jnp.ones(
        (input_sequence.shape[0], Lambda_bar.shape[0])
    )
    Bu_elements = vmap(lambda u: B_bar @ u)(input_sequence)

    _, xs = lax.associative_scan(
        binary_operator, (Lambda_elements, Bu_elements)
    )

    if bidirectional:
        _, xs2 = lax.associative_scan(
            binary_operator, (Lambda_elements, Bu_elements), reverse=True
        )
        xs = jnp.concatenate((xs, xs2), axis=-1)

    if conj_sym:
        return vmap(lambda x: 2 * (C_tilde @ x).real)(xs)
    return vmap(lambda x: (C_tilde @ x).real)(xs)


class S5SSM(nnx.Module):
    """The S5 SSM
    Args:
        Lambda_re_init (complex64):
            Real part of init diag state matrix  (P,)
        Lambda_im_init (complex64):
            Imag part of init diag state matrix  (P,)
        V           (complex64): Eigenvectors used for init           (P,P)
        Vinv        (complex64): Inverse eigenvectors used for init   (P,P)
        H           (int32):     Number of features of input seq
        P           (int32):     state size
        C_init      (string):    Specifies How C is initialized
            Options:
                trunc_standard_normal:
                    sample from truncated standard normal
                    and then multiply by V, i.e. C_tilde=CV.
                lecun_normal:
                    sample from Lecun_normal and then multiply by V.
                complex_normal:
                    directly sample a complex valued output matrix
                    from standard normal, does not multiply by V
        conj_sym    (bool):    Whether conjugate symmetry is enforced
        clip_eigs   (bool):
            Whether to enforce left-half plane condition, i.e. constrain
            real part of eigenvalues to be negative. True recommended for
            autoregressive task/unbounded sequence lengths.
            Discussed in https://arxiv.org/pdf/2206.11893.pdf.
        bidirectional (bool):
            Whether model is bidirectional, if True, uses two C matrices
        discretization: (string) Specifies discretization method
            Options:
                zoh: zero-order hold method
                bilinear: bilinear transform
        dt_min      (float32):
            minimum value to draw timescale values from when initializing
            log_step
        dt_max      (float32):
            maximum value to draw timescale values from when initializing
            log_step
        step_rescale:  (float32):
            allows for uniformly changing the timescale parameter,
            e.g. after training on a different resolution for the speech
            commands benchmark
    """

    def __init__(
        self,
        H: int,
        Lambda_re_init: Array,
        Lambda_im_init: Array,
        V: Array,
        Vinv: Array,
        P: int,
        C_init: str,
        discretization: str,
        dt_min: float,
        dt_max: float,
        rngs: nnx.Rngs,
        conj_sym: bool = True,
        clip_eigs: bool = False,
        bidirectional: bool = False,
        step_rescale: float = 1.0,
    ):
        """Initializes parameters once and performs discretization each time
        the SSM is applied to a sequence
        """

        self.bidirectional = bidirectional
        self.clip_eigs = clip_eigs
        self.conj_sym = conj_sym
        self.step_rescale = step_rescale

        if conj_sym:
            # Need to account for case where we actually sample real B and C, and then multiply
            # by the half sized Vinv and possibly V
            local_P = 2 * P
        else:
            local_P = P

        # Initialize diagonal state to state matrix Lambda (eigenvalues)
        self.Lambda_re = nnx.Param(Lambda_re_init)
        self.Lambda_im = nnx.Param(Lambda_im_init)

        # Initialize input to state (B) matrix
        self.B = nnx.Param(
            init_VinvB(lecun_normal(), rngs.params(), (local_P, H), Vinv)
        )

        # Initialize state to output (C) matrix
        if C_init in ["trunc_standard_normal"]:
            C_init_fn = trunc_standard_normal
        elif C_init in ["lecun_normal"]:
            C_init_fn = lecun_normal()
        elif C_init in ["complex_normal"]:
            C_init_fn = normal(stddev=0.5**0.5)
        else:
            raise NotImplementedError(
                "C_init method {} not implemented".format(C_init)
            )

        if C_init in ["complex_normal"]:
            if bidirectional:
                self.C1 = nnx.Param(C_init_fn(rngs.params(), (H, 2 * P, 2)))
            else:
                self.C1 = nnx.Param(C_init_fn(rngs.params(), (H, P, 2)))
            self.C2 = None
        else:
            C_shape = (H, local_P, 2)
            if bidirectional:
                self.C1 = nnx.Param(
                    init_CV(C_init_fn, rngs.params(), C_shape, V)
                )
                self.C2 = nnx.Param(
                    init_CV(C_init_fn, rngs.params(), C_shape, V)
                )
            else:
                self.C1 = nnx.Param(
                    init_CV(C_init_fn, rngs.params(), C_shape, V)
                )
                self.C2 = None

        # Initialize feedthrough (D) matrix
        self.D = nnx.Param(normal(stddev=1.0)(rngs.params(), (H,)))

        # Initialize learnable discretization timescale value
        self.log_step = nnx.Param(
            init_log_steps(rngs.params(), (P, dt_min, dt_max))
        )

        # Discretize
        if discretization == "zoh":
            self.discretize = discretize_zoh
        elif discretization == "bilinear":
            self.discretize = discretize_bilinear
        else:
            raise NotImplementedError(
                f"Discretization method {discretization} not implemented"
            )

    def __call__(self, input_sequence):
        """
        Compute the LxH output of the S5 SSM given an LxH input sequence
        using a parallel scan.
        Args:
             input_sequence (float32): input sequence (L, H)
        Returns:
            output sequence (float32): (L, H)
        """
        if self.C2 is None:
            C_tilde = self.C1[..., 0] + 1j * self.C1[..., 1]
        else:
            C1 = self.C1[..., 0] + 1j * self.C1[..., 1]
            C2 = self.C2[..., 0] + 1j * self.C2[..., 1]
            C_tilde = jnp.concatenate((C1, C2), axis=-1)

        step = self.step_rescale * jnp.exp(self.log_step[:, 0])
        if self.clip_eigs:
            Lambda = (
                jnp.clip(self.Lambda_re, None, -1e-4)  # type: ignore
                + 1j * self.Lambda_im
            )
        else:
            Lambda = self.Lambda_re + 1j * self.Lambda_im
        B_tilde = self.B[..., 0] + 1j * self.B[..., 1]
        Lambda_bar, B_bar = self.discretize(Lambda, B_tilde, step)

        ys = apply_ssm(
            Lambda_bar,
            B_bar,
            C_tilde,
            input_sequence,
            self.conj_sym,
            self.bidirectional,
        )

        # Add feedthrough matrix output Du;
        Du = vmap(lambda u: self.D * u)(input_sequence)
        return ys + Du


class BatchS5SSM(S5SSM):
    def __init__(
        self,
        H: int,
        Lambda_re_init: Array,
        Lambda_im_init: Array,
        V: Array,
        Vinv: Array,
        P: int,
        C_init: str,
        discretization: str,
        dt_min: float,
        dt_max: float,
        rngs: nnx.Rngs,
        conj_sym: bool = True,
        clip_eigs: bool = False,
        bidirectional: bool = False,
        step_rescale: float = 1.0,
    ):
        super().__init__(
            H=H,
            Lambda_re_init=Lambda_re_init,
            Lambda_im_init=Lambda_im_init,
            V=V,
            Vinv=Vinv,
            P=P,
            C_init=C_init,
            discretization=discretization,
            dt_min=dt_min,
            dt_max=dt_max,
            rngs=rngs,
            conj_sym=conj_sym,
            clip_eigs=clip_eigs,
            bidirectional=bidirectional,
            step_rescale=step_rescale,
        )

    @nnx.vmap(in_axes=(None, 0))
    def __call__(self, input_sequence):
        return super().__call__(input_sequence)
