import torch
import torch.nn as nn
from copy import deepcopy
from modules import layer

__all__ = [
    'opzo_spiking_fcn',
]

class Scale(nn.Module):

    def __init__(self, scale):
        super(Scale, self).__init__()
        self.scale = scale

    def forward(self, x):
        return x * self.scale


class OPZOSpikingFCN(nn.Module):

    def __init__(self, d_in=784, d_hidden=[800, 800], num_classes=10, weight_standardization=False, spiking_neuron: callable=None, feedback_mode='PZO', momentum_fb=0.99999, p_scale=0.2, local_loss=False, drop_rate=0.0, **kwargs):
        super(OPZOSpikingFCN, self).__init__()
        self.neuron = spiking_neuron

        if weight_standardization:
            linear = layerWSLinear
        else:
            linear = nn.Linear

        layers = []
        in_dim = d_in
        for dim in d_hidden:
            layers += [linear(in_dim, dim), spiking_neuron(**deepcopy(kwargs))]
            if drop_rate > 0.:
                layers += [layer.Dropout(drop_rate)]
            if weight_standardization:
                layers += [Scale(2.74)]
            in_dim = dim

        self.features = layer.OTTTSequential(*layers)

        self.classifier = layer.OTTTSequential(nn.Linear(d_hidden[-1], num_classes))

        self._initialize_weights()

        # DFA or PZO or ZO or DKP
        self.feedback_mode = feedback_mode
        self.momentum_fb = momentum_fb
        self.p_scale = p_scale
        if feedback_mode in ['DFA', 'PZO', 'DKP']:
            self.fb = nn.ModuleList()
            for dim in d_hidden:
                self.fb.append(nn.Linear(num_classes, dim, bias=False))

            if feedback_mode == 'DFA':
                for m in self.fb:
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
            else:
                for m in self.fb:
                    nn.init.constant_(m.weight, 0)

        # for antithetic
        self.perturb_var_list = []

        # local loss
        self.local_loss = local_loss
        if local_loss:
            self.local_fc = nn.ModuleList()
            for dim in d_hidden:
                self.local_fc.append(nn.Linear(dim, num_classes))

            self.ll_lambda = 0.01

    def forward(self, x, loss_func=None, bp=False, deterministic=True, antithetic=True, perturb_before_neuron=False, only_local_loss=False):

        x = torch.flatten(x, 1)

        if bp:
            x = self.features(x)
            x = self.classifier(x)

            if loss_func is not None:
                loss = loss_func(x)
                if self.training:
                    loss.backward()
                return x, loss.item()
            else:
                return x
        elif not self.training:
            for module in self.features:
                if (not deterministic) and isinstance(module, self.neuron):
                    perturb = self.get_perturb(x, self.p_scale)
                    if perturb_before_neuron:
                        x = module(x + perturb)
                    else:
                        x = module(x) + perturb
                else:
                    x = module(x)
            x = self.classifier(x)

            if loss_func is not None:
                loss = loss_func(x)
                return x, loss.item()
            else:
                return x
        else:
            in_for_grad_list = []
            perturb_list = []
            sg_list = []
            if antithetic:
                if len(self.perturb_var_list) > 0:
                    use_antithetic = True
                    self.perturb_var_index = 0
                else:
                    use_antithetic = False
            if self.local_loss:
                local_output = []

            if self.feedback_mode == 'DKP':
                dkp_for_grad_list = []

            def sequential_forward(modules, x):
                for module in modules:
                    if len(list(module.parameters())) > 0: # Conv/Linear
                        if isinstance(x, list):
                            in_for_grad_list.append(x[1])
                            x = x[0]
                        else:
                            in_for_grad_list.append(x)
                        x = module(x)
                        if self.feedback_mode in ['PZO', 'ZO'] and not only_local_loss and perturb_before_neuron:
                            # perturbation
                            if antithetic:
                                if use_antithetic:
                                    perturb = -self.perturb_var_list[self.perturb_var_index]
                                    self.perturb_var_index += 1
                                else:
                                    perturb = self.get_perturb(x, self.p_scale)
                                    self.perturb_var_list.append(perturb)
                            else:
                                perturb = self.get_perturb(x, self.p_scale)
                            x += perturb
                            perturb_list.append(perturb)
                    elif isinstance(module, self.neuron):
                        x, sg = module(x, return_grad=True)
                        sg_list.append(sg)
                        if self.local_loss:
                            local_output.append(x[0].flatten(1))
                        if self.feedback_mode in ['PZO', 'ZO'] and not only_local_loss and not perturb_before_neuron:
                            # perturbation
                            if antithetic:
                                if use_antithetic:
                                    perturb = -self.perturb_var_list[self.perturb_var_index]
                                    self.perturb_var_index += 1
                                else:
                                    perturb = self.get_perturb(x[0], self.p_scale)
                                    self.perturb_var_list.append(perturb)
                            else:
                                perturb = self.get_perturb(x[0], self.p_scale)
                            x[0] += perturb
                            x[1] += perturb
                            perturb_list.append(perturb)
                        if self.feedback_mode == 'DKP':
                            dkp_for_grad_list.append(x[1])
                    else:
                        if isinstance(x, list):
                            x = layer.SpikeTraceOp(module)(x)
                        else:
                            x = module(x)
                return x

            # forward propagation
            with torch.no_grad():
                x = sequential_forward(self.features, x)
                x = sequential_forward(self.classifier, x)
                if self.feedback_mode in ['PZO', 'ZO'] and perturb_before_neuron:
                    # the last layer do not need perturbation
                    x = x - perturb_list[-1]
                output = x.clone().detach()

            # feedback signals
            with torch.enable_grad():
                x.requires_grad_(True)
                loss = loss_func(x)
                if self.local_loss:
                    loss_local = 0
                    for i in range(len(local_output)):
                        local_output[i].requires_grad_(True)
                        loss_local += loss_func(self.local_fc[i](local_output[i]))
            loss.backward()
            if self.local_loss:
                loss_local.backward()
                
            grad_last = x.grad.data

            # update feedback connections
            if self.feedback_mode == 'PZO' and not only_local_loss:
                for i in range(len(self.fb)):
                    new_weight = perturb_list[i].flatten(1).t().mm(output) / output.shape[0] / self.p_scale
                    self.fb[i].weight.data *= self.momentum_fb
                    self.fb[i].weight.data += (1 - self.momentum_fb) * new_weight
            elif self.feedback_mode == 'DKP' and not only_local_loss:
                for i in range(len(self.fb)):
                    self.set_grad(self.fb[i], grad_last, dkp_for_grad_list[i].flatten(1))

            # feedback propagation
            with torch.no_grad():
                if not only_local_loss:
                    if self.feedback_mode == 'PZO':
                        for i in range(len(self.fb)):
                            perturb_list[i] = self.fb[i](grad_last).reshape(perturb_list[i].shape) * sg_list[i]
                    elif self.feedback_mode in ['DFA', 'DKP']:
                        for i in range(len(self.fb)):
                            perturb_list.append(self.fb[i](grad_last).reshape(sg_list[i].shape) * sg_list[i])
                    else:
                        for i in range(len(perturb_list)):
                            perturb_list[i] = perturb_list[i] * loss * sg_list[i] / self.p_scale

                    if self.local_loss:
                        for i in range(len(local_output)):
                            perturb_list[i] += self.ll_lambda * local_output[i].grad.reshape(sg_list[i].shape) * sg_list[i]
                else:
                    assert self.local_loss
                    for i in range(len(local_output)):
                        perturb_list.append(local_output[i].grad.reshape(sg_list[i].shape) * sg_list[i])

            # set grad
            index = 0
            for module in self.features:
                if isinstance(module, nn.Linear):
                    self.set_grad(module, in_for_grad_list[index], perturb_list[index])
                    index += 1
            self.set_grad(self.classifier[-1], in_for_grad_list[index], grad_last)

            if antithetic and use_antithetic:
                self.perturb_var_list = []

            return x, loss.item()

    def get_perturb(self, x, p_scale):
        perturb = torch.randn_like(x)
        return perturb * p_scale

    def set_p_scale(self, p_scale):
        self.p_scale = p_scale

    def set_grad(self, op, in_for_grad, grad):
        with torch.enable_grad():
            in_for_grad.requires_grad_(True)
            tmp = op(in_for_grad)
        tmp.backward(grad)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def get_grads(self):
        grads = []
        for module in self.features:
            if isinstance(module, nn.Linear):
                grads.append(module.weight.grad.cpu().detach())
        return grads



def opzo_spiking_fcn(**kwargs):
    return OPZOSpikingFCN(**kwargs)
