"""(c) 2023 The Authors of Paper #4031"""

from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import he_normal
from jax.numpy.linalg import eigh

from . import diagonal_scans
from .conv_ops import VmapResnetBlock


def make_HiPPO(N):
    P = np.sqrt(1 + 2 * np.arange(N))
    A = P[:, np.newaxis] * P[np.newaxis, :]
    A = np.tril(A) - np.diag(np.arange(N))
    return -A


def make_Normal_HiPPO(N):
    """normal approximation to the HiPPO-LegS matrix"""
    # Make -HiPPO
    hippo = make_HiPPO(N)

    # Add in a rank 1 term. Makes it Normal.
    P = np.sqrt(np.arange(N) + 0.5)
    nhippo = hippo + P[:, np.newaxis] * P[np.newaxis, :]

    # HiPPO also specifies the B matrix
    B = np.sqrt(2 * np.arange(N) + 1.0)

    return nhippo, P, B

def make_NPLR_HiPPO(N):
    """
    Makes components needed for NPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Args:
        N (int32): state size
    Returns:
        N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B
    """
    # Make -HiPPO
    hippo = make_HiPPO(N)

    # Add in a rank 1 term. Makes it Normal.
    P = np.sqrt(np.arange(N) + 0.5)

    # HiPPO also specifies the B matrix
    B = np.sqrt(2 * np.arange(N) + 1.0)
    return hippo, P, B


def make_DPLR_HiPPO(N):
    """
    Makes components needed for DPLR representation of HiPPO-LegS
     From https://github.com/srush/annotated-s4/blob/main/s4/s4.py
    Note, we will only use the diagonal part
    Args:
        N:
    Returns:
        eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B,
        eigenvectors V, HiPPO B pre-conjugation
    """
    A, P, B = make_NPLR_HiPPO(N)

    S = A + P[:, np.newaxis] * P[np.newaxis, :]

    S_diag = np.diagonal(S)
    Lambda_real = np.mean(S_diag) * np.ones_like(S_diag)

    # Diagonalize S to V \Lambda V^*
    Lambda_imag, V = eigh(S * -1j)

    P = V.conj().T @ P
    B_orig = B
    B = V.conj().T @ B
    return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig


def discretize_zoh(Lambda, B_tilde, Delta):
    """ Discretize a diagonalized, continuous-time linear SSM
        using zero-order hold method.
        Args:
            Lambda (complex64): diagonal state matrix              (P,)
            B_tilde (complex64): input matrix                      (P, H)
            Delta (float32): discretization step sizes             (P,)
        Returns:
            discretized Lambda_bar (complex64), B_bar (complex64)  (P,), (P,H)
    """
    Identity = np.ones(Lambda.shape[0])
    Lambda_bar = np.exp(Lambda * Delta)
    B_bar = (1/Lambda * (Lambda_bar-Identity))[..., None] * B_tilde
    return Lambda_bar, B_bar

def log_step_initializer(dt_min=0.001, dt_max=0.1):
    def init(key, shape):
        return jax.random.uniform(key, shape) * (
            np.log(dt_max) - np.log(dt_min)
        ) + np.log(dt_min)

    return init

def init_log_steps(key, input):
    U, dt_min, dt_max = input
    log_steps = []
    for i in range(U):
        key, skey = jax.random.split(key)
        log_step = log_step_initializer(dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,))
        log_steps.append(log_step)

    return np.array(log_steps)

def initialize_C_kernel(key, shape):
    """For general kernels, e.g. C,D, encoding/decoding"""
    out_dim, in_dim, k = shape
    fan_in = in_dim*(k**2)

    # Note in_axes should be the first by default:
    # https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.variance_scaling.html#jax.nn.initializers.variance_scaling
    return he_normal()(key,
                       (fan_in, out_dim)).reshape(out_dim,
                                                  in_dim,
                                                  k, k).transpose(0,2,3,1).reshape(-1, in_dim)

def initialize_B_kernel(key, shape):
    """We will store the B kernel as a matrix,
    returns shape: (out_dim, in_dim*k*k)"""
    out_dim, in_dim, k = shape
    fan_in = in_dim*(k**2)

    # Note in_axes should be the first by default:
    # https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.variance_scaling.html#jax.nn.initializers.variance_scaling
    return he_normal()(key,
                       (fan_in, out_dim)).T

def init_VinvB(key, shape, Vinv):
    B = initialize_B_kernel(key, shape)
    print('B', B.shape)
    VinvB = Vinv @ B
    print('VinvB', VinvB)
    VinvB_real = VinvB.real
    VinvB_imag = VinvB.imag
    return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1)

def init_CV(key, shape, V):
    out_dim, in_dim, k = shape
    C = initialize_C_kernel(key, shape)
    CV = C @ V
    CV = CV.reshape(out_dim, k, k, in_dim//2).transpose(1,2,3,0)
    CV_real = CV.real
    CV_imag = CV.imag
    return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1)

