import math
from math import log

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.nn.components import LinearActivation, Activation, Normalization, DropoutNd
from einops import rearrange, repeat
import opt_einsum as oe
from functools import partial

from models.nn.components import NonLinear


class OurModule(nn.Module):
    def __init__(self): super().__init__()

    def register(self, name, tensor, trainable=False, lr=None, wd=None):
        """Utility method: register a tensor as a buffer or trainable parameter"""
        if trainable:
            try:
                self.register_parameter(name, nn.Parameter(tensor))
            except KeyError:
                delattr(self, name)
                self.register_parameter(name, nn.Parameter(tensor))
        else:
            
            try:
                self.register_buffer(name, tensor)
            except KeyError:
                delattr(self, name)
                self.register_buffer(name, tensor)

        optim = {}
        if trainable and lr is not None: optim["lr"] = lr
        if trainable and wd is not None: optim["weight_decay"] = wd
        if len(optim) > 0: setattr(getattr(self, name), "_optim", optim)

#
# This is intended to match np.convolve(x,w)[:len(w)]
# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j]
# Here y = (u \ask v) on return.
# We assume the inputs are:
# u (B H L)
# v (C H L)
# and we want to produce y that is (B C H L)
#
def fft_conv(u,v):  # y = fft_conv(g,f)
    L   = u.shape[-1]
    u_f = torch.fft.rfft(u, n=2*L) # (B H L)
    v_f = torch.fft.rfft(v, n=2*L) # (C H L)
   
    y_f = oe.contract('bhl,chl->bchl', u_f, v_f) 
    y   = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L)
    return y


def get_linear_layer(input_dim, output_dim, 
                     initial_weight=None,
                     initial_bias=None,
                     no_grad=False):
    layer = nn.Linear(input_dim, output_dim)
    with torch.no_grad():
        if initial_weight is not None:
            layer.weight.copy_(
                initial_weight.type(layer.weight.type()).view(*layer.weight.shape)
            )
        if initial_bias is not None:
            layer.bias.copy_(
                initial_bias.type(layer.bias.type()).view(*layer.bias.shape)
            )
        if no_grad:
            layer.weight.requires_grad = False
            layer.bias.requires_grad = False
    return layer


def get_encoder(input_dim, n_hippos, n_layers=1, activation='none',
                initialize_identity=False, no_grad=False):
    if activation == 'relu':
        activation = nn.ReLU()
    elif activation == 'gelu':
        activation = nn.GELU()
    else:
        activation = nn.Identity()
        
    initial_weight_1 = torch.ones((input_dim, n_hippos))
    initial_bias_1 = torch.zeros((input_dim, n_hippos))
    
    if initialize_identity:
        layers = [get_linear_layer(input_dim, n_hippos, 
                                   initial_weight_1,
                                   initial_bias_1,
                                   no_grad=no_grad)]
    else:
        layers = [get_linear_layer(input_dim, n_hippos, 
                                   None, None, no_grad)]
    
        for _ in range(n_layers - 1):
            layers.append(activation)
            layers.append(nn.Linear(n_hippos, n_hippos))

    encoder = nn.Sequential(*layers)
    return encoder


def get_decoder(output_dim, n_hippos, n_layers=1, activation='none'):
    if activation == 'relu':
        activation = nn.ReLU()
    elif activation == 'gelu':
        activation = nn.GELU()
    else:
        activation = nn.Identity()

    layers = [nn.Linear(n_hippos, output_dim)]

    for _ in range(n_layers - 1):
        layers.append(activation)
        layers.append(nn.Linear(n_hippos, n_hippos))

    decoder = nn.Sequential(*layers[::-1])
    return decoder



