# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron


class LIFLayer(neuron.LIFNode):

    def __init__(self, timestep=0, **cell_args):
        super(LIFLayer, self).__init__()
        assert timestep > 0, 'the number of time steps should be specified'
        self.timestep = timestep
        self.rate_flag = cell_args.get("rate_flag", False)

        tau = 1.0 / (1.0 - torch.sigmoid(cell_args['decay'])).item()

        super().__init__(tau=tau, decay_input=False, v_threshold=cell_args['thresh'], v_reset=None,
                         detach_reset=cell_args['detach_reset'], step_mode='s')

    def forward(self, x, rate=None):
        if self.rate_flag and self.training:
            assert x.shape[0] == self.timestep
            assert rate is not None
            assert x.shape[1:] == rate.shape

            self.reset()
            spikes = []
            vs = []

            self.post_elig = 0.
            self.post_elig_factor = 1.

            if isinstance(self, neuron.LIFNode):
                lam = 1.0 - 1. / self.tau
            else:
                raise NotImplementedError()

            for t in range(self.timestep):
                self.v_float_to_tensor(x[t])
                self.neuronal_charge(x[t])
                spike = self.neuronal_fire()
                vs.append(self.v)

                # sg = torch.autograd.grad(outputs=spike.sum(), inputs=self.v, retain_graph=True)[0]
                sigmoid_alpha = 4.0
                sgax = ((self.v - self.v_threshold) * sigmoid_alpha).sigmoid_()
                sg = (1. - sgax) * sgax * sigmoid_alpha

                spikes.append(spike)

                self.post_elig = 1. / (t + 1) * (t * self.post_elig + self.post_elig_factor * sg)

                if self.v_reset is not None:  # hard-reset
                    self.post_elig_factor = 1. + self.post_elig_factor * (lam * (1. - spike) - lam * self.cell.v * sg)
                else:  # soft-reset
                    if not self.detach_reset:  # soft-reset w/ reset_detach==False
                        self.post_elig_factor = 1. + self.post_elig_factor * (lam - lam * sg)
                    else:  # soft-reset w/ reset_detach==True
                        self.post_elig_factor = 1. + self.post_elig_factor * (lam)

                self.neuronal_reset(spike)
            out = torch.stack(spikes, dim=0)
            gu = self.post_elig.clone().detach()

            rate = out.mean(dim=0).clone().detach() + (rate * gu) - (rate * gu).detach()
            return out, rate
        else:
            self.reset()
            spikes = []
            for t in range(self.timestep):
                self.v_float_to_tensor(x[t])
                self.neuronal_charge(x[t])
                spike = self.neuronal_fire()
                spikes.append(spike)
                self.neuronal_reset(spike)

            out = torch.stack(spikes, dim=0)
            return out, torch.zeros_like(out[0])


class WrappedSNNOp(nn.Module):

    def __init__(self, op, time_step, rate_flag: bool):
        super(WrappedSNNOp, self).__init__()
        self.op = op
        self.time_step = time_step
        self.rate_flag = rate_flag

        assert self.op is not None and self.time_step is not None

    def forward(self, x: torch.Tensor, rate, **kwargs):

        if not self.training or not self.rate_flag:
            x = x.contiguous()
            x = self.op(x.view(-1, *x.shape[2:]))
            out = x.view(self.time_step, -1, *x.shape[1:])

            return out, torch.zeros_like(out[0])

        if not isinstance(self.op, nn.BatchNorm2d) and not isinstance(self.op, nn.BatchNorm1d):
            with torch.no_grad():
                if self.time_step is None:
                    out = self.op(x)
                else:
                    x = x.contiguous()
                    x = self.op(x.view(-1, *x.shape[2:]))
                    out = x.view(self.time_step, -1, *x.shape[1:])
            rate = self.op(rate)
        elif isinstance(self.op, nn.BatchNorm2d):

            assert self.time_step == x.shape[0]
            x0 = x.detach()

            with torch.no_grad():
                if self.time_step is None:
                    out = self.op(x)
                else:
                    x = x.contiguous()
                    x = self.op(x.view(-1, *x.shape[2:]))
                    out = x.view(self.time_step, -1, *x.shape[1:])


            with torch.no_grad():
                x1 = x0.contiguous()
                x1 = x1.view(-1, *x1.shape[2:])
                x_mean = x1.mean(dim=(0, 2, 3), keepdim=True)
                x_var = ((x1 - x_mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

            rate_mean = rate.mean(dim=(0, 2, 3), keepdim=True)
            rate_mean = x_mean.detach() + (rate_mean - rate_mean.detach())

            rate_var = ((rate - rate_mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

            rate_var = x_var.detach() + (rate_var - rate_var.detach())
            rate_var = nn.functional.relu(rate_var)

            rate_hat = (rate - rate_mean) / torch.sqrt((rate_var + self.op.eps))
            gamma = self.op.weight.view(1, self.op.weight.shape[0], 1, 1)
            beta = self.op.bias.view(1, self.op.bias.shape[0], 1, 1)
            rate = gamma * rate_hat + beta

        else:  # nn.BatchNorm1d
            raise NotImplementedError()

        return out.detach(), rate


def wrap_model(model, time_step, rate_flag):
    for name, module in model.named_children():

        if isinstance(module, (
                nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.Flatten, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.Dropout,)):
            setattr(model, name, WrappedSNNOp(module, time_step=time_step, rate_flag=rate_flag))
        elif isinstance(module, LIFLayer):
            continue
        elif len(list(module.children())) >= 0:
            wrap_model(module, time_step=time_step, rate_flag=rate_flag)
        else:
            raise NotImplementedError()


class SequentialModule(nn.Sequential):
    def forward(self, input, rate=None, **kwargs):
        for module in self._modules.values():
            input, rate = module(input, rate, **kwargs)
        return input, rate


def affine_forward_hook(module, args, output):
    assert len(output) == 2
    if not module.training:
        return
    if module.rate_flag:
        if torch.nn.functional.cosine_similarity(output[0].mean(dim=0).reshape(1, -1),
                                                 output[1].reshape(1, -1)) > 1.0 - 1.0e-03:
            return
        else:
            print(output[0].mean(dim=0).reshape(1, -1))
            print(output[1].reshape(1, -1))
            raise NotImplementedError()

    rate = output[0].mean(dim=0)
    return output[0], rate
