import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function

__all__ = [
    'FGTOnlineSpikingCNN', 'fgtonline_spiking_cnn', 'fgtonline_spiking_cnn_ws', 
    'fgtonline_spiking_dcnn_ws',
]

class ScaledWSConv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, gain=True, eps=1e-4):
        super(ScaledWSConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        if gain:
            self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
        else:
            self.gain = None
        self.eps = eps

    def get_weight(self):
        fan_in = np.prod(self.weight.shape[1:])
        mean = torch.mean(self.weight, axis=[1, 2, 3], keepdims=True)
        var = torch.var(self.weight, axis=[1, 2, 3], keepdims=True)
        weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
        if self.gain is not None:
            weight = weight * self.gain
        return weight

    def forward(self, x):
        return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)


class ScaledWSLinear(nn.Linear):

    def __init__(self, in_features, out_features, bias=True, gain=True, eps=1e-4):
        super(ScaledWSLinear, self).__init__(in_features, out_features, bias)
        if gain:
            self.gain = nn.Parameter(torch.ones(self.out_features, 1))
        else:
            self.gain = None
        self.eps = eps

    def get_weight(self):
        fan_in = np.prod(self.weight.shape[1:])
        mean = torch.mean(self.weight, axis=[1], keepdims=True)
        var = torch.var(self.weight, axis=[1], keepdims=True)
        weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
        if self.gain is not None:
            weight = weight * self.gain
        return weight

    def forward(self, x):
        return F.linear(x, self.get_weight(), self.bias)


class Replace(Function):
    @staticmethod
    def forward(ctx, z1, z1_r):
        return z1_r

    @staticmethod
    def backward(ctx, grad):
        return (grad, grad)


class WrapedSNNOp(nn.Module):

    def __init__(self, op):
        super(WrapedSNNOp, self).__init__()
        self.op = op

    def forward(self, x, **kwargs):
        require_wrap = kwargs.get('require_wrap', True)
        if require_wrap:
            B = x.shape[0] // 2
            spike = x[:B]
            rate = x[B:]
            with torch.no_grad():
                out = self.op(spike).detach()
            in_for_grad = Replace.apply(spike, rate)
            out_for_grad = self.op(in_for_grad)
            output = Replace.apply(out_for_grad, out)
            return output
        else:
            return self.op(x)


