from copy import deepcopy
import torch
from torch import nn
from torch import autograd
from torch.nn import functional as F
from torch.autograd import Variable
import torch.utils.data

import os
import sys
import pickle
from collections import defaultdict
from tqdm import tqdm
import random

torch.backends.cudnn.enabled = False


def variable(t: torch.Tensor, use_cuda=True, **kwargs):
    if torch.cuda.is_available() and use_cuda:
        t = t.cuda()
    return Variable(t, **kwargs)

class EWC(object):
    def __init__(self, dicts_path):
        self.dicts_path = dicts_path
        if os.path.exists(self.dicts_path):
            self.load_dicts()
        else:
            self.fisher_dict = {}
            self.optpar_dict = {}


    def estimate_fisher(self, model, data_loader, run_batch_fn, task_id):
        model.eval()
        model.zero_grad()


        self.fisher_dict[task_id] = defaultdict(float)
        self.optpar_dict[task_id] = defaultdict(float)

        param_names = [
            n.replace('.', '__') for n, p in model.named_parameters() if p.requires_grad
        ]

        for batch in tqdm(data_loader, total=len(data_loader)):
            loss = run_batch_fn(batch)
            # import IPython
            # IPython.embed()
            # for l in loss:
            x = autograd.grad(loss, model.parameters(), retain_graph=True)
            for name, g in zip(param_names, x):
                self.fisher_dict[task_id][name] += g ** 2

        for name, param in model.named_parameters():
            name = name.replace('.', '__')
            self.optpar_dict[task_id][name] = param.data.clone()
            self.fisher_dict[task_id][name] = self.fisher_dict[task_id][name] / float(data_loader.__len__())

    def load_dicts(self):
        with open(self.dicts_path, 'rb') as fp:
            dicts = pickle.load(fp)
            self.fisher_dict, self.optpar_dict = dicts['fisher'], dicts['optpar']

    def save_dicts(self):
        with open(self.dicts_path, 'wb') as fp:
            pickle.dump({'fisher': self.fisher_dict, 'optpar': self.optpar_dict}, fp)

    def penalty(self, model: nn.Module, task_id, w_decay=1, penalty_mode='single'):
        if penalty_mode == 'single':
            return self.single_penalty(model, task_id, 1)
        elif penalty_mode == 'all':
            return self.multi_penalty(model, task_id, w_decay)
        elif penalty_mode == 'random':
            return self.random_penalty(model, task_id, w_decay)

    def single_penalty(self, model: nn.Module, task_id, w_decay=1):
        ewc_penalty = 0
        w = 1
        for task in range(task_id - 1, task_id - 2, -1):
            for name, param in model.named_parameters():
                name = name.replace('.', '__')
                fisher = self.fisher_dict[task][name]
                optpar = self.optpar_dict[task][name]
                try:
                    ewc_penalty += w * (fisher * (optpar - param) ** 2).sum()
                    w = w * w_decay
                except:
                    import IPython;
                    IPython.embed()
        return ewc_penalty

    def multi_penalty(self, model: nn.Module, task_id, w_decay):
        ewc_penalty = 0
        w = 1
        for task in range(task_id - 1, -1, -1):
            for name, param in model.named_parameters():
                name = name.replace('.', '__')
                fisher = self.fisher_dict[task][name]
                optpar = self.optpar_dict[task][name]
                ewc_penalty += w * (fisher * (optpar - param) ** 2).sum()
                w = w * w_decay
        return ewc_penalty

    def random_penalty(self, model: nn.Module, task_id, w_decay):
        ewc_penalty = 0
        weights = [1]
        for task in range(task_id - 1, -1, -1):
            weights.append(weights[-1] * w_decay)

        random.shuffle(weights)
        for task in range(task_id - 1, -1, -1):
            w = weights[task]
            for name, param in model.named_parameters():
                name = name.replace('.', '__')
                fisher = self.fisher_dict[task][name]
                optpar = self.optpar_dict[task][name]
                ewc_penalty += w * (fisher * (optpar - param) ** 2).sum()
                # w = w*w_decay
        return ewc_penalty

        # Now we need to generate the the EWC updated loss

