import copy
import torch
import torch.nn as nn
from .noisy_spike import NoisySpike


class LIF(nn.Module):
    def __init__(self, nb_steps, Act=NoisySpike, readout=False, trainable_thresh=False, trainable_decay=False):
        super(LIF, self).__init__()
        self.readout = readout
        self.nb_steps = nb_steps
        if isinstance(Act, nn.Module):
            self.act = copy.deepcopy(Act)
        elif issubclass(Act, torch.nn.Module):
            self.act = Act()
        self.decay = 0.9
        self.vth = 0.5

        if trainable_decay:
            self.decay = nn.Parameter(torch.tensor(self.decay, dtype=torch.float), requires_grad=True)

        if trainable_thresh:
            self.vth = torch.nn.Parameter(torch.tensor(self.vth, dtype=torch.float), requires_grad=True)

    def forward(self, x):
        if self.readout:
            return torch.mean(x, axis=len(x.shape) - 1)
        else:
            u = torch.zeros(x.shape[:-1], device=x.device)
            out = torch.zeros(x.shape, device=x.device)
            for step in range(self.nb_steps):
                u, out[..., step] = self.state_update(u, out[..., max(step - 1, 0)], x[..., step])
            return out

    def state_update(self, u, o, i):
        u = self.decay * u * (1 - o) + i
        o = self.act(u - self.vth)
        return u, o


class tdLayer(nn.Module):
    def __init__(self, layer, nb_steps):
        super(tdLayer, self).__init__()
        self.nb_steps = nb_steps
        self.layer = layer

    def forward(self, x):
        x_ = []
        for step in range(self.nb_steps):
            x_.append(self.layer(x[..., step]))
        return torch.stack(x_, dim=len(x.shape) - 1)


class TemporalBN(nn.Module):
    def __init__(self, in_channels, nb_steps, step_wise=False):
        super(TemporalBN, self).__init__()
        self.nb_steps = nb_steps
        if step_wise:
            self.bns = nn.ModuleList([nn.BatchNorm2d(in_channels) for t in range(self.nb_steps)])
        else:
            self.bns = nn.BatchNorm2d(in_channels)
        self.step_wise = step_wise

    def forward(self, x):
        out = []
        stack_dim = len(x.shape) - 1
        for t in range(self.nb_steps):
            if self.step_wise:
                out.append(self.bns[t](x[..., t]))
            else:
                out.append(self.bns(x[..., t]))
        out = torch.stack(out, dim=stack_dim)
        return out


class tdBatchNorm(nn.Module):
    def __init__(self, bn, alpha=1, Vth=0.5):
        super(tdBatchNorm, self).__init__()
        self.bn = bn
        self.alpha = alpha
        self.Vth = Vth

    def forward(self, x):
        exponential_average_factor = 0.0

        if self.training and self.bn.track_running_stats:
            if self.bn.num_batches_tracked is not None:
                self.bn.num_batches_tracked += 1
                if self.bn.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.bn.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.bn.momentum

        if self.training:
            mean = x.mean([0, 2, 3, 4], keepdim=True)
            var = x.var([0, 2, 3, 4], keepdim=True, unbiased=False)
            n = x.numel() / x.size(1)
            with torch.no_grad():
                self.bn.running_mean = exponential_average_factor * mean[0, :, 0, 0, 0] \
                                       + (1 - exponential_average_factor) * self.bn.running_mean
                self.bn.running_var = exponential_average_factor * var[0, :, 0, 0, 0] * n / (n - 1) \
                                      + (1 - exponential_average_factor) * self.bn.running_var
        else:
            mean = self.bn.running_mean[None, :, None, None, None]
            var = self.bn.running_var[None, :, None, None, None]

        x = self.alpha * self.Vth * (x - mean) / (torch.sqrt(var) + self.bn.eps)

        if self.bn.affine:
            x = x * self.bn.weight[None, :, None, None, None] + self.bn.bias[None, :, None, None, None]

        return x
