import math
from re import X
from sre_parse import State
from typing_extensions import Self
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import opt_einsum as oe
import numpy as np

from src.models.ssm.components import LinearActivation, Activation, DropoutNd
from src.ops.krylov import krylov


class OurModule(nn.Module):
    def __init__(self): super().__init__()

    def register(self, name, tensor, trainable=False, lr=None, wd=None):
        """Utility method: register a tensor as a buffer or trainable parameter"""

        if trainable:
            self.register_parameter(name, nn.Parameter(tensor))
        else:
            self.register_buffer(name, tensor)

        optim = {}
        if trainable and lr is not None: optim["lr"] = lr
        if trainable and wd is not None: optim["weight_decay"] = wd
        if len(optim) > 0: setattr(getattr(self, name), "_optim", optim)

#
# This is intended to match np.convolve(x,w)[:len(w)]
# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j]
# Here y = (u \ask v) on return.
# We assume the inputs are:
# u (B H L)
# v (C H L)
# and we want to produce y that is (B C H L)
#


def fft_conv(u,v):
    L   = u.shape[-1]
    u_f = torch.fft.rfft(u, n=2*L) # (B H L)
    v_f = torch.fft.rfft(v, n=2*L) # (C H L)

    y_f = oe.contract('bhl,chl->bchl', u_f, v_f)
    y   = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L)
    return y


