import torch
import global_v as glv
import numpy as np


def psp(inputs, network_config):
    shape = inputs.shape
    n_steps = network_config['n_steps']
    tau_s = network_config['tau_s']

    syn = torch.zeros((shape[0], shape[1], shape[2], shape[3])).cuda()
    syns = torch.zeros((shape[0], shape[1], shape[2], shape[3], n_steps)).cuda()

    for t in range(n_steps):
        syn = syn - syn / tau_s + inputs[..., t]
        syns[..., t] = syn / tau_s

    return syns


class SpikeLoss(torch.nn.Module):
    """
    This class defines different spike based loss modules that can be used to optimize the SNN.
    """
    def __init__(self, network_config):
        super(SpikeLoss, self).__init__()
        self.network_config = network_config
        self.f_loss = torch.nn.CrossEntropyLoss()

    def spike_count(self, outputs, target, network_config):
        delta = loss_count.apply(outputs, target, network_config)
        return 1 / 2 * torch.sum(delta ** 2)

    def spike_softmax_last(self, outputs, labels, network_config):
        out = loss_softmax_last.apply(outputs, network_config)
        return self.f_loss(out, labels.view(-1))

    def spike_softmax(self, outputs, labels, network_config):
        shape = outputs.shape
        out = torch.sum(outputs, dim=-1).view(shape[0], shape[1])
        return self.f_loss(out, labels.view(-1))

    def spike_framewise(self, outputs, labels, network_config):
        shape = outputs.shape
        return torch.sum(self.f_loss(outputs.view(shape[0], shape[1], -1), labels))


class loss_count(torch.autograd.Function):  # a and u is the incremnet of each time steps
    @staticmethod
    def forward(ctx, outputs, target, network_config):
        desired_count = network_config['desired_count']
        undesired_count = network_config['undesired_count']
        shape = outputs.shape
        n_steps = shape[4]
        out_count = torch.sum(outputs, dim=4)

        delta = (out_count - target) / n_steps
        mask = torch.ones_like(out_count)
        mask[target == undesired_count] = 0
        mask[delta < 0] = 0
        delta[mask == 1] = 0
        mask = torch.ones_like(out_count)
        mask[target == desired_count] = 0
        mask[delta > 0] = 0
        delta[mask == 1] = 0
        delta = delta.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
        return delta

    @staticmethod
    def backward(ctx, grad):
        return grad, None, None, None


class loss_softmax_last(torch.autograd.Function):  # a and u is the incremnet of each time steps
    @staticmethod
    def forward(ctx, outputs, network_config):
        shape = outputs.shape
        out = outputs[..., -1].clone()
        out = out.view(shape[0], -1)
        ctx.save_for_backward(torch.tensor([shape[4]]))
        return out

    @staticmethod
    def backward(ctx, grad):
        (others) = ctx.saved_tensors
        n_steps = others[0].item()
        shape = grad.shape

        grad_out = torch.zeros(shape[0], shape[1], 1, 1, n_steps).cuda()
        grad_out[..., -1] = grad.view(shape[0], shape[1], 1, 1)
        
        return grad_out, None


class loss_framewise(torch.autograd.Function):  # a and u is the incremnet of each time steps
    @staticmethod
    def forward(ctx, outputs, network_config, end, i):
        shape = outputs.shape
        out = outputs[i, :, :, :, 0:end].clone()
        out = out.view(1, shape[1], -1)
        ctx.save_for_backward(torch.tensor([int(shape[0]), int(shape[1]), int(shape[4]), i, end]))
        return outputs, out

    @staticmethod
    def backward(ctx, grad_outputs, grad):
        (others) = ctx.saved_tensors
        n_batch = others[0][0].item()
        n_class = others[0][1].item()
        n_steps = others[0][2].item()
        i = others[0][3].item()
        end = others[0][4].item()

        if i == n_batch - 1:
            grad_out = torch.zeros(n_batch, n_class, 1, 1, n_steps).cuda()
        else:
            grad_out = grad_outputs.clone()
        grad_out[i, :, :, :, 0:end] = grad.view(1, n_class, 1, 1, -1)
        
        return grad_out, None, None, None

