"""
SSMs used in Hedgehog
- Should clean up because some redundancy with the others
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import opt_einsum as oe
from einops import rearrange, repeat

class OptimModule(nn.Module):
    def __init__(self): 
        super().__init__()

    def register(self, name, tensor, trainable=False, lr=None, wd=None):
        """Utility method: register a tensor as a buffer or trainable parameter"""
        if trainable:
            try:
                self.register_parameter(name, nn.Parameter(tensor))
            except KeyError:
                delattr(self, name)
                self.register_parameter(name, nn.Parameter(tensor))
        else:
            try:
                self.register_buffer(name, tensor)
            except KeyError:
                delattr(self, name)
                self.register_buffer(name, tensor)

        optim = {}
        if trainable and lr is not None: optim["lr"] = lr
        if trainable and wd is not None: optim["weight_decay"] = wd
        if len(optim) > 0: setattr(getattr(self, name), "_optim", optim)

def init_ssm(method):
    supported_methods = ['companion', 'shift']  # , 'diagonal'
    
    if method == 'companion':
        return CompanionSSM
    elif method == 'shift':
        return ShiftSSM
    elif method == 'conv1d':
        return Conv1dSSM
    # elif method == 'diagonal':  # TODO
    #     return DiagonalSSM   
    else:
        raise NotImplementedError(
            f"SSM method {method} not implemented! Please choose from {supported_methods}")
        
# -----------------------------------
# Structured State-Space Model Layers
# -----------------------------------
class SSM(OptimModule):
    """
    Basic SSM class. Inherit from this for all other SSMs.
    """
    def __init__(self, 
                 n_kernels: int,       # Number of kernels / scales
                 kernel_dim: int,
                 n_heads: int=None,    # Number of heads per kernel
                 head_dim: int=None,   # Dimension of each head
                 model_dim: int=None,  # Dimension of layer inputs and outputs
                 kernel_weights: torch.float=None,
                 kernel_init: str='normal',
                 kernel_train: bool=True,
                 identity_kernel: bool=True,  # Include kernel where conv is identity func
                 skip_connection: bool=False,
                 bidirectional: bool=False,
                 norm_order: int=0,
                 seed: int=42,
                 lr: int=None,
                 weight_decay: int=None):
        super().__init__()
        # At least one of these should be int
        assert not (n_heads is None and head_dim is None)
                 
        self.n_kernels = n_kernels
        self.kernel_dim = kernel_dim
        
        dims = self.init_heads(n_heads, head_dim, model_dim)
        self.head_dim, self.n_heads, self.model_dim = dims
        
        self.kernel_weights  = kernel_weights
        self.kernel_init     = kernel_init
        self.kernel_train    = kernel_train
        self.identity_kernel = identity_kernel
        self.skip_connection = skip_connection
        self.bidirectional   = bidirectional
        
        # Layer-specific training hyperparameters
        self.lr = lr
        self.weight_decay = weight_decay
        
        self.seed = seed
        self.generator = torch.Generator()
        self.generator.manual_seed(seed)
        
        self.init_weights()
        
    def init_heads(self, n_heads: int, head_dim: int, model_dim: int):
        # model_dim should be precomputed as head_dim * n_heads * n_kernels
        if head_dim is None:
            self.head_dim  = self.model_dim // (self.n_kernels * n_heads)
            self.n_heads   = n_heads
        elif n_heads is None:
            self.head_dim  = head_dim
            self.n_heads   = self.model_dim // (self.n_kernels * head_dim)
        else:
            self.head_dim  = head_dim
            self.n_heads   = n_heads
            self.model_dim = self.head_dim * self.n_heads * self.n_kernels
        return self.head_dim, self.n_heads, self.model_dim
    
    def init_kernel_weights(self, kernel_init, identity_kernel=False):
        n_kernels = (self.n_kernels - 1 if self.identity_kernel and identity_kernel 
                     else self.n_kernels)
        
        # Set aside one kernel to be identity kernel
        if self.identity_kernel and identity_kernel:
            kernels = [torch.zeros(1, self.kernel_dim)]
            kernels[0][:, 0] = 1
        else:
            kernels = []
        
        # Initialize remaining kernels
        if kernel_init == 'normal':
            kernel = torch.randn(n_kernels, self.kernel_dim)
        elif kernel_init == 'xavier':
            # Xavier-ish initialization
            stdv = 1. / math.sqrt(self.kernel_dim)
            kernel = torch.FloatTensor(n_kernels, 
                                       self.kernel_dim).uniform_(-stdv, stdv)
        elif kernel_init == 'identity':
            kernel = torch.zeros(n_kernels, self.kernel_dim)
            kernel[:, 0] = 1.
        else:
            raise NotImplementedError
            
        kernels.append(kernel)
        return torch.cat(kernels, dim=0)
        
    def fft_conv(self, u_input: torch.tensor, v_kernel: torch.tensor):
        # Convolve u with v in O(n log n) time with FFT (n = len(u))
        L   = u_input.shape[-1]  # Assume u is input
        u_f = torch.fft.rfft(u_input, n=2*L) # (B H L)
        v_f = torch.fft.rfft(v_kernel[:, :L], n=2*L) # (H L)

        y_f = oe.contract('b h l, h l -> b h l', u_f, v_f) 
        y   = torch.fft.irfft(y_f, n=2*L)[..., :L]  # (B H L)
        return y
    
    def init_weights(self):
        if self.kernel_weights is not None:  
            # lr and wd as None means they're the same as model lr and wd
            self.register('k', self.kernel_weights, trainable=True, lr=self.lr, wd=self.weight_decay)
        
        skip = torch.randn(self.model_dim)
        self.register('skip', skip, trainable=True, lr=self.lr, wd=self.weight_decay)
        
    def set_weights(self, name, weights, trainable, lr, wd):
        w = getattr(self, name)
        assert w.shape == weights.shape
        self.register(name, weights, trainable, lr, wd)
    
    def get_kernel(self, u):
        raise NotImplementedError
        
    def forward(self, u):
        u = rearrange(u, 'b l d -> b d l')  # Assume u is B x L x D
        # Repeat kernels across heads
        if self.kernel_weights is None:
            k = self.get_kernel(u)
            k = repeat(k, 'nk kd -> (nk nh hd) kd', 
                       nh=self.n_heads, hd=self.head_dim)
        else:
            k = self.k
        # Compute SSM output as convolution
        y = self.fft_conv(u, k)
        
        if self.bidirectional:  
            # Reverse input length-wise and convolve with same kernel
            y += self.fft_conv(u.flip([2]), k).flip([2])
            y -= k[:, :1].unsqueeze(0) * u # Avoid double-counting

        if self.skip_connection:
            y = y + oe.contract('b d l, d -> b d l', u, self.skip)
            
        y = rearrange(y, 'b d l -> b l d')
        return y
    

# --------------------
# Companion Matrix SSM
# --------------------
class CompanionSSM(SSM):
    """
    Open-loop implementation of Companion SSM:
    -> y_t = C \sum_{i = 0}^{k - 1 - i} A^k B u_i
       where A is companion matrix
    """
    def __init__(self, norm_order: int=1, **kwargs):
        self.norm_order = norm_order
        kwargs['kernel_weights'] = None
        kwargs['kernel_train'] = True
        super().__init__(**kwargs)
        
    def init_weights(self):
        super().init_weights()  # Initializes skip connection
        self._fp = (self.n_kernels, self.kernel_dim)
        
        # Shift matrix initialization
        self.shift_matrix = torch.zeros(self.n_kernels, 
                                        self.kernel_dim, 
                                        self.kernel_dim)
        self.shift_matrix[:, 1:, :-1] = torch.eye(self.kernel_dim - 1)
        self.a_padding = torch.zeros(*self._fp)
        self.a_padding[:, -1] = 1.
        
        # A matrix
        a = self.init_kernel_weights(self.kernel_init)
        self.register("a", a, trainable=True, lr=self.lr, wd=self.weight_decay)
        
        # B matrix
        b = self.init_kernel_weights(self.kernel_init) 
        self.register("b", b, trainable=True, lr=self.lr, wd=self.weight_decay)
        
        # C matrix
        c = self.init_kernel_weights(self.kernel_init, self.identity_kernel)
        self.register("c", c, trainable=True, lr=self.lr, wd=self.weight_decay)
        
    def set_weights(self, name, weights, trainable, lr, wd):
        w = getattr(self, name)
        assert w.shape == weights.shape
        self.register(name, weights, trainable, lr, wd)
    
    def norm(self, x, ord=1):
        # x.shape is either (H x D) or (H x D x D)
        x_norm = torch.linalg.norm(x, ord=ord, dim=-1, keepdim=True)
        # If norm(x) in batch close to 0, don't normalize (heuristicky, but we norm for stability)
        x = x / x_norm if torch.abs(x_norm).mean().item() > 1e-4 else x  
        return x
    
    def get_companion_matrix(self, a):
        return (self.shift_matrix.to(a.device) + 
                oe.contract('h i, h j -> h j i', self.a_padding.to(a.device), a))
    
    def get_kernel(self, u, c=None, l=None):
        l = u.shape[-1] if l is None else l
        c = self.c if c is None else c
        a = (self.norm(self.a, ord=self.norm_order) 
             if self.norm_order > 0 else self.a)
        A = self.get_companion_matrix(a)
        # Power up companion matrix
        k = krylov(l, A, self.b, c).to(u.device)
        return k
    
    def forward(self, u):
        return super().forward(u)
    

# ----------------
# Shift Matrix SSM
# ----------------
class ShiftSSM(CompanionSSM):
    """
    Open-loop implementation of Shift SSM:
    -> y_t = C \sum_{i = 0}^{k - 1 - i} S^k B u_i
       where S is shift matrix
    """
    def __init__(self, norm_order: int=0, **kwargs):
        super().__init__(norm_order=norm_order, **kwargs)
        
    def init_weights(self):
        super().init_weights()  # Initializes skip connection, B, C matrices
        # A column initialized in super().init_weights(), but now we zero-out
        a = torch.zeros(self.n_kernels, self.kernel_dim)
        self.register("a", a, trainable=False, lr=self.lr, wd=self.weight_decay)
        
        # B Matrix - make it not learnable by default
        b = torch.zeros(self.n_kernels, self.kernel_dim)
        b[:, 0] = 1.
        self.register("b", b, trainable=False, lr=self.lr, wd=self.weight_decay)
        
        # C matrix
        c = self.init_kernel_weights(self.kernel_init, self.identity_kernel)
        self.register("c", c, trainable=True, lr=self.lr, wd=self.weight_decay)
        
    def get_kernel(self, u):
        # Short-cut to powering up C A^t B where A is shift, B is first basis vector,
        # where this just ends up being zero-padded C matrix
        # Assume u is shape B x D x L
        return F.pad(self.c, (0, u.shape[-1] - self.c.shape[-1]))
    
    def forward(self, u):
        return super().forward(u)
    
# ------------    
# Conv1D "SSM"
# ------------
class Conv1dSSM(SSM):
    """
    Open-loop implementation of Shift SSM:
    -> y_t = C \sum_{i = 0}^{k - 1 - i} S^k B u_i
       where S is shift matrix
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def init_weights(self):
        super().init_weights()  # Initializes skip connection
        input_dim = self.n_kernels * self.n_heads * self.head_dim
        self.kernel = nn.Conv1d(in_channels=input_dim, 
                                out_channels=input_dim, 
                                kernel_size=self.kernel_dim, 
                                stride=1,
                                padding=self.kernel_dim-1,
                                bias=False)
        
    def forward(self, u):
        l = u.shape[1]    # Assume u.shape = (b, l, d)
        u = rearrange(u, 'b l d -> b d l')
        y = self.kernel(u)[:, :, :l]  # Only retain same len outputs

        if self.skip_connection:
            y = y + oe.contract('b d l, d -> b d l', u, self.skip)
            
        return rearrange(y, 'b d l -> b l d')

        
# ----------------
# Helper functions
# ----------------
def krylov(L, A, b, c=None, return_power=False):
    """
    Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick.

    If return_power=True, return A^{L-1} as well
    """
    # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises

    x = b.unsqueeze(-1) # (..., N, 1)
    A_ = A

    AL = None
    if return_power:
        AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device)
        _L = L-1

    done = L == 1
    # loop invariant: _L represents how many indices left to compute
    while not done:
        if return_power:
            if _L % 2 == 1: AL = A_ @ AL
            _L //= 2

        # Save memory on last iteration
        l = x.shape[-1]
        if L - l <= l:
            done = True
            _x = x[..., :L-l]
        else: _x = x

        try:
            _x = A_ @ _x
        except Exception as e:
            print(e)
            breakpoint()
        x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes
        if not done: A_ = A_ @ A_

    try:
        assert x.shape[-1] == L
    except:
        print('x.shape', x.shape)
        print('L', L)
        breakpoint()

    if c is not None:
        x = torch.einsum('...nl, ...n -> ...l', x, c)
    x = x.contiguous() # WOW!!
    if return_power:
        return x, AL
    else:
        return x
