import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function

__all__ = [
    'fptonline_spiking_fcn',
]


class ScaledWSLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True, gain=True, eps=1e-4):
        super(ScaledWSLinear, self).__init__(in_features, out_features, bias)
        if gain:
            self.gain = nn.Parameter(torch.ones(self.out_features, 1))
        else:
            self.gain = None
        self.eps = eps

    def get_weight(self):
        fan_in = np.prod(self.weight.shape[1:])
        mean = torch.mean(self.weight, axis=[1], keepdims=True)
        var = torch.var(self.weight, axis=[1], keepdims=True)
        weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
        if self.gain is not None:
            weight = weight * self.gain
        return weight

    def forward(self, x):
        return F.linear(x, self.get_weight(), self.bias)


class Replace(Function):
    @staticmethod
    def forward(ctx, z1, z1_r):
        return z1_r

    @staticmethod
    def backward(ctx, grad):
        return (grad, grad)


class WrapedSNNOp(nn.Module):

    def __init__(self, op):
        super(WrapedSNNOp, self).__init__()
        self.op = op

    def forward(self, x, **kwargs):
        require_wrap = kwargs.get('require_wrap', True)
        if require_wrap:
            B = x.shape[0] // 2
            spike = x[:B]
            rate = x[B:]
            with torch.no_grad():
                out = self.op(spike).detach()
            in_for_grad = Replace.apply(spike, rate)
            out_for_grad = self.op(in_for_grad)
            output = Replace.apply(out_for_grad, out)
            return output
        else:
            return self.op(x)


