import torch
import numpy as np
import ShiftingWindowSetting as sw


# turn model into eval mode to make sure batch norm stats are not effected
def empirical_fisher_diag(model, loss_fn, data_loader, device, nullClasses):
    model.eval()
    F = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            F[name] = torch.zeros_like(param, device=device)

    for X, Y in data_loader:
        model.zero_grad()
        X, Y = X.to(device), Y.to(device)
        loss = loss_fn(sw.calc_model_output(model, X, nullClasses), Y)
        loss.backward()
        for name, param in model.named_parameters():
            if param.requires_grad:
                F[name] += param.grad**2
                F[name].detach_()

    for name in F:
        F[name] /= len(data_loader)

    model.zero_grad()
    model.train()

    return F


# pass in fim of a layers weights to calc nodes which are "active"
def calc_active_filters_and_sum_of_fim(fim, device):
    fim = fim.detach()
    if device != "cpu":
        fim = fim.cpu()
    fim = fim.numpy()
    sum = fim.sum()
    fim = fim / sum
    fim = fim.reshape((fim.shape[0], -1))

    # this is the expected sum of normalised fisher information terms for each filter
    expected_filter_Sum = 1 / fim.shape[0]
    return fim.sum(axis=1) >= expected_filter_Sum, sum


# be careful as a filter is sen as "active" if it relatively so in its layer for that task but does not take
# into account that another tasks might have a higher absolute fim for that filter but in that task the filter is not
# relatively important so is not marked as active affective affecting our shared and new active filter metrics
def analyse_fim(fim, previous_fims, device):
    for name, val in fim.items():
        if "IC" in name or "bn" in name or "bias" in name:
            continue
        print(name)

        active_filters, fim_sum = calc_active_filters_and_sum_of_fim(val, device)
        sum_of_filters_active = np.sum(active_filters)
        avg_fim = fim_sum/(val.shape[0]*val.shape[1])
        print("prop of filters active: "+str(sum_of_filters_active/val.shape[0]))
        print("average fim per layer: "+str(avg_fim))
        shared_filters_across_all_tasks = np.array(active_filters)
        new_active_filters = np.array(active_filters)
        used_filters = np.array(active_filters)

        for i, old_fim in enumerate(previous_fims):
            old_active_filters, old_fim_sum = calc_active_filters_and_sum_of_fim(old_fim[name], device)
            shared_filters_across_all_tasks &= old_active_filters
            if i < len(previous_fims)-1:
                new_active_filters &= np.logical_not(old_active_filters)
            used_filters |= old_active_filters
            prop_of_shared_filters = np.sum(old_active_filters & active_filters)/sum_of_filters_active
            print("prop of shared filters (in active filters) with task "+str(i)+": "+str(prop_of_shared_filters))
            #print("relative difference in average fim with task "+str(i)+": "
            #      + str(1 - avg_fim/(old_fim_sum/(val.shape[0]*val.shape[1]))))

        print("prop of shared filters across all seen tasks: "
              + str(np.sum(shared_filters_across_all_tasks)/val.shape[0]))
        print("prop of filters used only by the current task: "
              + str(np.sum(new_active_filters)/val.shape[0]))
        print("prop of filters used by any task: "
              + str(np.sum(used_filters)/val.shape[0]))

        #plt.imshow(val)
        #plt.show()


class EWC_reg(sw.CLLearningAlgo):

    prev_param_list = []
    fim_list = []

    def __init__(self, args, reg_coef=10):
        super().__init__(args=args)
        self.reg_coef = reg_coef

    def calc_reg_loss_term(self):
        reg_loss = torch.zeros(1, device=self.device)
        for fim, prev_params in zip(self.fim_list, self.prev_param_list):
            for name, param in self.model.named_parameters():
                if param.requires_grad:
                    reg_loss += torch.sum(fim[name] * (param - prev_params[name])**2)
        return reg_loss

    def at_end_of_task(self):
        fim = empirical_fisher_diag(self.model, self.loss_fn, self.data_loader, self.device, self.nullClasses)
        old_params = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                old_params[name] = param.clone()
        self.prev_param_list.append(old_params)

        self.fim_list.append(fim)

    def capcity_reg(self):
        reg_loss = torch.zeros(1, device=self.device)
        for fim, prev_params in zip(self.fim_list, self.prev_param_list):
            for name, param in self.model.named_parameters():
                if param.requires_grad and ("conv" in name) and ("layer" in name):
                    #print(2500*fim[name])
                    mean = torch.mean(fim[name])
                    factor = 1.0
                    modifier = -(torch.clamp(fim[name], max=factor*mean) - factor*mean*torch.ones_like(fim[name], device=self.device)) # 125
                    #print(modifier)
                    reg_loss += torch.sum(1/len(self.fim_list)*modifier*torch.abs(param))
        return reg_loss


# This class computes the contstnt memory version of EWC where you
# sum over the fim for the pervious tasks and shrink towards the last tasks params
class EWC_constMemReg(sw.CLLearningAlgo):

    prev_param = {}
    acc_fim = {}

    def __init__(self, args, reg_coef=10):
        super().__init__(args=args)
        self.reg_coef = reg_coef
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.acc_fim[name] = torch.zeros_like(param, device=self.device)

    def at_end_of_task(self):
        fim = empirical_fisher_diag(self.model, self.loss_fn, self.data_loader, self.device, self.nullClasses)
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.acc_fim[name] += fim[name]
                self.prev_param[name] = param.clone()

    def calc_reg_loss_term(self):
        reg_loss = torch.zeros(1, device=self.device)

        if self.prev_param == {}:
            return reg_loss

        for name, param in self.model.named_parameters():
            if param.requires_grad:
                reg_loss += torch.sum(self.acc_fim[name] * (param - self.prev_param[name]) ** 2)
        return reg_loss