# COPIED FROM s4_simple.py
class SimpleS4(OurModule):
    def __init__(self,
                 nHippos,
                 d_state=64,
                 channels=1, 
                 use_initial=True, # Use the initial state?
                 zero_order_hold=False, # Use zero-order hold approximation
                 trap_rule=True, # Use the trapezoid rule
                 dt_min=0.001,
                 dt_max=0.1,
                 lr=None, # Hook to set LR of SSM parameters differently
                 learn_a=True,
                 learn_theta=True,
                 theta_scale=False,
                 unconstrained_a=False,  # New thing
                 skip_connection=True,
                 **kernel_args,): 
        super().__init__()
        # H is number of hippos
        # D is the dimension (also shockingly n other places)
        # B is the batch
        # L is the length
        self.h = nHippos
        self.d = d_state // 2    
        self.channels = channels
        self.use_initial = use_initial
        self.zero_order_hold = zero_order_hold
        #
        # Use the trapezoid rule correct or just do zero-order hold.
        self.trap_rule = trap_rule
        
        # Don't restrict A s.t. real(A) < 0
        self.unconstrained_a = unconstrained_a

        _fp    = (self.channels, self.h, self.d)
        
        # Chebyshev initialization
        h_scale  = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min))
        angles   = torch.arange(self.d)*torch.pi
        t_scale  = h_scale if theta_scale else torch.ones(self.h)
        theta    = oe.contract('c,h,d->chd', torch.ones(self.channels), t_scale, angles)
        a        = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d)
        #a = -oe.contract('c,h,d->chd', torch.ones(self.channels), _log_T,
        #                               torch.ones(self.d))
                                            
        self.register("theta", theta,learn_theta,lr=lr, wd=None)
        self.register("a", a, learn_a,lr=lr, wd=None)
        # The other maps 
        if skip_connection:
            self.D = nn.Parameter(torch.randn(channels, self.h))
        
        # Easier for inheritance
        self.lr = lr
        self.use_initial = use_initial
        self.dt_min = dt_min
        self.dt_max = dt_max
        self.theta_scale = theta_scale
        self.d_state = d_state
        self.learn_theta = learn_theta
        self.learn_a = learn_a
        self.skip_connection = skip_connection
        
        if use_initial:
            self.b = nn.Parameter(torch.randn(*_fp))
            self.c = nn.Parameter(torch.randn(*_fp))
            self.x0 = nn.Parameter(torch.randn(*_fp))
        else:
            # This is an optimization that we combine q = c * b
            # It's as if we're setting x0 = 0.
            self.q = nn.Parameter(torch.randn(*_fp))

    def zoh_method(self, u):
        l  = u.size(-1)
        T = 1/(l-1) 
        zk        = T*torch.arange(u.size(-1), device=u.device).view(1,1,-1,1)
        ls        = torch.complex(-self.a.abs(), self.theta)
        term_0    = (torch.exp(ls*T) - 1)/ls
        base_term = (2*term_0.unsqueeze(2)*torch.exp(ls.unsqueeze(2)* zk)).real
        q  = self.b*self.c if self.use_initial else self.q
        f  = (q.unsqueeze(2)*base_term).sum(-1)
        y  = fft_conv(u,f)
        y  = y + oe.contract('bhl,ch->bchl', u, self.D)
        if self.use_initial:
            # This the cosine formula from the note
            cos_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos(   self.theta.unsqueeze(2) * zk)
            y = y + (2*(self.c*self.x0).unsqueeze(2)*cos_term).sum(-1)
        return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels.

    order2const = {1 : [], 2: [1/2], 3: [5/12,13/12], 4: [3/8,7/6,23/24]}
    def numerical_quadrature(self,f,g, order=3):
        # this is old C-numerical recipe here looks like
        # http://www.foo.be/docs-free/Numerical_Recipe_In_C/c4-1.pdf
        # int_a^b = T*[1/2*f_1  + f_2 + ... + f_{n-1} + 1/2*f_n]
        # order 3 = T*[5/12*f_1 + 12/13f_2 + f_3 + ... f_{n-2} + 12/13*f_{n-1} + 5/12*f_n]
        # order 4 = T*[3/8*f_1  + 7/6*f_2 + 23/24*f_3 ... + f_{n-3} + f_{n-2}*23/24 + 7/6*f_{n-1}+3/8*f_n]
        # These formulas differ are manipulated so that for all but the endpoints, it's just adding them up!
        # Compare with typical simpson's composite rule that requires multiplying intermediate values.
        #
        # BE WARNED. The encapsulation on this is terrible, we rely on orders of f and g -- and T is premultiplied into f
        # it needs a serious refactor and it caused pain.
        #
        # Order here refers to the error term being of order say O(N^{-3}) for order 3
        #
        y = fft_conv(g,f)
        # y[k] = T*sum_{j} g[k-j]f[j] = T*sum_{j} h_k[j]
        # NB: F is pre-multiplied with T?
        def _roll(h,j): return h[...,:-j] if j > 0 else h

        for i, c in enumerate(self.order2const[order]):
            # roughly what we want is:
            # y[i:] += T*(c-1)*(h[i] + h[:-i]) where -0 is understood to mean h itself, which is not python
            # so the indexing here is we want
            # g[k-i]f[i] which means we need to shift g up by i positions.
            # term = _roll(g,i)*f[...,i] + g[...,i]*_roll(f,i)
            term  = oe.contract('ch,bhl -> bchl', f[...,i], _roll(g,i)) 
            term += oe.contract('chl,bh -> bchl', _roll(f,i), g[...,i])
            #y[...,i:] += T*(c-1)*term
            y[...,i:] += (c-1)*term # Note: f is premultiplied with T.
        return y

    def quadrature_method(self, u):
        # The input is now Batch x Hippos x Length
        l  = u.size(-1)
        T  = 1/(l-1) # the step size
        zk = T*torch.arange(l, device=u.device).view(1,1,-1,1)
        # q and a are both C x H x D
        # zk is of length l we want a C x H x L matrix
        # From the note, we have 
        # f[k] = 2 sum_{j=1}^{d} q_j e^{a_j z_k} cos( z_k * theta_j )
        # we construct the body of the sum
        
        # base_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos(   self.theta.unsqueeze(2) * zk)
        _a = self.a if self.unconstrained_a else -self.a.abs()
        try:
            base_term = 2*T*torch.exp(_a.unsqueeze(2) * zk)*torch.cos(
                self.theta.unsqueeze(2) * zk)
        except Exception as e:
            print('_a.unsqueeze(2).shape:', _a.unsqueeze(2).shape)
            print('self.theta.unsqueeze(2).shape:', self.theta.unsqueeze(2).shape)
            print('zk.shape:', zk.shape)
            raise e
        q  = self.b*self.c if self.use_initial else self.q
        f  = (q.unsqueeze(2)*base_term).sum(-1)

        # after summing f it is now an C H L matrix
        # g  = u  # this is a B H L matrix 
        # we want to convolve on L and produce a B H C L
        #
        y = self.numerical_quadrature(f,u, order = 2 if self.trap_rule else 1)
        # y = fft_conv(g,f)
        # if self.trap_rule: 
        # NB: T is incorporated into f! 
        #    y = y - T*(oe.contract('ch,bhl -> bchl', f[:,:,0], g) + oe.contract('chl,bh -> bchl', f, g[:,:,0]))/2
    
        # Add in the skip connection with per-channel D matrix
        if self.skip_connection:
            y = y + oe.contract('bhl,ch->bchl', u, self.D)
        # Add back the initial state
        if self.use_initial:
            y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1)
        return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels.

    def forward(self, u):
        if self.zero_order_hold:
            return self.zoh_method(u) 
        else:
            x = self.quadrature_method(u)
            return x
    
    
