from abc import abstractmethod
import time
from typing import Callable, overload
from torch.utils.checkpoint import get_device_states, set_device_states
import torch
import torch.nn as nn
from . import surrogate, base
from functions import Calculate_fgrad_SNN, recombine_gradient
import numpy as np

class BaseNode(base.MemoryModule):
    def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
                 surrogate_function: Callable = surrogate.Sigmoid(), 
                 detach_reset: bool = False, 
                 tau: float = 2.,
                 alpha = 0.5,
                 beta = 0.5,
                 theta = 0.1):

        assert isinstance(v_reset, float) or v_reset is None
        assert isinstance(v_threshold, float)
        assert isinstance(detach_reset, bool)
        super().__init__()

        if v_reset is None:
            self.register_memory('v', 0.) 
        else:
            self.register_memory('v', v_reset)

        self.register_memory('v_threshold', v_threshold)
        self.register_memory('v_reset', v_reset)
        self.tau = tau
        self.alpha = alpha
        self.beta = beta
        self.theta = theta

        self.detach_reset = detach_reset
        self.surrogate_function = surrogate_function

    @abstractmethod 
    def neuronal_charge(self, x: torch.Tensor): 
        raise NotImplementedError

    
    def neuronal_fire(self,m):
        return self.surrogate_function(m - self.v_threshold)


    def neuronal_reset(self, spike, m):
        if self.detach_reset:
            spike_d = spike.detach()
        else:
            spike_d = spike

        if self.v_reset is None:
            v = m - spike_d * self.v_threshold

        else:
            v = (1. - spike_d) * m + spike_d * self.v_reset # default

        return v

    def extra_repr(self):
        return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'


    def forward(self, x: torch.Tensor):
        xs = torch.chunk(x, 2, dim=-1) 
        if torch.is_tensor(self.v):
            vs = torch.chunk(self.v, 2, dim=-1)
        else:
            vs = self.v

        v_out = []
        spike_out = []

        x_in = xs[0]
        if isinstance(vs,tuple):
            v_in = vs[0] 
        else:
            v_in = vs
            
        m_1 = self.neuronal_charge(x_in,v_in)
        spike_1 = self.neuronal_fire(m_1) + self.beta * xs[1]
        spike_out.append(spike_1)
        if isinstance(vs,tuple):
            v_1 = self.neuronal_reset(spike_1,m_1) + self.alpha * vs[0]
        else:
            v_1 = self.neuronal_reset(spike_1,m_1) + self.alpha * vs
        v_out.append(v_1)

        x_in = spike_1
        if isinstance(vs,tuple):
            v_in = vs[1] 
        else:
            v_in = vs

        m_2 = self.neuronal_charge(x_in,v_in)

        spike_2 = self.neuronal_fire(m_2) + self.beta * xs[0]

        spike_out.append(spike_2)

        if isinstance(vs,tuple):
            v_2 = self.neuronal_reset(spike_2,m_2) + self.alpha * vs[1]
        else:
            v_2 = self.neuronal_reset(spike_2,m_2) + self.alpha * vs
        v_out.append(v_2)

        self.v = torch.cat(v_out, dim=-1)
        spike = torch.cat(spike_out, dim=-1)
        return spike

    def inverse(self, y: torch.Tensor): 
        self.v = self.v.detach()
        self.v.requires_grad = True

        vs = torch.chunk(self.v, 2, dim=-1) 
        ys = torch.chunk(y, 2, dim=-1) 
        if torch.is_tensor(self.v):
            vs = torch.chunk(self.v, 2, dim=-1)
        else:
            vs = self.v

        x_out = []
        v_out = []
        
        v_2 = (vs[1] - (1 - ys[1]) * (1 / self.tau) * ys[0] - ys[1] * self.v_reset) / ((1 - ys[1]) * (1 - (1 / self.tau)) + self.alpha)
        
        m_2 = self.neuronal_charge(ys[0],v_2)
        
        m_2.requires_grad_()

        k_1 = self.neuronal_fire(m_2)

        x1 = (ys[1] - k_1) / self.beta

        v_1 = (vs[0] - (1 - ys[0]) * (1 / self.tau) * x1 - ys[0] * self.v_reset) / ((1 - ys[0]) * (1 - (1 / self.tau)) + self.alpha)

        m_1 = self.neuronal_charge(x1,v_1)

        m_1.requires_grad_()
        k_2 = self.neuronal_fire(m_1)
        x2 = (ys[0] - k_2) / self.beta

        x_out.append(x1)
        x_out.append(x2)

        v_out.append(v_1)
        v_out.append(v_2)
        self.v = torch.cat(v_out, dim=-1)
        x = torch.cat(x_out, dim=-1)

        dy2_dx2,dy2_dx1,dy1_dx1,dy1_dx2 = 1,2,3,4

        return x,dy2_dx2,dy2_dx1,dy1_dx1,dy1_dx2

    def forward_keep_v(self, x: torch.Tensor):
        xs = torch.chunk(x, 2, dim=-1) 
        if torch.is_tensor(self.v):
            vs = torch.chunk(self.v, 2, dim=-1)
        else:
            vs = self.v

        v_out = []
        spike_out = []

        x_in = xs[0]
        if isinstance(vs,tuple):
            v_in = vs[0] 
        else:
            v_in = vs

        m_1 = self.neuronal_charge(x_in,v_in)

        spike_1 = self.neuronal_fire(m_1) + self.beta * xs[1]
        spike_out.append(spike_1)

        if isinstance(vs,tuple):
            v_1 = self.neuronal_reset(spike_1,m_1) + self.alpha * vs[0]
        else:
            v_1 = self.neuronal_reset(spike_1,m_1) + self.alpha * vs
        v_out.append(v_1)

        x_in = spike_1
        if isinstance(vs,tuple):
            v_in = vs[1] 
        else:
            v_in = vs

        m_2 = self.neuronal_charge(x_in,v_in)

        spike_2 = self.neuronal_fire(m_2) + self.beta * xs[0]

        spike_out.append(spike_2)

        if isinstance(vs,tuple):
            v_2 = self.neuronal_reset(spike_2,m_2) + self.alpha * vs[1]
        else:
            v_2 = self.neuronal_reset(spike_2,m_2) + self.alpha * vs
        v_out.append(v_2)

        spike = torch.cat(spike_out, dim=-1)
        return spike