class ConvS5SSM(nn.Module):
    Lambda_re_init: np.DeviceArray
    Lambda_im_init: np.DeviceArray
    V: np.DeviceArray
    Vinv: np.DeviceArray
    clip_eigs: bool
    parallel: bool  # Compute scan in parallel

    activation: nn.module
    num_groups: int

    U: int    # Number of ConvSSM input and output features
    P: int    # Number of state features of ConvSSM
    k_B: int  # B kernel width/height
    k_C: int  # C kernel width/height
    k_D: int  # D kernel width/height

    dt_min: float  # for initializing discretization step
    dt_max: float
    C_D_config: str = "resnet"
    squeeze_excite: bool = False

    def setup(self):
        # Initialize diagonal state to state transition kernel Lambda (eigenvalues)
        self.Lambda_re = self.param("Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,))
        self.Lambda_im = self.param("Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,))
        if self.clip_eigs:
            self.Lambda = np.clip(self.Lambda_re, None, -1e-4) + 1j * self.Lambda_im
        else:
            self.Lambda = self.Lambda_re + 1j * self.Lambda_im

        # Initialize input to state (B) and output to state (C) kernels
        self.B = self.param("B",
                            lambda rng, shape: init_VinvB(rng,
                                                          shape,
                                                          self.Vinv),
                            (2*self.P, self.U, self.k_B))
        B_tilde = self.B[..., 0] + 1j * self.B[..., 1]

        self.C = self.param("C",
                            lambda rng, shape: init_CV(rng, shape, self.V),
                            (self.U, 2*self.P, self.k_C))
        self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1]



        if self.C_D_config == "resnet":
            self.C_D_conv = VmapResnetBlock(activation=self.activation,
                                            k_size=self.k_C,
                                            out_channels=self.U,
                                            num_groups=self.num_groups,
                                            squeeze_excite=self.squeeze_excite)

        # Initialize learnable discretization steps
        self.log_step = self.param("log_step",
                                   init_log_steps,
                                   (self.P, self.dt_min, self.dt_max))
        step = np.exp(self.log_step[:, 0])

        if self.parallel:
            # Discretize
            self.A_bar, self.B_bar = discretize_zoh(self.Lambda,
                                                    B_tilde,
                                                    step)
            self.B_bar = self.B_bar.reshape(self.P, self.U, self.k_B, self.k_B).transpose(2,3,1,0)
        else:
            # trick to cache the discretization for step-by-step
            def init_discrete():
                A_bar, B_bar = discretize_zoh(self.Lambda,
                                                   B_tilde,
                                                   step)
                B_bar = B_bar.reshape(self.P, self.U, self.k_B, self.k_B).transpose(2,3,1,0)
                return A_bar, B_bar
            ssm_var = self.variable("prime", "ssm", init_discrete)
            if self.is_mutable_collection("prime"):
                ssm_var.value = init_discrete()
            self.ssm = ssm_var.value


    def __call__(self, input_sequence, x0):
        """
        input sequence is shape (L, bsz, H, W, U)
        x0 is (bsz, H, W, U)
        Returns:
            x_L (float32): the last state of the ConvSSM  (bsz, H, W, P)
            ys (float32): the convSSM outputs       (L,bsz, H, W, U)
        """
        if self.parallel:
            x_last, ys = diagonal_scans.apply_convSSM_parallel(self.A_bar,
                                                               self.B_bar,
                                                               self.C_tilde,
                                                               input_sequence,
                                                               x0)

        else:
            # For sequential generation (e.g. autoregressive decoding)
            x_last, ys = diagonal_scans.apply_convSSM_sequential(*self.ssm,
                                                                 self.C_tilde,
                                                                 input_sequence,
                                                                 x0)
        if self.C_D_config == "resnet":
            ys = self.C_D_conv(ys)
        return x_last, ys


def hippo_initializer(ssm_size):
    Lambda, _, _, V, _ = make_DPLR_HiPPO(ssm_size)
    # Enforce conjugate symmetry
    ssm_size = ssm_size // 2
    Lambda = Lambda[:ssm_size]
    V = V[:, :ssm_size]
    Vinv = V.conj().T
    return Lambda.real, Lambda.imag, V, Vinv, ssm_size

def init_ConvS5SSM(ssm_size,
                   clip_eigs,
                   U,
                   k_B,
                   k_C,
                   k_D,
                   dt_min,
                   dt_max,
                   C_D_config):
    Lambda_re_init, Lambda_im_init,\
    V, Vinv, ssm_size = hippo_initializer(ssm_size)

    return partial(ConvS5SSM,
                   Lambda_re_init=Lambda_re_init,
                   Lambda_im_init=Lambda_im_init,
                   V=V,
                   Vinv=Vinv,
                   clip_eigs=clip_eigs,
                   U=U,
                   P=ssm_size,
                   k_B=k_B,
                   k_C=k_C,
                   k_D=k_D,
                   dt_min=dt_min,
                   dt_max=dt_max,
                   C_D_config=C_D_config)