class FPTOnlineSpikingFCN(nn.Module):

    def __init__(self, d_in=784, d_hidden=[400, 400], num_classes=10, weight_standardization=False, single_step_neuron: callable=None, grad_with_rate=True, momentum_feedback=True, momentum_fb=0.999, local_loss=False, DFA=False, forward_quantize=False, q_scale=10., q_timesteps=20, DKP=False, **kwargs):
        super(FPTOnlineSpikingFCN, self).__init__()
        self.single_step_neuron = single_step_neuron
        self.grad_with_rate = grad_with_rate

        if weight_standardization:
            linear = ScaledWSLinear
            self.sn_scale = 2.74
        else:
            linear = nn.Linear
            self.sn_scale = 1.

        layers = []
        in_dim = d_in
        for dim in d_hidden:
            layers += [linear(in_dim, dim), single_step_neuron(**kwargs)]
            in_dim = dim

        self.features = nn.ModuleList(layers)

        self.classifier = nn.Linear(d_hidden[-1], num_classes)

        self._initialize_weights()

        self.momentum_feedback = momentum_feedback
        self.DFA = DFA
        self.DKP = DKP
        if momentum_feedback:
            self.fb = nn.ModuleList()
            for dim in d_hidden:
                self.fb.append(nn.Linear(num_classes, dim, bias=False))

            if self.DFA:
                for m in self.fb:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
            else:
                for m in self.fb:
                    nn.init.constant_(m.weight, 0)

            self.momentum_fb = momentum_fb

        self.local_loss = local_loss
        if local_loss:
            self.local_fc = nn.ModuleList()
            for dim in d_hidden:
                self.local_fc.append(nn.Linear(dim, num_classes, bias=False))

            self.ll_lambda = 0.01

        self.forward_quantize = forward_quantize
        self.q_scale = q_scale
        self.q_timesteps = q_timesteps

    def forward(self, x, forward_grad=False, loss_func=None, sample_num=1, group_num=1, share_forward=False, only_local_loss=False, sample_one_layer=False, **kwargs):
        if self.grad_with_rate and self.training:
            output_type = 'spike_rate'
        else:
            output_type = 'spike'

        x = torch.flatten(x, 1)

        # OTTT
        if not forward_grad:
            first_op = True
            for module in self.features:
                if isinstance(module, self.single_step_neuron):
                    x = module(x, forward_grad=False, output_type=output_type, **kwargs)
                    x = x * self.sn_scale
                else:
                    if output_type == 'spike_rate' and not first_op:
                        module = WrapedSNNOp(module)
                    x = module(x)
                    first_op = False

            if output_type == 'spike_rate':
                classifier = WrapedSNNOp(self.classifier)
            else:
                classifier = self.classifier

            x = classifier(x)

            if loss_func is not None:
                loss = loss_func(x)
                if self.training:
                    loss.backward()
                return x, loss.item()
            else:
                return x
        else:
            if sample_one_layer:
                layer_num = 0
                for module in self.features:
                    if isinstance(module, self.single_step_neuron):
                        layer_num += 1
                layer_index = np.random.randint(layer_num)
                current_index = 0

            # forward propagation
            with torch.no_grad():
                B = x.shape[0]
                z_list = []
                in_for_grad_list = []
                v_list = []
                sg_list = []
                if self.local_loss:
                    local_output = []

                in_for_grad_list.append(x)
                for module in self.features:
                    if isinstance(module, self.single_step_neuron):
                        x, sg = module(x, forward_grad=True, output_type=output_type, **kwargs)
                        if self.local_loss:
                            local_output.append(x[:B].flatten(1))
                        x = x * self.sn_scale
                        # forward propagate directional gradient
                        for i in range(len(z_list)):
                            if self.forward_quantize:
                                z_list[i] = (torch.clamp(z_list[i] * sg, - self.q_scale, self.q_scale) / self.q_scale * self.q_timesteps).round() / self.q_timesteps * self.q_scale * self.sn_scale
                            else:
                                z_list[i] = z_list[i] * sg * self.sn_scale
                        if output_type == 'spike_rate':
                            in_for_grad_list.append(x[B:])
                            x = x[:B]
                        else:
                            in_for_grad_list.append(x)
                        if not self.DKP and not self.DFA:
                            if sample_one_layer:
                                if current_index == layer_index:
                                    v = self.get_v(x, sample_num, group_num)
                                    v_list.append(v)
                                    z_list.append(v * self.sn_scale)
                                current_index += 1
                            else:
                                v = self.get_v(x, sample_num, group_num)
                                v_list.append(v)
                                if share_forward and len(z_list) > 0:
                                    z_list[0] = z_list[0] + v * self.sn_scale
                                else:
                                    z_list.append(v * self.sn_scale)
                        elif self.DKP:
                            v_list.append(in_for_grad_list[-1])
                        sg_list.append(sg)
                    else:
                        x = module(x)
                        # forward propagate directional gradient
                        for i in range(len(z_list)):
                            z = module(z_list[i].flatten(0, 1))
                            z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])
                # classifier
                x = self.classifier(x)
                # forward propagate directional gradient
                for i in range(len(z_list)):
                    z = self.classifier(z_list[i].flatten(0, 1))
                    z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])

            # feedback signals
            with torch.enable_grad():
                x.requires_grad_(True)
                loss = loss_func(x)
                if self.local_loss:
                    loss_local = 0
                    for i in range(len(local_output)):
                        local_output[i].requires_grad_(True)
                        loss_local += loss_func(self.local_fc[i](local_output[i]))
            loss.backward()
            if self.local_loss:
                loss_local.backward()
                
            grad_last = x.grad.data

            # update feedback connections
            if self.momentum_feedback and not self.DFA and not self.DKP:
                num = z_list[0].shape[0] * z_list[0].shape[1]
                if sample_one_layer:
                    self.fb[layer_index].weight.data = self.momentum_fb * self.fb[layer_index].weight.data + (1 - self.momentum_fb) * v_list[0].flatten(0, 1).flatten(1).t().matmul(z_list[0].flatten(0, 1)) / num
                elif share_forward:
                    for i in range(len(self.fb)):
                        self.fb[i].weight.data = self.momentum_fb * self.fb[i].weight.data + (1 - self.momentum_fb) * v_list[i].flatten(0, 1).flatten(1).t().matmul(z_list[0].flatten(0, 1)) / num
                else:
                    for i in range(len(self.fb)):
                        self.fb[i].weight.data = self.momentum_fb * self.fb[i].weight.data + (1 - self.momentum_fb) * v_list[i].flatten(0, 1).flatten(1).t().matmul(z_list[i].flatten(0, 1)) / num
            elif self.momentum_feedback and self.DKP:
                num = grad_last.shape[0]
                for i in range(len(self.fb)):
                    self.set_grad(self.fb[i], grad_last, v_list[i].flatten(1))

            # feedback propagation
            with torch.no_grad():
                if self.momentum_feedback:
                    for i in range(len(self.fb)):
                        sg_list[i] = self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i]
                else:
                    assert not sample_one_layer
                    for i in range(len(z_list)):
                        z_list[i] = torch.sum(z_list[i] * grad_last, dim=2)
                        for j in range(len(v_list[i].shape) - len(z_list[i].shape)):
                            z_list[i] = z_list[i].unsqueeze(-1)
                    if share_forward:
                        for i in range(len(v_list)):
                            sg_list[i] = v_list[i] * z_list[0] * sg_list[i]
                            sg_list[i] = torch.sum(sg_list[i], dim=0)
                    else:
                        for i in range(len(v_list)):
                            sg_list[i] = v_list[i] * z_list[i] * sg_list[i]
                            sg_list[i] = torch.sum(sg_list[i], dim=0)

            if self.local_loss:
                for i in range(len(sg_list)):
                    if only_local_loss:
                        sg_list[i] = self.ll_lambda * local_output[i].grad.reshape(v_list[i].shape)
                    else:
                        sg_list[i] = sg_list[i] + self.ll_lambda * local_output[i].grad.reshape(sg_list[i].shape)

            index = 0
            for module in self.features:
                if isinstance(module, nn.Linear) or isinstance(module, ScaledWSLinear):
                    self.set_grad(module, in_for_grad_list[index], sg_list[index])
                    index += 1
            self.set_grad(self.classifier, in_for_grad_list[-1], grad_last)

            return x, loss.item()

    def get_v(self, x, sample_num=1, group_num=1):
        v_list = []
        for i in range(sample_num):
            for j in range(group_num):
                group_dim = x.shape[1] // group_num
                v = torch.zeros_like(x)
                # Gaussian distribution
                #v[:, j*group_dim:(j+1)*group_dim] = torch.randn_like(x[:, :group_dim])
                # Rademacher distribution
                v[:, j*group_dim:(j+1)*group_dim] = torch.rand_like(x[:, :group_dim])
                v[:, j*group_dim:(j+1)*group_dim] = (v[:, j*group_dim:(j+1)*group_dim] >= 0.5).float() * 2 - 1
                v_list.append(v)
        v = torch.stack(v_list, dim=0)
        return v

    def set_grad(self, op, in_for_grad, grad):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad)
        tmp.backward(grad)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def get_grads(self):
        grads = []
        for module in self.features:
            if isinstance(module, nn.Linear) or isinstance(module, ScaledWSLinear):
                grads.append(module.weight.grad.cpu().detach())
        return grads



def fptonline_spiking_fcn(**kwargs):
    return FPTOnlineSpikingFCN(**kwargs)
