import logging


import torch
import torch.nn as nn


_logger = logging.getLogger(__name__)


class EWCWrapper(nn.Module):

    def __init__(self, backbone, iter_backbone, head_factory, use_importances=True):

        super().__init__()
        self.backbone = backbone
        self.use_importances = use_importances
        self.iter_backbone = iter_backbone
        self.head = head_factory(backbone) # Only one head withall the task logits

        self.importances = dict()
        self.saved_params = dict()
        self.save_params(0)

        # Enable gradients for the artihippo layer
        for param_name, param in self.iter_backbone(self.backbone):
            param.requires_grad = True

    def importance_template(self):
        importances = dict()
        for param_name, param in self.iter_backbone(self.backbone):
            importances[param_name] = torch.ones_like(param)

        # Add the head parameters
        importances["head.weight"] = torch.ones_like(self.head.weight)
        importances["head.bias"] = torch.ones_like(self.head.bias)

        return importances

    @torch.no_grad()
    def save_params(self, task_idx):
        saved_params = dict()
        for param_name, param in self.iter_backbone(self.backbone):
            saved_params[param_name] = param.data.clone()
        
        # Head
        saved_params["head.weight"] = self.head.weight.data.clone()
        saved_params["head.bias"] = self.head.bias.data.clone()

        self.saved_params[task_idx] = saved_params

    def update_importances(self, task_idx, importances):
        self.importances[task_idx] = importances

    def penalty(self, mask=None):

        penalty = 0.
        for task_idx, task_importances in self.importances.items():
            for param_name, param in self.named_parameters():
                # if "head" in param_name:
                #     continue
                if "backbone.head" in param_name:
                    continue
                name = param_name.replace("backbone.", "")
                if name in task_importances.keys():
                    importance = task_importances[name]
                    saved_param = self.saved_params[task_idx][name]

                    penalty = penalty + (importance * (param - saved_param)**2).sum()

        return penalty
    
    def forward(self, x):
        x = self.backbone.forward_features(x)
        x = self.backbone.forward_head(x, pre_logits=True)
        return self.head(x)