class S4SimpleModel(nn.Module):

    def __init__(self, dim, *args, input_dim=768, **kwargs):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, dim)
        self.ssm = SimpleS4(dim, *args, **kwargs)
        self.output_proj = nn.Linear(dim, input_dim)

    def forward(self, x):
        """x: (B, L, D)
        """
        x = self.input_proj(x)
        y = rearrange(self.ssm(rearrange(x, 'b l d -> b d l')), 'b d l -> b l d')
        y = F.gelu(y)
        return self.output_proj(y)

    
class SimpleS4Wrapper(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=64,
            channels=1,
            bidirectional=False,
            dropout=0.0,
            transposed=True, # axis ordering (B, L, D) or (B, D, L)
            ln=True, # IGNORED: Extra normalization
            postact=None, # activation after FF
            activation='gelu', # activation in between SS and FF
            initializer=None, # initializer on FF
            weight_norm=False, # weight normalization on FF
            linear=False,
            # SSM Kernel arguments
            **kernel_args,
        ):
        super().__init__()
        self.h = d_model
        self.d = d_state
        self.channels = channels
        #self.shift = shift
        #self.linear = linear
        self.out_d = self.h
        self.transposed = transposed
        self.bidirectional = bidirectional
        assert not bidirectional, f"Bidirectional NYI"
        self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, 
                           channels=channels, **kernel_args)
        # the mapping
        # We transpose if it's not in the forward.
        nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization
                       dropout=dropout, postact=postact, activation=activation, transposed=True,
                       initializer=initializer, weight_norm=weight_norm)
        self.out = nn.Identity() if linear else nl

    def forward(self, u, state=None):
        #  u: (B H L) if self.transposed else (B L H)
        if not self.transposed: u = u.transpose(-1, -2)
        # We only pass BHL, and it is as if transposed is True.
        
        ret_i = self.s4(u) # ret_i.shape 1, 8, 65 -> B (C H) L
        
        ret = self.out(ret_i)
        
        if not self.transposed: ret = ret.transpose(-1, -2)
        return ret, state

    @property
    def d_state(self): return self.h * self.d 

    @property
    def d_output(self): return self.out_d 