class LIFNode(BaseNode):
    def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
                 v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
                 detach_reset: bool = False, cupy_fp32_inference=False,alpha = 0.5, beta = 0.5,theta =0.5):

        assert isinstance(tau, float) and tau > 1.
        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, tau, alpha, beta,theta)
        self.tau = tau
        self.decay_input = decay_input
        
    def extra_repr(self):
        return super().extra_repr() + f', tau={self.tau}'

    def neuronal_charge(self, x: torch.Tensor, v):
        if self.decay_input:
            if self.v_reset is None or self.v_reset == 0.: 
                m = v + (x - v) / self.tau
            else:
                m = v + (x - (v - self.v_reset)) / self.tau

        else:
            if self.v_reset is None or self.v_reset == 0.:
                m = v * (1. - 1. / self.tau) + x
            else:
                m = v - (v - self.v_reset) / self.tau + x
        return m

    def forward(self, x: torch.Tensor):
        return super().forward(x) 

    def inverse(self, x: torch.Tensor):
        return super().inverse(x) 
        
# RevSNN BaseBlock
class RevSNNLayer(nn.Module):
    def __init__(self, SNNLayer):
        super(RevSNNLayer, self).__init__()
        self.nn = InvertibleModuleWrapper(fn=SNNLayer, keep_input=False)

    def forward(self, *args, **kwargs):
        return self.nn(*args, **kwargs) 

class InvertibleModuleWrapper(nn.Module):
    def __init__(
        self, fn, keep_input=False, keep_input_inverse=False, num_bwd_passes=1, disable=True, preserve_rng_state=False
    ):
        super(InvertibleModuleWrapper, self).__init__()
        self.disable = disable
        self.keep_input = keep_input
        self.keep_input_inverse = keep_input_inverse
        self.num_bwd_passes = num_bwd_passes
        self.preserve_rng_state = preserve_rng_state
        self._fn = fn

    def forward(self, *xin):
        if not self.disable: 
            y = InvertibleCheckpointFunction.apply(  
                self._fn.forward,
                self._fn.inverse,
                self._fn.forward_keep_v,
                self.keep_input,
                self.num_bwd_passes,
                self.preserve_rng_state,
                len(xin),
                *(xin + tuple([p for p in self._fn.parameters() if p.requires_grad])),
            )
        else: 
            y = self._fn(*xin)

        if isinstance(y, tuple) and len(y) == 1:
            return y[0]
        return y

class InvertibleCheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, fn, fn_inverse,fn_keep_v, keep_input, num_bwd_passes, preserve_rng_state, num_inputs, *inputs_and_weights):
        ctx.fn = fn
        ctx.fn_inverse = fn_inverse
        # ctx.fn_keep_v = fn
        ctx.fn_keep_v = fn_keep_v
        ctx.keep_input = keep_input
        ctx.weights = inputs_and_weights[num_inputs:]
        ctx.num_bwd_passes = num_bwd_passes
        ctx.preserve_rng_state = preserve_rng_state
        ctx.num_inputs = num_inputs
        inputs = inputs_and_weights[:num_inputs] 
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*inputs) 

        ctx.input_requires_grad = [element.requires_grad for element in inputs]

        with torch.no_grad():
            x = []
            for element in inputs:
                if isinstance(element, torch.Tensor):
                    element.storage().resize_(int(np.prod(element.size()))) 
                    x.append(element.detach())                
                else:
                    x.append(element)
            outputs = ctx.fn(*x)

        if not isinstance(outputs, tuple):
            outputs = (outputs,)


        detached_outputs = tuple([element.detach_() for element in outputs]) 
        
        if not ctx.keep_input:
            inputs[0].storage().resize_(0)

        ctx.inputs = [inputs] * num_bwd_passes
        ctx.outputs = [detached_outputs] * num_bwd_passes

        return detached_outputs
    

    @staticmethod
    def backward(ctx, *grad_outputs):  
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "InvertibleCheckpointFunction is not compatible with .grad(), please use .backward() if possible"
            )
        if len(ctx.outputs) == 0:
            raise RuntimeError(
                "Trying to perform backward on the InvertibleCheckpointFunction for more than "
                "{} times! Try raising `num_bwd_passes` by one.".format(ctx.num_bwd_passes)
            )
        inputs = ctx.inputs.pop()
        outputs = ctx.outputs.pop()

        if not ctx.keep_input:
            rng_devices = []
            if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
                rng_devices = ctx.fwd_gpu_devices
            with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
                if ctx.preserve_rng_state:
                    torch.set_rng_state(ctx.fwd_cpu_state)
                    if ctx.had_cuda_in_fwd:
                        set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)

                with torch.no_grad():
                    inputs_inverted = ctx.fn_inverse(*(outputs + inputs[1:]))[0]
                    for element in outputs:
                        element.storage().resize_(0)

                    if not isinstance(inputs_inverted, tuple):
                        inputs_inverted = (inputs_inverted,)
                    for element_original, element_inverted in zip(inputs, inputs_inverted):
                        element_original.storage().resize_(int(np.prod(element_original.size())))
                        element_original.set_(element_inverted)

        with torch.set_grad_enabled(True):
            detached_inputs = []
            for element in inputs:
                if isinstance(element, torch.Tensor):
                    detached_inputs.append(element.detach())
                else:
                    detached_inputs.append(element)
            detached_inputs = tuple(detached_inputs)
            for det_input, requires_grad in zip(detached_inputs, ctx.input_requires_grad):
                det_input.requires_grad = requires_grad
            temp_output = ctx.fn_keep_v(*detached_inputs)
        if not isinstance(temp_output, tuple):
            temp_output = (temp_output,)

        filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs))
        gradients = torch.autograd.grad(
            outputs=temp_output, inputs=filtered_detached_inputs + ctx.weights, grad_outputs=grad_outputs
        )

        input_gradients = []
        i = 0
        for rg in ctx.input_requires_grad:
            if rg:
                input_gradients.append(gradients[i])
                i += 1
            else:
                input_gradients.append(None)
        input_gradients[0] = torch.nan_to_num(input_gradients[0],nan = 0)
        gradients = tuple(input_gradients) 
        return (None, None, None, None, None, None, None) + gradients