class SimpleS4(OurModule):
    def __init__(self,
            nHippos,
            d_state=64,
            channels=1,
            use_initial=True, # Use the initial state?
            zero_order_hold=False, # Use zero-order hold approximation
            trap_rule=True,
            dt_min=0.001,
            dt_max=0.1,
            lr=None, # Hook to set LR of SSM parameters differently
            learn_a=True,
            learn_theta=True,
            learn_dt=False, # whether to learn separate dt for each hippo
            theta_scale=False,
            skip_connection=True,
            repr='cont', # representation to use: ['cont','disc','comp']
            **kernel_args,): # Use the trapezoid rule
        super().__init__()
        # H is number of hippos
        # D is the dimension (also shockingly n other places)
        # B is the batch
        # L is the length
        self.h = nHippos
        self.d = d_state // 2
        self.channels = channels
        self.use_initial = use_initial
        self.zero_order_hold = zero_order_hold
        #
        # Use the trapezoid rule correct or just do zero-order hold.
        self.trap_rule = trap_rule
        self.repr = repr
        self.learn_dt = learn_dt

        _fp    = (self.channels, self.h, self.d)

        # Chebyshev initialization
        h_scale  = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min))
        angles   = torch.arange(self.d)*torch.pi
        t_scale  = h_scale if theta_scale else torch.ones(self.h)
        theta    = oe.contract('c,h,d->chd', torch.ones(self.channels), t_scale, angles)
        if 'comp' in self.repr:
            # companion matrix representation
            # a_mag = np.sqrt(np.random.rand(self.channels, self.h, self.d-1))
            # a_theta = np.random.rand(self.channels, self.h, self.d-1) * 2 * np.pi
            # a_roots = a_mag * np.cos(a_theta) + 1j * a_mag * np.sin(a_theta)
            # a_coeffs = []
            # for c_i in range(self.channels):
            #     for h_i in range(self.h):
            #         a_coeffs.append(np.poly(a_roots[c_i,h_i,:]).real)
            # a = torch.FloatTensor(a_coeffs).view(*_fp)
            a = torch.randn(*_fp) # this represents "p" in the notes
        elif self.repr == 'disc':
            # discrete diagonal representation
            a = torch.randn(*_fp).abs()
            #a = 2*torch.rand(*_fp)-1 # init randomly from [-1,1]
        else:
            # default continuous diagonal representation
            a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d)

        self.register("theta", theta,learn_theta,lr=lr, wd=None)
        self.register("a", a, learn_a,lr=lr, wd=None)

        if self.learn_dt:
            log_dt = torch.rand(self.h) * (
                math.log(dt_max) - math.log(dt_min)
            ) + math.log(dt_min)
            self.register("log_dt", log_dt, True,lr=lr, wd=None)

        # The other maps
        if not skip_connection:
            self.register("D", torch.zeros((channels, self.h)), False)
        else:
            self.D = nn.Parameter(torch.randn(channels, self.h))

        if use_initial or 'comp' in self.repr:
            self.b = nn.Parameter(torch.randn(*_fp))
            self.c = nn.Parameter(torch.randn(*_fp))
            self.x0 = nn.Parameter(torch.randn(*_fp))
        else:
            # This is an optimization that we combine q = c * b
            # It's as if we're setting x0 = 0.
            self.q = nn.Parameter(torch.randn(*_fp))


    #
    # Return g[k] = c^TC(p)^kb
    #
    def companion_power_roll(self, n, c, b, p):
        ch, h, d = c.shape
        g = torch.zeros((ch,h,n)).to(c.device)
        x = b.clone()
        for i in range(n):
            g[:,:,i] = oe.contract('chd,chd->ch', c, x) #g[i] = torch.dot(c,x)
            u        = x[:,:,-1].unsqueeze(-1)*p #oe.contract('chd,chd->ch', x, p) #u = torch.dot(x,p)
            x        = torch.roll(x, shifts=(1), dims=(2)) #x = torch.roll(x, shifts=(1))
            x[:,:,0] = 0 #u #x[0] = u
            x = x + u
        return g

    def companion_power(self, l, b, p, cheby=False):
        ch, h, d = p.shape

        # construct companion matrix
        if cheby:
            # J matrix
            C = 0.5*(torch.diag(torch.ones(d-1),1) + torch.diag(torch.ones(d-1),-1))
            C[1,0] = 1
            C = repeat(C,'d1 d2 -> c h d1 d2',c=ch,h=h)
            C = C.to(p.device)
            C[:,:,:,-1] = C[:,:,:,-1] + 0.5*p
        else:
            # shift matrix
            C = torch.zeros((ch,h,d,d),device=p.device)
            C[:,:,1:,:-1] = torch.eye(d-1)
            C[:,:,:,-1] = C[:,:,:,-1] + p

        # dt = torch.exp(self.log_dt).view(1,-1,1,1)
        # eye = torch.eye(d).to(p.device)
        # Acont = (C - eye)*l
        # C = eye + dt * Acont #torch.matrix_exp(dt * Acont)
        # C = C / (torch.linalg.norm(C,ord=1,dim=1).unsqueeze(1)+1e-6)
        # uses repeated squares
        g = krylov(l,C,b)

        return g

    def zoh_method(self, u):
        l  = u.size(-1)
        T = 1/(l-1)
        zk        = T*torch.arange(u.size(-1), device=u.device).view(1,1,-1,1)
        ls        = torch.complex(-self.a.abs(), self.theta)
        term_0    = (torch.exp(ls*T) - 1)/ls
        base_term = (2*term_0.unsqueeze(2)*torch.exp(ls.unsqueeze(2)* zk)).real
        q  = self.b*self.c if self.use_initial else self.q
        f  = (q.unsqueeze(2)*base_term).sum(-1)
        y  = fft_conv(u,f)
        y  = y + oe.contract('bhl,ch->bchl', u, self.D)
        if self.use_initial:
            # This the cosine formula from the note
            cos_term = 2*T*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos(   self.theta.unsqueeze(2) * zk)
            y = y + (2*(self.c*self.x0).unsqueeze(2)*cos_term).sum(-1)
        return rearrange(y, 'b c h l-> b (c h) l') # flatten the channels.

    order2const = {1 : [], 2: [1/2], 3: [5/12,13/12], 4: [3/8,7/6,23/24]}
    def numerical_quadrature(self,f,g, order=3):
        # this is old C-numerical recipe here looks like
        # http://www.foo.be/docs-free/Numerical_Recipe_In_C/c4-1.pdf
        # int_a^b = T*[1/2*f_1  + f_2 + ... + f_{n-1} + 1/2*f_n]
        # order 3 = T*[5/12*f_1 + 12/13f_2 + f_3 + ... f_{n-2} + 12/13*f_{n-1} + 5/12*f_n]
        # order 4 = T*[3/8*f_1  + 7/6*f_2 + 23/24*f_3 ... + f_{n-3} + f_{n-2}*23/24 + 7/6*f_{n-1}+3/8*f_n]
        # These formulas differ are manipulated so that for all but the endpoints, it's just adding them up!
        # Compare with typical simpson's composite rule that requires multiplying intermediate values.
        #
        # BE WARNED. The encapsulation on this is terrible, we rely on orders of f and g -- and T is premultiplied into f
        # it needs a serious refactor and it caused pain.
        #
        # Order here refers to the error term being of order say O(N^{-3}) for order 3
        #
        y = fft_conv(g,f)
        # breakpoint()
        # y[k] = T*sum_{j} g[k-j]f[j] = T*sum_{j} h_k[j]
        # NB: F is pre-multiplied with T?
        def _roll(h,j): return h[...,:-j] if j > 0 else h

        for i, c in enumerate(self.order2const[order]):
            # roughly what we want is:
            # y[i:] += T*(c-1)*(h[i] + h[:-i]) where -0 is understood to mean h itself, which is not python
            # so the indexing here is we want
            # g[k-i]f[i] which means we need to shift g up by i positions.
            # term = _roll(g,i)*f[...,i] + g[...,i]*_roll(f,i)
            term  = oe.contract('ch,bhl -> bchl', f[...,i], _roll(g,i))
            term += oe.contract('chl,bh -> bchl', _roll(f,i), g[...,i])
            #y[...,i:] += T*(c-1)*term
            y[...,i:] += (c-1)*term # Note: f is premultiplied with T.
        return y

    def quadrature_method(self, u):
        # The input is now Batch x Hippos x Length
        l  = u.size(-1)
        dt = 1/(l-1) # the step size
        if self.learn_dt:
            dt = torch.exp(self.log_dt).view(1,-1,1, 1)


        # q and a are both C x H x D
        # zk is of length l we want a C x H x L matrix
        zk = dt*torch.arange(l, device=u.device).view(1,1,-1,1)
        # From the note, we have
        # f[k] = 2 sum_{j=1}^{d} q_j e^{a_j z_k} cos( z_k * theta_j )
        # we construct the body of the sum

        if 'comp' in self.repr:
            # companion matrix representation
            cheby = self.repr == "comp_cheby" # chebychev
            a_ = self.a/torch.linalg.norm(self.a,ord=1,dim=2).unsqueeze(2) # normalize with L1-norm
            b_ = self.b/torch.linalg.norm(self.b,ord=1,dim=2).unsqueeze(2)
            c_ = self.c/torch.linalg.norm(self.c,ord=1,dim=2).unsqueeze(2)

            f = self.companion_power(l, b_, a_, cheby=cheby)
            f_k = f[:,:,:,-1]
            u_k = u[:,:,-1]
            state = oe.contract('c h d, b h -> b c h d',f_k, u_k)
            #state = rearrange(state, 'b c h d -> b (c h) d')
            f = oe.contract('c h d, c h d l -> c h l',c_, f)
            y = self.numerical_quadrature(f,u, order = 2 if self.trap_rule else 1)

        else:
            state=None
            if self.repr == 'disc':
                # discrete diagonal representation
                a_ = (self.a).abs()
                base_term = 2 * dt * torch.pow(a_.unsqueeze(2), zk) * torch.cos(self.theta.unsqueeze(2) * zk)
            else:
                # continuous diagonal representation
                base_term = 2*dt*torch.exp(-self.a.abs().unsqueeze(2) * zk)*torch.cos(   self.theta.unsqueeze(2) * zk)

            q  = self.b*self.c if self.use_initial else self.q
            f  = (q.unsqueeze(2)*base_term).sum(-1)
            y = self.numerical_quadrature(f,u, order = 2 if self.trap_rule else 1)

        # after summing f it is now an C H L matrix
        # g  = u  # this is a B H L matrix
        # we want to convolve on L and produce a B H C L
        #
        # y = self.numerical_quadrature(f,u, order = 2 if self.trap_rule else 1)
        # y = fft_conv(g,f)
        # if self.trap_rule:
        # NB: T is incorporated into f!
        #    y = y - T*(oe.contract('ch,bhl -> bchl', f[:,:,0], g) + oe.contract('chl,bh -> bchl', f, g[:,:,0]))/2

        # Add in the skip connection with per-channel D matrix
        y = y + oe.contract('bhl,ch->bchl', u, self.D)
        # Add back the initial state
        if self.use_initial:
            y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1)
        return rearrange(y, 'b c h l-> b (c h) l'), state # flatten the channels.

    def forward(self, u):
        return self.zoh_method(u) if self.zero_order_hold else self.quadrature_method(u)
        #return self.shift_kernel(u)

    def setup_step(self):
        # companion matrix
        ch, h, d = self.a.shape
        a_ = self.a/torch.linalg.norm(self.a,ord=1,dim=2).unsqueeze(2)

        C = torch.zeros((ch,h,d,d),device=self.a.device)
        C[:,:,1:,:-1] = torch.eye(d-1)
        C[:,:,:,-1] = C[:,:,:,-1] + a_
        self.A = C


    def default_state(self, *batch_shape):
        state = torch.zeros(*batch_shape, self.h, self.d, dtype=self.a.dtype, device=self.a.device)
        return state

    def step(self, u, state):

        b_ = self.b/torch.linalg.norm(self.b,ord=1,dim=2).unsqueeze(2)
        c_ = self.c/torch.linalg.norm(self.c,ord=1,dim=2).unsqueeze(2)

        state_contraction = oe.contract_expression(
            "c h m n, ... h n -> ... h m",
            (self.channels, self.h, self.d, self.d),
            (state.shape),
        )
        next_state = state_contraction(self.A, state) \
                + oe.contract("c h d, b h -> b c h d", b_, u)
        y = oe.contract("c h n, b c h n -> b c h", c_, next_state)
        return y, next_state

