import math
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset

from common import set_random_seed
from multihead_learner import ContinualMultiheadMLP

def generate_new_network(shared_structure, in_size, device, convnet=False):
    net_struct = []
    for i, layer in enumerate(shared_structure):
        if not convnet:
            linear = nn.Linear(in_features=shared_structure[i - 1] if i > 0 else in_size,
                                       out_features=layer,
                                       )
            linear.to(device=device)
            net_struct.append((f"layer{i}", linear))
        else:
            if i < len(shared_structure) - 1:
                cnn = nn.Conv2d(in_channels=shared_structure[i - 1] if i > 0 else in_size[0],
                                        out_channels=layer,
                                        kernel_size=5,
                                        )
                cnn.to(device=device)
                net_struct.append((f"cnn{i}", cnn))
                maxpool = nn.MaxPool2d(kernel_size=2)
                maxpool.to(device)
                net_struct.append((f"pooling{i}", maxpool))
            else:
                flat = nn.Flatten()
                flat.to(device)
                net_struct.append(("flatten", flat))
                cnn_in = torch.rand(1, *in_size, device=device)
                output_feat = nn.Sequential(OrderedDict(net_struct))(cnn_in)
                conv_out_size = output_feat.size(1)
                linear = nn.Linear(in_features=conv_out_size,
                                           out_features=layer,
                                           )
                linear.to(device=device)
                net_struct.append((f"layer{i}", linear))
        if not convnet or i < len(shared_structure) - 1:
            net_struct.append((f"activation{i}", nn.Tanh()))
            #net_struct.append((f"activation{i}", nn.ELU()))
        elif convnet:
            net_struct.append((f"activation{i}", nn.Tanh()))
            #net_struct.append((f"activation{i}", nn.ELU()))
    shared_net = nn.Sequential(OrderedDict(net_struct))
    return shared_net

class ContinualMultiheadEWC(ContinualMultiheadMLP):
    def __init__(self, shared_structure, in_size, out_size, device, fisher_estimation_sample_size=16, lamda=40,
                 noise_level=1e-2):
        super().__init__(shared_structure, in_size, out_size, device)
        is_conv = (len(shared_structure) > 1 and shared_structure[0] != shared_structure[-1])
        self.shared_net = generate_new_network(shared_structure, in_size, device, convnet=is_conv)
        self.fisher_estimation_sample_size = fisher_estimation_sample_size
        self.lamda = lamda
        self.noise_level = noise_level
        self.n_MC = 3
        self.batch_size=64
        self.fisher_dict = {}
        self.opt_dict = {}

    def forward(self, task_data, task_id=None):
        return super(ContinualMultiheadEWC, self).forward(task_data, task_id)

    def adapt_new_task(self, task_id=None):
        super(ContinualMultiheadEWC, self).adapt_new_task(task_id)

    def get_previous_training(self, old_task_id, trainset):
        self.opt_dict[old_task_id] = {}
        self.fisher_dict[old_task_id] = {}
        data_loader = DataLoader(TensorDataset(*trainset), batch_size=self.batch_size, shuffle=True)
        for x,y in data_loader:
            x = x.to(self.device)
            y = y.float().to(self.device)
            ce_loss = F.cross_entropy(self(x, task_id=old_task_id), y)
            ce_loss.backward()
            break # only need a small sample for fisher
        # gradients accumulated can be used to calculate fisher
        for name, param in self.shared_net.named_parameters():
            self.opt_dict[old_task_id][name] = param.data.clone()
            self.fisher_dict[old_task_id][name] = param.grad.data.clone().pow(2)

    def ewc_fisher_loss(self):
        loss = 0
        for task_id in self.fisher_dict.keys():
            for name, param in self.shared_net.named_parameters():
                fisher = self.fisher_dict[task_id][name]
                optpar = self.opt_dict[task_id][name]
                loss += (fisher * (optpar - param).pow(2)).sum()
        return loss

    def loss(self, task_data, task_labels, label_loss, task_id=None, specific_seed=None, is_test=False):
        if specific_seed is not None:
            old_seed = torch.random.seed()
            set_random_seed(specific_seed)
            losses = []
            for i in range(self.n_MC):
                noises = {}
                for name, param in self.shared_net.named_parameters():
                    noises[name] = torch.normal(0, self.noise_level, param.shape, device=self.device)
                    param += noises[name]
                losses.append(label_loss(self.forward(task_data, task_id), task_labels))
                for name, param in self.shared_net.named_parameters():
                    param -= noises[name]
            set_random_seed(old_seed)
            return losses

        normal_loss = label_loss(self.forward(task_data, task_id), task_labels)
        if is_test:
            return normal_loss
        else:
            ewc_loss = self.ewc_fisher_loss()
            return normal_loss + self.lamda * ewc_loss

