# https://github.com/NicolasZucchet/minimal-LRU/blob/main/lru/model.py
import jax
import jax.numpy as jnp

from memory.module import MemoryModule
import equinox as eqx
from equinox import nn
from typing import List, Tuple

parallel_scan = jax.lax.associative_scan


# Parallel scan operations
@jax.vmap
def binary_operator_diag(q_i, q_j):
    """Binary operator for parallel scan of linear recurrence"""
    A_i, b_i = q_i
    A_j, b_j = q_j
    return A_j * A_i, A_j * b_i + b_j


def wrapped_associative_update(
    carry: jax.Array, incoming: jax.Array
) -> Tuple[jax.Array, ...]:
    """The reset-wrapped form of the associative update.

    You might need to override this
    if you use variables in associative_update that are not from initial_state.
    This is equivalent to the h function in the paper:
    b x H -> b x H
    """
    prev_start, *carry = carry
    start, *incoming = incoming
    # Reset all elements in the carry if we are starting a new episode
    A, b = carry

    A = jnp.logical_not(start) * A + start * jnp.ones_like(A)
    b = jnp.logical_not(start) * b

    out = binary_operator_diag((A, b), incoming)
    start_out = jnp.logical_or(start, prev_start)
    return (start_out, *out)


def matrix_init(key, shape, dtype=jnp.float32, normalization=1):
    return jax.random.normal(key=key, shape=shape, dtype=dtype) / normalization


def nu_init(key, shape, r_min, r_max, dtype=jnp.float32):
    u = jax.random.uniform(key=key, shape=shape, dtype=dtype)
    return jnp.log(-0.5 * jnp.log(u * (r_max**2 - r_min**2) + r_min**2))


def theta_init(key, shape, max_phase, dtype=jnp.float32):
    u = jax.random.uniform(key, shape=shape, dtype=dtype)
    return jnp.log(max_phase * u)


def gamma_log_init(key, lamb):
    nu, theta = lamb
    diag_lambda = jnp.exp(-jnp.exp(nu) + 1j * jnp.exp(theta))
    return jnp.log(jnp.sqrt(1 - jnp.abs(diag_lambda) ** 2))


class LRU(eqx.Module):
    """
    LRU module in charge of the recurrent processing.
    Implementation following the one of Orvieto et al. 2023.
    """

    theta_log: jax.Array
    nu_log: jax.Array
    gamma_log: jax.Array
    B_re: jax.Array
    B_im: jax.Array
    C_re: jax.Array
    C_im: jax.Array
    D: jax.Array

    d_model: int  # input and output dimensions
    d_hidden: int  # hidden state dimension
    r_min: float  # smallest lambda norm
    r_max: float  # largest lambda norm
    max_phase: float = 6.28  # max phase lambda

    def __init__(self, d_model, d_hidden, r_min, r_max, key):
        keys = jax.random.split(key, 8)
        self.d_hidden = d_hidden
        self.d_model = d_model
        self.r_min = r_min
        self.r_max = r_max

        self.theta_log = theta_init(keys[0], (self.d_hidden,), self.max_phase)
        self.nu_log = nu_init(keys[1], (self.d_hidden,), self.r_min, self.r_max)
        self.gamma_log = gamma_log_init(keys[2], (self.nu_log, self.theta_log))
        self.B_re = matrix_init(
            keys[3],
            (self.d_hidden, self.d_model),
            normalization=jnp.sqrt(2 * self.d_model),
        )
        self.B_im = matrix_init(
            keys[4],
            (self.d_hidden, self.d_model),
            normalization=jnp.sqrt(2 * self.d_model),
        )
        self.C_re = matrix_init(
            keys[5],
            (self.d_model, self.d_hidden),
            normalization=jnp.sqrt(self.d_hidden),
        )
        self.C_im = matrix_init(
            keys[6],
            (self.d_model, self.d_hidden),
            normalization=jnp.sqrt(self.d_hidden),
        )
        self.D = matrix_init(keys[7], (self.d_model,))

    def __call__(self, state, x, start):
        """Forward pass of a LRU: h_t+1 = lambda * h_t + B x_t+1, y_t = Re[C h_t + D x_t]"""
        diag_lambda = jnp.exp(-jnp.exp(self.nu_log) + 1j * jnp.exp(self.theta_log))
        B_norm = (self.B_re + 1j * self.B_im) * jnp.expand_dims(
            jnp.exp(self.gamma_log), axis=-1
        )
        C = self.C_re + 1j * self.C_im

        Lambda_elements = jnp.repeat(diag_lambda[None, ...], x.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(x.astype(jnp.complex64))

        Lambda_elements = jnp.concatenate(
            [
                jnp.ones((1, diag_lambda.shape[0])),
                Lambda_elements,
            ]
        )  # (T+1, d_hidden)

        Bu_elements = jnp.concatenate(
            [
                state,
                Bu_elements,
            ]
        )  # (T+1, d_hidden)

        start = start.reshape([-1, 1])
        start = jnp.concatenate([jnp.zeros_like(start[:1]), start], axis=0)

        # Compute hidden states
        _, _, xs = parallel_scan(
            wrapped_associative_update, (start, Lambda_elements, Bu_elements)
        )
        xs = xs[1:]  # (T, d_hidden)

        # Use them to compute the output of the module
        outputs = jax.vmap(lambda x, u: (C @ x).real + self.D * u)(xs, x)

        return xs[None, -1], outputs


class SequenceLayer(eqx.Module):
    """Single layer, with one LRU module, GLU, and batch/layer norm"""

    lru: LRU  # lru module
    d_model: int  # model size
    d_hidden: int  # hidden size
    out1: eqx.Module  # first output linear layer
    out2: eqx.Module  # second output linear layer
    normalization: eqx.Module  # layer norm

    def __init__(self, d_model, d_hidden, r_min, r_max, key):
        """Initializes the ssm"""
        keys = jax.random.split(key, 3)
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.lru = LRU(self.d_model, d_hidden, r_min, r_max, key=keys[0])
        self.out1 = eqx.filter_vmap(nn.Linear(self.d_model, self.d_model, key=keys[1]))
        self.out2 = eqx.filter_vmap(nn.Linear(self.d_model, self.d_model, key=keys[2]))
        self.normalization = eqx.filter_vmap(nn.LayerNorm(self.d_model))

    def __call__(self, state, x, start):
        skip = x
        x = self.normalization(x)  # pre normalization
        state, x = self.lru(state, x, start)  # call LRU
        x = jax.nn.gelu(x)
        o1 = self.out1(x)
        x = o1 * jax.nn.sigmoid(self.out2(x))  # GLU
        return state, skip + x  # skip connection


class StackedLRU(MemoryModule):

    layers: List[SequenceLayer]
    d_model: int
    d_hidden: int
    n_layers: int
    name: str = "StackedLRU"

    def __init__(self, d_model, d_hidden, n_layers, r_min, r_max, key):
        keys = jax.random.split(key, 7)
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.n_layers = n_layers

        self.layers = [
            SequenceLayer(
                d_model=self.d_model,
                d_hidden=self.d_hidden,
                r_min=r_min,
                r_max=r_max,
                key=keys[i + 1],
            )
            for i in range(self.n_layers)
        ]

    def __call__(self, x, state, start, key=None):
        new_states = []
        for i, layer in enumerate(self.layers):
            new_s, x = layer(state[i], x, start)
            new_states.append(new_s)

        return x, new_states

    def initial_state(self, shape=tuple()):
        return [
            jnp.zeros((*shape, 1, self.d_hidden), dtype=jnp.complex64)
            for _ in range(self.n_layers)
        ]