# Below here are standard wrapper classes to handle
# (1) Non-linearity
# (2) Integration with the Hippo Code base
class NonLinear(nn.Module):
    def __init__(self, h, channels,
                dropout=0.0,
                postact=None, # activation after FF
                activation='gelu', # activation in between SS and FF
                initializer=None, # initializer on FF
                ):
            super().__init__()
            dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11
            dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

            activation_fn = Activation(activation)

            output_linear = LinearActivation(
                h*channels,
                h,
                initializer=initializer,
                activation=postact,
                activate=True,
            )
            #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear)
            self.f = nn.Sequential(activation_fn, dropout, output_linear)
    def forward(self,x):  # Always (B H L)
        return self.f(x)

class SimpleS4Wrapper(nn.Module):
    def __init__(
            self,
            d_model,
            d_state=64,
            channels=1,
            bidirectional=False,
            dropout=0.0,
            postact=None, # activation after FF
            activation='gelu', # activation in between SS and FF
            initializer=None, # initializer on FF
            linear=False,
            # SSM Kernel arguments
            **kernel_args,
        ):
        super().__init__()
        self.h = d_model
        self.d = d_state
        self.channels = channels
        #self.shift = shift
        #self.linear = linear
        self.out_d = self.h
        self.bidirectional = bidirectional
        assert not bidirectional, f"Bidirectional NYI"
        self.s4 = SimpleS4(nHippos=d_model, d_state=d_state,
                            channels=channels, **kernel_args)
        # the mapping
        # We transpose if it's not in the forward.
        nl          =  NonLinear(self.h, channels=self.channels,
                        dropout=dropout, postact=postact, activation=activation,
                                 initializer=initializer)
        self.out = nn.Identity() if linear else nl

    def forward(self, u, state=None):
        #  u: (B L H)
        u = u.transpose(-1, -2)
        # We only pass BHL, and it is as if transposed is True.
        # y, state = self.s4(u)
        y = self.s4(u)
        y = y.transpose(-1, -2)
        ret = self.out(y)
        # return ret, state
        # TODO: don't return state for now
        return ret

    def setup_step(self):
        self.s4.setup_step()

    def step(self, u, state):
        """ Step one time step as a recurrent model. Intended to be used during validation.
        u: (B H)
        state: (B H N)
        Returns: output (B H), state (B H N)
        """

        y, next_state = self.s4.step(u, state) # (B C H)
        y = y + oe.contract('bh,ch->bch', u, self.s4.D)
        y = rearrange(y, '... c h -> ... (c h)')
        ret = self.out(y)
        return ret, next_state

    def default_state(self, *batch_shape, device=None):
        return self.kernel.default_state(*batch_shape)

    @property
    def d_state(self): return self.h * self.d

    @property
    def d_output(self): return self.out_d
