import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import power, arange, log2


def scan(x, c, h0):
    if h0 is None or x.shape[0] > 1:
        # parallel scan
        for p in power(2, arange(int(log2(x.shape[0] - 1) + 1))):
            x = torch.cat([x[:p], x[p:] + x[:-p] * c[p:]], dim=0)
            c = torch.cat([c[:p], c[p:] * c[:-p]], dim=0)
        return x
    else:
        # sequential scan
        return x + h0 * c


@torch.compile
class ATan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mem):
        spk = (mem > 0).float()
        ctx.save_for_backward(mem)
        return spk

    @staticmethod
    def backward(ctx, grad_output):
        (mem,) = ctx.saved_tensors
        grad = 1 / (1 + (torch.pi * mem).pow_(2)) * grad_output
        return grad


class ParallelLIF(nn.Module):
    def __init__(self, d_hidden, tau=0.5, gate=None):
        super(ParallelLIF, self).__init__()
        self.gate = gate
        self.beta = nn.Parameter(torch.logit(torch.ones(d_hidden) * tau), requires_grad=True)
        self.th = 0.5
        self.step_func = ATan.apply

    def forward(self, u, mem=None):
        tau = torch.sigmoid(self.beta) * torch.ones_like(u)
        mem = scan(u * (1 - tau), tau, h0=mem)

        if self.gate is None:   # default: LIF gate
            # Vth is stochastic during parallel training, and deterministic during sequential rollout
            if mem.shape[0] > 1:
                # print('training')
                spk = self.step_func(mem - self.th - 0.5 + torch.rand(u.shape[-2:], device=u.device) * torch.ones_like(u)) # Vth is different across batches
            else:
                # print('rollout')
                spk = self.step_func(mem - self.th)
        
        return spk, mem


class LIFGateUnit(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, batch_first=False, bias=True, gate=None):
        super(LIFGateUnit, self).__init__()

        self.hidden_size = hidden_size
        output_size = output_size

        self.gate_func = ParallelLIF(hidden_size * 2, gate=gate)

        self.linear_gate = nn.Linear(input_size, hidden_size * 2, bias=False)
        self.linear_in = nn.Linear(input_size, hidden_size * 2, dtype=torch.complex64, bias=False)
        self.linear_out = nn.Linear(hidden_size * 2, output_size, bias=False)
        self.layer_norm = nn.LayerNorm(output_size, elementwise_affine=False)

    def forward(self, x, mem=None):
        T, B, _ = x.size()

        # split memory state
        if mem is None:
            mem = torch.zeros(1, self.hidden_size * 4, device=x.device)
        m12, m3, m4 = torch.split(mem, [self.hidden_size * 2, self.hidden_size, self.hidden_size], dim=-1)
        m34 = torch.complex(m3, m4)

        # calculate gates
        s, m12 = self.gate_func(self.linear_gate(x), m12)
        s1, s2 = torch.split(s, self.hidden_size, dim=-1)

        # preparing variables for the recurrence
        u1, u2 = torch.split(self.linear_in(x.cfloat()), self.hidden_size, dim=-1)

        # limit c in the unit disk to avoid gradient explosion
        mod = torch.sqrt(u1.real ** 2 + u1.imag ** 2 + 1)
        mod_tanh = torch.tanh(mod)
        c = u1 / mod * mod_tanh

        # main recurrent process
        m34 = scan(u2 * s1, c * s1 + (1 - s1), h0=m34)
        out = m34 * s2 + u2 * (1 - s2)

        z = self.linear_out(torch.view_as_real(out).flatten(-2))
        z = self.layer_norm(z)

        mem = torch.cat([m12, m34.real, m34.imag], dim=-1)

        return z, mem
    
class StackedLIFGate(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, batch_first=False, bias=True):
        super(StackedLIFGate, self).__init__()
        self.num_layers = num_layers
        input_size_list = [input_size] + [output_size] * (num_layers - 1)
        self.rnns = nn.ModuleList([LIFGateUnit(input_size_list[i], hidden_size, output_size) for i in range(num_layers)])
    
    def forward(self, x, mem=None):
        if mem is not None:
            mems = torch.split(mem, 1, dim=0)
        else:
            mems = [None] * self.num_layers
        z = x
        mems_ = []
        for i in range(self.num_layers):
            z, mem_ = self.rnns[i](z, mems[i])
            mems_.append(mem_)
        mem = torch.cat(mems_, dim=0)

        return z, mem