class FGTOnlineSpikingCNN(nn.Module):

    def __init__(self, cfg, weight_standardization=True, num_classes=10, single_step_neuron: callable=None, grad_with_rate=True, fc_hw=1, c_in=3, momentum_feedback=True, momentum_fb=0.999, h_in=32, w_in=32, local_loss=False, DFA=False, forward_quantize=False, q_scale=10., q_timesteps=20, global_step=4, global_num=2, **kwargs):
        super(FGTOnlineSpikingCNN, self).__init__()
        self.single_step_neuron = single_step_neuron
        self.grad_with_rate = grad_with_rate
        if weight_standardization:
            conv = ScaledWSConv2d
            self.sn_scale = 2.74
        else:
            conv = nn.Conv2d
            self.sn_scale = 1.

        self.features, c_final = self.make_layers(cfg=cfg, c_in=c_in, conv=conv, neuron=single_step_neuron, **kwargs)

        self.avgpool = nn.AdaptiveAvgPool2d((fc_hw, fc_hw))
        self.classifier = nn.Linear(c_final*(fc_hw**2), num_classes)

        # intermediate global learning
        assert global_num == 2 # only support one intermediate global loss now
        self.global_step = global_step
        tmp = global_step
        for v in cfg:
            if v == 'M':
                continue
            else:
                tmp -= 1
                if tmp == 0:
                    c_m = v[0]
                    break
        self.classifier_m = nn.Linear(c_m*(fc_hw**2), num_classes)

        self._initialize_weights()

        self.momentum_feedback = momentum_feedback
        self.DFA = DFA
        if momentum_feedback:
            self.fb = nn.ModuleList()
            hw_dim = h_in * w_in
            for v in cfg:
                if v == 'M':
                    hw_dim = hw_dim // 4
                else:
                    self.fb.append(nn.Linear(num_classes, hw_dim*v[0], bias=False))

            if self.DFA:
                for m in self.fb:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
            else:
                for m in self.fb:
                    nn.init.constant_(m.weight, 0)

            self.momentum_fb = momentum_fb

            # intermediate global learning
            self.fb_m = nn.ModuleList()
            hw_dim = h_in * w_in
            tmp = global_step
            for v in cfg:
                if v == 'M':
                    hw_dim = hw_dim // 4
                else:
                    self.fb_m.append(nn.Linear(num_classes, hw_dim*v[0], bias=False))
                    tmp -= 1
                    if tmp == 0:
                        break
            for m in self.fb_m:
                nn.init.constant_(m.weight, 0)

        self.local_loss = local_loss
        if local_loss:
            self.local_fc = nn.ModuleList()
            hw_dim = h_in * w_in
            for v in cfg:
                if v == 'M':
                    hw_dim = hw_dim // 4
                else:
                    self.local_fc.append(nn.Linear(hw_dim*v[0], num_classes, bias=False))

            self.ll_lambda = 0.01

        self.forward_quantize = forward_quantize
        self.q_scale = q_scale
        self.q_timesteps = q_timesteps

    @staticmethod
    def make_layers(cfg, c_in=3, conv=ScaledWSConv2d, neuron: callable = None, **kwargs):
        layers = []
        for v in cfg:
            if v == 'M':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                channel = v[0]
                ks = v[1]
                layers += [conv(c_in, channel, kernel_size=ks, padding=(ks-1)//2), neuron(**kwargs)]
                c_in = channel
        return nn.ModuleList(layers), c_in

    def forward(self, x, forward_grad=False, loss_func=None, sample_num=1, group_num=1, share_forward=False, only_local_loss=False, sample_one_layer=False, **kwargs):
        if self.grad_with_rate and self.training:
            output_type = 'spike_rate'
        else:
            output_type = 'spike'

        # OTTT
        if not forward_grad:
            first_conv = True
            for module in self.features:
                if isinstance(module, self.single_step_neuron):
                    x = module(x, forward_grad=False, output_type=output_type, **kwargs)
                    x = x * self.sn_scale
                elif isinstance(module, nn.Conv2d) or isinstance(module, ScaledWSConv2d):
                    if output_type == 'spike_rate' and not first_conv:
                        module = WrapedSNNOp(module)
                    x = module(x)
                    first_conv = False
                else:
                    x = module(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            if output_type == 'spike_rate':
                classifier = WrapedSNNOp(self.classifier)
            else:
                classifier = self.classifier
            x = classifier(x)

            if loss_func is not None:
                loss = loss_func(x)
                if self.training:
                    loss.backward()
                return x, loss.item()
            else:
                return x
        else:
            if sample_one_layer:
                layer_num = 0
                for module in self.features:
                    if isinstance(module, self.single_step_neuron):
                        layer_num += 1
                layer_index = np.random.randint(layer_num)
                current_index = 0

            tmp = self.global_step
            # forward propagation
            with torch.no_grad():
                B = x.shape[0]
                z_list = []

                in_for_grad_list = []
                v_list = []
                sg_list = []
                if self.local_loss:
                    local_output = []

                # for conv1
                in_for_grad_list.append(x)

                # features
                for module in self.features:
                    if isinstance(module, self.single_step_neuron):
                        x, sg = module(x, forward_grad=True, output_type=output_type, **kwargs)
                        if self.local_loss:
                            local_output.append(x[:B].flatten(1))
                        x = x * self.sn_scale
                        # forward propagate directional gradient
                        for i in range(len(z_list)):
                            if self.forward_quantize:
                                z_list[i] = (torch.clamp(z_list[i] * sg, - self.q_scale, self.q_scale) / self.q_scale * self.q_timesteps).round() / self.q_timesteps * self.q_scale * self.sn_scale
                            else:
                                z_list[i] = z_list[i] * sg * self.sn_scale
                        if output_type == 'spike_rate':
                            in_for_grad_list.append(x[B:])
                            x = x[:B]
                        else:
                            in_for_grad_list.append(x)
                        if sample_one_layer:
                            if current_index == layer_index:
                                # v dim: N*B*C*H*W
                                v = self.get_v(x, sample_num, group_num)
                                v_list.append(v)
                                z_list.append(v * self.sn_scale)
                            current_index += 1
                        else:
                            # v dim: N*B*C*H*W
                            v = self.get_v(x, sample_num, group_num)
                            v_list.append(v)
                            if share_forward and len(z_list) > 0:
                                z_list[0] = z_list[0] + v * self.sn_scale
                            else:
                                z_list.append(v * self.sn_scale)
                        sg_list.append(sg)

                        # intermediate global
                        tmp -= 1
                        if tmp == 0:
                            x_m = torch.flatten(self.avgpool(x), 1)
                            in_for_grad_m = torch.flatten(self.avgpool(in_for_grad_list[-1]), 1)
                            z_list_m = []
                            for i in range(len(z_list)):
                                z = torch.flatten(self.avgpool(z_list[i].flatten(0, 1)), 1)
                                z_list_m.append(z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:]))
                            x_m = self.classifier_m(x_m)
                            for i in range(len(z_list_m)):
                                z = self.classifier_m(z_list_m[i].flatten(0, 1))
                                z_list_m[i] = z.reshape(z_list_m[i].shape[0], z_list_m[i].shape[1], *z.shape[1:])

                    elif isinstance(module, nn.Conv2d) or isinstance(module, ScaledWSConv2d):
                        x = module(x)
                        # forward propagate directional gradient
                        for i in range(len(z_list)):
                            z = module(z_list[i].flatten(0, 1))
                            z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])
                    else: # pooling
                        x = module(x)
                        in_for_grad_list[-1] = module(in_for_grad_list[-1])
                        # forward propagate directional gradient
                        for i in range(len(z_list)):
                            z = module(z_list[i].flatten(0, 1))
                            z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])

                # classifier
                x = torch.flatten(self.avgpool(x), 1)
                in_for_grad_list[-1] = torch.flatten(self.avgpool(in_for_grad_list[-1]), 1)
                # forward propagate directional gradient
                for i in range(len(z_list)):
                    z = torch.flatten(self.avgpool(z_list[i].flatten(0, 1)), 1)
                    z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])
                x = self.classifier(x)
                for i in range(len(z_list)):
                    z = self.classifier(z_list[i].flatten(0, 1))
                    z_list[i] = z.reshape(z_list[i].shape[0], z_list[i].shape[1], *z.shape[1:])

            # feedback signals
            with torch.enable_grad():
                x.requires_grad_(True)
                loss = loss_func(x)
                if self.local_loss:
                    loss_local = 0
                    for i in range(len(local_output)):
                        local_output[i].requires_grad_(True)
                        loss_local += loss_func(self.local_fc[i](local_output[i]))
            loss.backward()
            if self.local_loss:
                loss_local.backward()

            grad_last = x.grad.data

            # intermediate global
            with torch.enable_grad():
                x_m.requires_grad_(True)
                loss_m = loss_func(x_m)
            loss_m.backward()
            grad_m = x_m.grad.data

            # update feedback connections
            if self.momentum_feedback and not self.DFA:
                num = z_list[0].shape[0] * z_list[0].shape[1]
                if sample_one_layer:
                    self.fb[layer_index].weight.data = self.momentum_fb * self.fb[layer_index].weight.data + (1 - self.momentum_fb) * v_list[0].flatten(0, 1).flatten(1).t().matmul(z_list[0].flatten(0, 1)) / num
                elif share_forward:
                    for i in range(len(self.fb)):
                        self.fb[i].weight.data = self.momentum_fb * self.fb[i].weight.data + (1 - self.momentum_fb) * v_list[i].flatten(0, 1).flatten(1).t().matmul(z_list[0].flatten(0, 1)) / num
                else:
                    for i in range(len(self.fb)):
                        self.fb[i].weight.data = self.momentum_fb * self.fb[i].weight.data + (1 - self.momentum_fb) * v_list[i].flatten(0, 1).flatten(1).t().matmul(z_list[i].flatten(0, 1)) / num

                # intermadiate global
                assert not sample_one_layer
                assert not share_forward
                for i in range(len(self.fb_m)):
                    self.fb_m[i].weight.data = self.momentum_fb * self.fb_m[i].weight.data + (1 - self.momentum_fb) * v_list[i].flatten(0, 1).flatten(1).t().matmul(z_list_m[i].flatten(0, 1)) / num

            # feedback propagation
            with torch.no_grad():
                if self.momentum_feedback:
                    for i in range(len(self.fb)):
                        #sg_list[i] = self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i]

                        # intermediate global
                        if i < len(self.fb_m):
                            sg_list[i] = (self.fb[i](grad_last).reshape(sg_list[i].shape) + self.fb_m[i](grad_m).reshape(sg_list[i].shape)) / 2. * sg_list[i]
                        else:
                            sg_list[i] = self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i]
                else:
                    assert not sample_one_layer
                    for i in range(len(z_list)):
                        z_list[i] = torch.sum(z_list[i] * grad_last, dim=2)
                        for j in range(len(v_list[i].shape) - len(z_list[i].shape)):
                            z_list[i] = z_list[i].unsqueeze(-1)
                    if share_forward:
                        for i in range(len(v_list)):
                            sg_list[i] = v_list[i] * z_list[0] * sg_list[i]
                            sg_list[i] = torch.sum(sg_list[i], dim=0)
                    else:
                        for i in range(len(v_list)):
                            sg_list[i] = v_list[i] * z_list[i] * sg_list[i]
                            sg_list[i] = torch.sum(sg_list[i], dim=0)

            if self.local_loss:
                for i in range(len(sg_list)):
                    if only_local_loss:
                        sg_list[i] = self.ll_lambda * local_output[i].grad.reshape(sg_list[i].shape)
                    else:
                        sg_list[i] = sg_list[i] + self.ll_lambda * local_output[i].grad.reshape(sg_list[i].shape)

            index = 0
            for module in self.features:
                if isinstance(module, nn.Conv2d) or isinstance(module, ScaledWSConv2d):
                    self.set_grad(module, in_for_grad_list[index], sg_list[index])
                    index += 1
            self.set_grad(self.classifier, in_for_grad_list[-1], grad_last)
            # intermediate global
            self.set_grad(self.classifier_m, in_for_grad_m, grad_m)

            return x, loss.item()

    def get_v(self, x, sample_num=1, group_num=1):
        v_list = []
        for i in range(sample_num):
            for j in range(group_num):
                group_dim = x.shape[1] // group_num
                v = torch.zeros_like(x)
                # Gaussian distribution
                #v[:, j*group_dim:(j+1)*group_dim] = torch.randn_like(x[:, :group_dim])
                # Rademacher distribution
                v[:, j*group_dim:(j+1)*group_dim] = torch.rand_like(x[:, :group_dim])
                v[:, j*group_dim:(j+1)*group_dim] = (v[:, j*group_dim:(j+1)*group_dim] >= 0.5).float() * 2 - 1
                v_list.append(v)
        v = torch.stack(v_list, dim=0)
        return v

    def set_grad(self, op, in_for_grad, grad):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad)
        tmp.backward(grad)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, ScaledWSConv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def get_spike(self):
        spikes = []
        for module in self.features:
            if isinstance(module, self.single_step_neuron):
                spike = module.spike.cpu()
                spikes.append(spike.reshape(spike.shape[0], -1))
        return spikes


cfgs = {
    'A': [[128, 3], 'M', [256, 3], 'M', [512, 3], 'M', [512, 3]],
    'B': [[64, 3], [128, 3], 'M', [256, 3], [256, 3], 'M', [512, 3], [512, 3], 'M', [512, 3], [512, 3]],
}


def fgtonline_spiking_cnn(**kwargs):
    return FGTOnlineSpikingCNN(cfgs['A'], weight_standardization=False, **kwargs)


def fgtonline_spiking_cnn_ws(**kwargs):
    return FGTOnlineSpikingCNN(cfgs['A'], weight_standardization=True, **kwargs)


def fgtonline_spiking_dcnn_ws(**kwargs):
    return FGTOnlineSpikingCNN(cfgs['B'], weight_standardization=True, **kwargs)
