import torch
import torch.nn as nn
import torch.nn.functional as F
import util


class FINAL_MLP_PLIF(nn.Module):
    def __init__(
        self,
        num_steps: int = 5,
        init_tau: float = 2.0,  # membrane decaying time constant
        init_spk_trace_tau: float = 0.5,  # spike trace decaying time constant
        init_acc_tau: float = 2.0,  # accumulative membrane decaying time constant
        init_parametric_tau=True,
        init_version="v3",
        init_decay_acc=True,
        scale=1000,
    ):
        super().__init__()
        self.num_steps = num_steps
        self.scale = scale

        self.fc1 = nn.Linear(784, scale, bias=False)
        self.pPLIF1 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall())

        self.pPTRACE1 = util.pPTRACE_Node(
            parametric_tau=init_parametric_tau,
            version=init_version,
            surrogate_function1=util.SpkTraceBinaryCall,
            surrogate_function2=util.SpkTraceModifiedCall,
        )
        self.fc2 = nn.Linear(scale, 100, bias=False)
        self.pPLIF2 = util.pPLIF_Node(surrogate_function=util.HeavisideBoxcarCall())

        self.boost1 = nn.AvgPool1d(10, 10)

        self.pPLI = util.pPLI_Node(
            decay_acc=init_decay_acc, surrogate_function=util.AccAlwaysGradCall()
        )

        self.tau_0 = nn.Parameter(torch.ones(1, dtype=torch.float) * init_tau)
        self.tau_0.to("cuda" if torch.cuda.is_available() else "cpu")

        self.tau_vector = torch.ones(1, dtype=torch.float) * init_tau
        self.tau_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        self.spk_trace_tau_vector = torch.ones(1, dtype=torch.float) * init_spk_trace_tau
        self.spk_trace_tau_vector.to("cuda" if torch.cuda.is_available() else "cpu")

        self.acc_tau = torch.ones(1, dtype=torch.float) * init_acc_tau
        self.acc_tau.to("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, x):
        self.device = x.device
        # spike_recording = []
        batch_size = x.size(0)
        x = x.flatten(1)

        h1_mem = h1_spike = h1_spike_trace = h1_spike_trace_b = torch.zeros(
            batch_size, self.scale, device=self.device
        )
        h2_mem = h2_spike = h2_spike_trace = h2_spike_trace_b = torch.zeros(
            batch_size, 100, device=self.device
        )
        boost1 = torch.zeros(batch_size, 10, device=self.device)
        acc_mem = torch.zeros(batch_size, 10, device=self.device)

        decay_0 = torch.sigmoid(self.tau_0)
        decay_vector = torch.sigmoid(self.tau_vector)
        acc_decay = torch.sigmoid(self.acc_tau)

        for step in range(self.num_steps - 1):
            with torch.no_grad():
                h1_mem, h1_spike = self.pPLIF1(
                    h1_mem.detach(), h1_spike.detach(), decay_0, self.conv1(x)
                )

                h2_mem, h2_spike = self.pPLIF2(
                    h2_mem.detach(), h2_spike.detach(), decay_vector[1], self.conv1(h1_spike)
                )

                boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

                acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        h1_mem, h1_spike = self.pPLIF1(h1_mem.detach(), h1_spike.detach(), decay_0, self.conv1(x))

        h2_mem, h2_spike = self.pPLIF2(
            h2_mem.detach(), h2_spike.detach(), decay_vector[1], self.conv1(h1_spike)
        )

        boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

        acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        return acc_mem, self.num_steps
        # return next - softmax and cross-entropy loss


class FINAL_MLP_et_PLIF(FINAL_MLP_PLIF):
    def forward(self, x):
        self.device = x.device
        # spike_recording = []
        batch_size = x.size(0)
        x = x.flatten(1)

        h1_mem = h1_spike = h1_spike_trace = h1_spike_trace_b = torch.zeros(
            batch_size, self.scale, device=self.device
        )
        h2_mem = h2_spike = h2_spike_trace = h2_spike_trace_b = torch.zeros(
            batch_size, 100, device=self.device
        )
        boost1 = torch.zeros(batch_size, 10, device=self.device)
        acc_mem = torch.zeros(batch_size, 10, device=self.device)

        decay_0 = torch.sigmoid(self.tau_0)
        decay_vector = torch.sigmoid(self.tau_vector)
        spk_trace_decay_vector = torch.sigmoid(self.spk_trace_tau_vector)
        acc_decay = torch.sigmoid(self.acc_tau)

        for step in range(self.num_steps - 1):
            with torch.no_grad():
                h1_mem, h1_spike = self.pPLIF1(
                    h1_mem.detach(), h1_spike.detach(), decay_0, self.conv1(x)
                )

                h2_spike_trace, h2_spike_trace_b = self.pPTRACE1(
                    h2_spike_trace.detach(), spk_trace_decay_vector[0], h1_spike
                )
                h2_mem, h2_spike = self.pPLIF2(
                    h2_mem.detach(),
                    h2_spike.detach(),
                    decay_vector[1],
                    self.conv1(h2_spike_trace_b),
                )

                boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

                acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        h1_mem, h1_spike = self.pPLIF1(h1_mem.detach(), h1_spike.detach(), decay_0, self.conv1(x))

        h2_spike_trace, h2_spike_trace_b = self.pPTRACE1(
            h2_spike_trace.detach(), spk_trace_decay_vector[0], h1_spike
        )
        h2_mem, h2_spike = self.pPLIF2(
            h2_mem.detach(), h2_spike.detach(), decay_vector[1], self.conv1(h2_spike_trace_b)
        )

        boost1 = self.boost1(h1_spike.unsqueeze(1)).squeeze(1)

        acc_mem = self.pPLI(acc_mem.detach(), acc_decay, boost1)

        return acc_mem, self.num_steps
        # return next - softmax and cross-entropy loss
