import torch
import torch.nn as nn
from copy import deepcopy
from modules import layer

__all__ = [
    'OPZOSpikingCNN',
    'opzo_spiking_cnnws', 
    'opzo_spiking_dcnnws', 
]

class Scale(nn.Module):

    def __init__(self, scale):
        super(Scale, self).__init__()
        self.scale = scale

    def forward(self, x):
        return x * self.scale


class OPZOSpikingCNN(nn.Module):

    def __init__(self, cfg, weight_standardization=True, num_classes=10, spiking_neuron: callable = None, fc_hw=1, c_in=3, feedback_mode='PZO', momentum_fb=0.99999, p_scale=0.2, h_in=32, w_in=32, local_loss=False, drop_rate=0.0, p_type='Gaussian', **kwargs):
        super(OPZOSpikingCNN, self).__init__()
        self.neuron = spiking_neuron
        self.features, c_final = self.make_layers(cfg=cfg, ws=weight_standardization, c_in=c_in, neuron=spiking_neuron, drop_rate=drop_rate, **kwargs)
        self.classifier = layer.OTTTSequential(
            nn.AdaptiveAvgPool2d((fc_hw, fc_hw)),
            nn.Flatten(1),
            nn.Linear(c_final*(fc_hw**2), num_classes),
        )
        self._initialize_weights()

        # DFA or PZO or ZO or DKP
        self.feedback_mode = feedback_mode
        self.momentum_fb = momentum_fb
        self.p_scale = p_scale
        self.p_type = p_type
        if feedback_mode in ['DFA', 'PZO', 'DKP']:
            self.fb = nn.ModuleList()
            hw_dim = h_in * w_in
            for v in cfg:
                if v == 'M':
                    hw_dim = hw_dim // 4
                else:
                    self.fb.append(nn.Linear(num_classes, hw_dim*v[0], bias=False))
            if feedback_mode == 'DFA':
                for m in self.fb:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
            else:
                for m in self.fb:
                    nn.init.constant_(m.weight, 0)

        # for antithetic
        self.perturb_var_list = []

        # local loss
        self.local_loss = local_loss
        if local_loss:
            self.local_fc = nn.ModuleList()
            hw_dim = h_in * w_in
            for v in cfg:
                if v == 'M':
                    hw_dim = hw_dim // 4
                else:
                    self.local_fc.append(nn.Linear(hw_dim*v[0], num_classes))

            self.ll_lambda = 0.01

    def forward(self, x, loss_func=None, bp=False, deterministic=True, antithetic=True, perturb_before_neuron=False, only_local_loss=False, zo_bias=0.):
        if bp and self.local_loss and self.training:
            assert loss_func is not None
            # first bp
            x_bp = self.features(x)
            x_bp = self.classifier(x_bp)
            loss = loss_func(x_bp)
            loss.backward()
            loss_all = loss.item()
            # then local loss
            ll_loss = 0
            ll_idx = 0
            for module in self.features:
                if len(list(module.parameters())) > 0: # Conv/Linear
                    if not isinstance(x, list):
                        x = module(x)
                    else:
                        x = layer.GradwithTrace(module)(x)
                elif isinstance(module, self.neuron):
                    x = module(x)
                    # calculate local loss
                    ll_loss += loss_func(self.local_fc[ll_idx](x[0].flatten(1))) * self.ll_lambda
                    ll_idx += 1
                    # detach
                    x[0] = x[0].detach()
                else:
                    if isinstance(x, list):
                        x = layer.SpikeTraceOp(module)(x)
                    else:
                        x = module(x)
            ll_loss.backward()
            loss_all += ll_loss.item()
            return x_bp, loss_all
        elif bp:
            x = self.features(x)
            x = self.classifier(x)
            
            if loss_func is not None:
                loss = loss_func(x)
                if self.training:
                    loss.backward()
                return x, loss.item()
            else:
                return x
        elif not self.training:
            for module in self.features:
                if (not deterministic) and isinstance(module, self.neuron):
                    perturb = self.get_perturb(x, self.p_scale)
                    if perturb_before_neuron:
                        x = module(x + perturb)
                    else:
                        x = module(x) + perturb
                else:
                    x = module(x)
            x = self.classifier(x)
            
            if loss_func is not None:
                loss = loss_func(x)
                return x, loss.item()
            else:
                return x
        else:
            in_for_grad_list = []
            perturb_list = []
            sg_list = []
            if antithetic:
                if len(self.perturb_var_list) > 0:
                    use_antithetic = True
                    self.perturb_var_index = 0
                else:
                    use_antithetic = False
            if self.local_loss:
                local_output = []

            if self.feedback_mode == 'DKP':
                dkp_for_grad_list = []

            def sequential_forward(modules, x):
                for module in modules:
                    if len(list(module.parameters())) > 0: # Conv/Linear
                        if isinstance(x, list):
                            in_for_grad_list.append(x[1])
                            x = x[0]
                        else:
                            in_for_grad_list.append(x)
                        x = module(x)
                        if self.feedback_mode in ['PZO', 'ZO'] and not only_local_loss and perturb_before_neuron:
                            # perturbation
                            if antithetic:
                                if use_antithetic:
                                    perturb = -self.perturb_var_list[self.perturb_var_index]
                                    self.perturb_var_index += 1
                                else:
                                    perturb = self.get_perturb(x, self.p_scale)
                                    self.perturb_var_list.append(perturb)
                            else:
                                perturb = self.get_perturb(x, self.p_scale)
                            x += perturb
                            perturb_list.append(perturb)
                    elif isinstance(module, self.neuron):
                        x, sg = module(x, return_grad=True)
                        sg_list.append(sg)
                        if self.local_loss:
                            local_output.append(x[0].flatten(1))
                        if self.feedback_mode in ['PZO', 'ZO'] and not only_local_loss and not perturb_before_neuron:
                            # perturbation
                            if antithetic:
                                if use_antithetic:
                                    perturb = -self.perturb_var_list[self.perturb_var_index]
                                    self.perturb_var_index += 1
                                else:
                                    perturb = self.get_perturb(x[0], self.p_scale)
                                    self.perturb_var_list.append(perturb)
                            else:
                                perturb = self.get_perturb(x[0], self.p_scale)
                            x[0] += perturb
                            x[1] += perturb
                            perturb_list.append(perturb)
                        if self.feedback_mode == 'DKP':
                            dkp_for_grad_list.append(x[1])
                    else:
                        if isinstance(x, list):
                            x = layer.SpikeTraceOp(module)(x)
                        else:
                            x = module(x)
                return x

            # forward propagation
            with torch.no_grad():
                x = sequential_forward(self.features, x)
                x = sequential_forward(self.classifier, x)
                if self.feedback_mode in ['PZO', 'ZO'] and perturb_before_neuron:
                    # the last layer do not need perturbation
                    x -= perturb_list[-1]
                output = x.clone().detach()

            # feedback signals
            with torch.enable_grad():
                x.requires_grad_(True)
                loss = loss_func(x)
                if self.local_loss:
                    loss_local = 0
                    for i in range(len(local_output)):
                        local_output[i].requires_grad_(True)
                        loss_local += loss_func(self.local_fc[i](local_output[i]))
            loss.backward()
            if self.local_loss:
                loss_local.backward()

            grad_last = x.grad.data

            # update feedback connections
            if self.feedback_mode == 'PZO' and not only_local_loss:
                for i in range(len(self.fb)):
                    new_weight = perturb_list[i].flatten(1).t().mm(output) / output.shape[0] / self.p_scale
                    self.fb[i].weight.data *= self.momentum_fb
                    self.fb[i].weight.data += (1 - self.momentum_fb) * new_weight
            elif self.feedback_mode == 'DKP' and not only_local_loss:
                for i in range(len(self.fb)):
                    self.set_grad(self.fb[i], grad_last, dkp_for_grad_list[i].flatten(1))

            # feedback propagation
            with torch.no_grad():
                if not only_local_loss:
                    if self.feedback_mode == 'PZO':
                        for i in range(len(self.fb)):
                            perturb_list[i] = self.fb[i](grad_last).reshape(perturb_list[i].shape) * sg_list[i]
                    elif self.feedback_mode in ['DFA', 'DKP']:
                        for i in range(len(self.fb)):
                            perturb_list.append(self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i])
                    else:
                        for i in range(len(perturb_list)):
                            perturb_list[i] = perturb_list[i] * (loss - zo_bias) * sg_list[i] / self.p_scale

                    if self.local_loss:
                        for i in range(len(local_output)):
                            perturb_list[i] += self.ll_lambda * local_output[i].grad.reshape(sg_list[i].shape) * sg_list[i]
                else:
                    assert self.local_loss
                    for i in range(len(local_output)):
                        perturb_list.append(local_output[i].grad.reshape(sg_list[i].shape) * sg_list[i])

            # set grad
            index = 0
            for module in self.features:
                if isinstance(module, nn.Conv2d):
                    self.set_grad(module, in_for_grad_list[index], perturb_list[index])
                    index += 1
            self.set_grad(self.classifier[-1], in_for_grad_list[index], grad_last)

            if antithetic and use_antithetic:
                self.perturb_var_list = []

            return x, loss.item()

    def get_perturb(self, x, p_scale):
        if self.p_type == 'Rademacher':
            perturb = torch.rand_like(x)
            perturb = (perturb >= 0.5).float() * 2 - 1
        else: # Gaussian
            perturb = torch.randn_like(x)
        return perturb * p_scale

    def set_p_scale(self, p_scale):
        self.p_scale = p_scale

    def set_grad(self, op, in_for_grad, grad):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad)
        tmp.backward(grad)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    @staticmethod
    def make_layers(cfg, ws=True, c_in=3, neuron: callable = None, drop_rate=0., **kwargs):
        layers = []
        in_channels = c_in
        Conv2d = layer.WSConv2d if ws else nn.Conv2d
        for v in cfg:
            if v == 'M':
                layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = Conv2d(in_channels, v[0], kernel_size=v[1], padding=(v[1]-1)//2)
                layers += [conv2d, neuron(**deepcopy(kwargs))]
                if drop_rate > 0.:
                    layers += [layer.Dropout(drop_rate)]
                if ws:
                    layers += [Scale(2.74)]
                in_channels = v[0]
        return layer.OTTTSequential(*layers), in_channels

    def get_grads(self):
        grads = []
        for module in self.features:
            if isinstance(module, nn.Conv2d):
                grads.append(module.weight.grad.cpu().detach())
        return grads

    def get_spike(self, x, deterministic=True):
        assert not self.training
        spikes = []
        for module in self.features:
            x = module(x)
            if isinstance(module, self.neuron):
                spikes.append(x.clone().detach().cpu().flatten(1))
        return spikes




cfgs = {
    'A': [[128, 3], 'M', [256, 3], 'M', [512, 3], 'M', [512, 3]],
    'B': [[64, 3], [128, 3], 'M', [256, 3], [256, 3], 'M', [512, 3], [512, 3], 'M', [512, 3], [512, 3]],
}


def opzo_spiking_cnnws(**kwargs):
    return OPZOSpikingCNN(cfgs['A'], weight_standardization=True, **kwargs)

def opzo_spiking_dcnnws(**kwargs):
    return OPZOSpikingCNN(cfgs['B'], weight_standardization=True, **kwargs)
