from copy import deepcopy
from typing import *

import torch
from torch import Tensor, nn
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader


class MixinEWC:
    def __init__(self):
        self.model = NotImplemented  # type: nn.Module
        self.model_old = None
        self.criterion = NotImplemented  # type: _Loss
        # self.list__dl_train = NotImplemented  # type: List[DataLoader]
        self.device = NotImplemented  # type: str

        self.fisher = None
    # enddef

    def ewc_compute_loss(self, list__name_startswith: List[str] = None) -> Tensor:
        if list__name_startswith is not None:
            param_names = [n for n, p in self.model.named_parameters()]
            for name in list__name_startswith:
                assert any(pn.startswith(name) for pn in param_names), f'{name} not in {param_names}'
            # endfor
        # endif

        if self.fisher is None:
            return 0
        else:
            # ewc
            l_reg = 0
            for (name, param), (_, param_old) in zip(self.model.named_parameters(), self.model_old.named_parameters()):
                if list__name_startswith is None or any([name.startswith(n) for n in list__name_startswith]):
                    reg = torch.sum(self.fisher[name] * (param_old - param).pow(2)) / 2
                    # print(f'ewc {name}, {reg}')
                    l_reg += reg
                # endif
            # endfor

            return l_reg
        # endif
    # enddef

    def ewc_in_train(self, idx_task: int, dl_train: DataLoader,
                     smax: float, args_on_forward: Dict[str, Any]):
        # self.model_old = deepcopy(self.model)
        # for EWC reg
        if self.model_old is None:
            self.model_old = deepcopy(self.model)
        # endif
        self.model_old.load_state_dict(deepcopy(self.model.state_dict()))
        self.model_old.eval()
        for n, p in self.model_old.named_parameters():
            p.grad = None
            p.requires_grad_(False)
        # endfor

        # fisher computation
        if self.fisher is None:
            fisher_old = None
        else:
            fisher_old = {}
            for n, _ in self.model.named_parameters():
                fisher_old[n] = self.fisher[n].clone()
            # endfor
        # endif
        self.fisher = self._fisher_matrix_diag(idx_task=idx_task, dl_train=dl_train,
                                               smax=smax, args_on_forward=args_on_forward)
        if fisher_old is not None:
            for n, _ in self.model.named_parameters():
                self.fisher[n] = (self.fisher[n] + fisher_old[n] * idx_task) / (idx_task + 1)
            # endfor
        # endif
    # enddef

    def _fisher_matrix_diag(self, idx_task: int, dl_train: DataLoader, smax: float,
                            args_on_forward: Dict[str, Any]) -> Dict[str, Tensor]:
        fisher = dict()  # type: Dict[str, Tensor]
        for n, p in self.model.named_parameters():
            fisher[n] = 0 * p.data
        # endfor

        num_train = len(dl_train.dataset)
        num_batch_train = len(dl_train)
        self.model.train()
        for idx_batch, (x, y) in enumerate(dl_train):
            x = x.to(self.device)
            y = y.to(self.device)

            batch_size = y.shape[0]
            s = 1 / smax + (smax - 1 / smax) * (idx_batch + 1) / num_batch_train

            output, _ = self.model(idx_task, x, s=s, args_on_forward=args_on_forward)
            loss = self.criterion(output, y)
            self.model.zero_grad()
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += batch_size * p.grad.data.pow(2)
                # endif
            # endfor
        # endfor
        for n, _ in self.model.named_parameters():
            fisher[n] /= num_train
            fisher[n] = fisher[n].data.clone()
        # endfor

        return fisher
    # enddef

# enclasss
