"""Initializations of structured state space models"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from src.models.ssm import hippo


def dplr(scaling='linear', N=64, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False):
    assert dtype == torch.float or torch.double
    dtype = torch.cfloat if dtype == torch.float else torch.cdouble

    pi = torch.tensor(math.pi)
    if random_real:
        real_part = torch.rand(H, N//2)
    else:
        real_part = .5 * torch.ones(H, N//2)
    if random_imag:
        imag_part = N//2 * torch.rand(H, N//2)
    else:
        imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H)

    real_part = real_scale * real_part
    if scaling == 'random':
        imag_part = torch.randn(H, N//2)
    elif scaling == 'real':
        imag_part = 0 * imag_part
        real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H)
    elif scaling in ['linear', 'lin']:
        imag_part = pi * imag_part
    elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix
        imag_part = 1/pi * N * (N/(1+2*imag_part)-1)
    elif scaling in ['inverse2', 'inv2']:
        imag_part = 1/pi * N * (N/(1+imag_part)-1)
    elif scaling in ['quadratic', 'quad']:
        imag_part = 1/pi * (1+2*imag_part)**2
    elif scaling in ['legs', 'hippo']:
        w, _, _, _ = hippo.nplr('legsd', N)
        imag_part = w.imag

    else: raise NotImplementedError
    imag_part = imag_scale * imag_part
    w = -real_part + 1j * imag_part

    # Initialize B
    if random_B:
        B = torch.randn(H, N//2, dtype=dtype)
    else:
        B = torch.ones(H, N//2, dtype=dtype)

    if normalize:
        norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function
        zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector
        B = B / zeta**.5

    P = torch.randn(rank, H, N//2, dtype=dtype)
    if diagonal: P = P * 0.0
    V = torch.eye(N, dtype=dtype)[:, :N//2] # Only used in testing
    V = repeat(V, 'n m -> h n m', h=H)

    return w, P, B, V

def ssm(measure, N, R, H, **ssm_args):
    """Dispatcher to create single SSM initialization

    N: state size
    R: rank (for DPLR parameterization)
    H: number of independent SSM copies
    """

    if measure == "dplr":
        w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
    elif measure.startswith("diag"):
        args = measure.split("-")
        assert args[0] == "diag" and len(args) > 1
        scaling = args[1]
        w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args)
    else:
        w, P, B, V = hippo.nplr(measure, N, R, **ssm_args)
        w = repeat(w, 'n -> s n', s=H)
        P = repeat(P, 'r n -> r s n', s=H)
        B = repeat(B, 'n -> s n', s=H)
        V = repeat(V, 'n m -> s n m', s=H)
    return w, P, B, V

combinations = {
    'hippo': ['legs', 'fourier'],
    'diag': ['diag-inv', 'diag-lin'],
    'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'],
}

def combination(measures, N, R, S, **ssm_args):
    if isinstance(measures, str):
        measures = combinations[measures] if measures in combinations else [measures]

    assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures"
    w, P, B, V = zip(
        *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures]
    )
    w = torch.cat(w, dim=0) # (S N)
    P = torch.cat(P, dim=1) # (R S N)
    B = torch.cat(B, dim=0) # (S N)
    V = torch.cat(V, dim=0) # (S N N)
    return w, P, B, V
