"""
Regularization based methods
- Online L2 Reg
- EWC
"""
import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters

from algorithms.ocl import OCLAlgorithm
from algorithms.alg_utils import ConcatDataset, ConcatBatchSampler
from utils import key_is_none


class L2Reg(OCLAlgorithm):
    """
    Online L2 Regularization for reducing forgetting

    Config:
        config.epochs              - Number of epochs for other batches
        config.lbd                 - L2 regularization level
        All configs for ERM
    """
    def __init__(self, model, config):
        super().__init__(model, config)

    def __call__(self, t: int, feedback: dict) -> None:
        super().__call__(t, feedback)
        if t == 0:
            self.w0 = parameters_to_vector(self.model.parameters()).detach()

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']

        epochs = self.config.epochs
        self.erm(epochs, batch_labeled, loss_func=self.loss_func)

    def loss_func(self, x, y, outputs):
        w = parameters_to_vector(self.model.parameters())
        loss = self.criterion(outputs, y) + self.config.lbd / 2 * ((w - self.w0) ** 2).sum()
        return loss


class EWC(OCLAlgorithm):
    """
    Elastic Weight Consolidation (EWC)
    Adds a Fisher Information matrix diagonal to L2 regularization
    so that more influential model weights change less
    We use the regression set for EWC

    Config:
    config.epochs              - Number of epochs for other batches
    config.lbd                 - L2 regularization level
    config.kr_size             - kr_size for estimating Fisher info mat
    All configs for ERM
    """
    def __init__(self, model, config):
        super().__init__(model, config)

    def __call__(self, t: int, feedback: dict) -> None:
        super().__call__(t, feedback)
        if t == 0:
            self.w0 = parameters_to_vector(self.model.parameters()).detach()

    def train(self, t: int, feedback: dict) -> None:
        if key_is_none(feedback, 'batch_labeled'):
            return
        batch_labeled = feedback['batch_labeled']

        dataset_kr = self.mem['train_regression']
        datasets = [batch_labeled, dataset_kr]
        samplers = [RandomSampler(batch_labeled),
                    RandomSampler(dataset_kr, replacement=True, num_samples=(len(batch_labeled) * self.config.kr_size))]
        index_change = [1, self.config.kr_size]
        concat_dataset = ConcatDataset(datasets)
        concat_batch_sampler = ConcatBatchSampler(samplers, self.config.batch_size,
                                                  concat_dataset.offsets, index_change)
        concat_loader = DataLoader(concat_dataset, batch_sampler=concat_batch_sampler,
                                   **self.config.loader_kwargs)

        epochs = self.config.epochs
        self.erm(epochs, loader=concat_loader, n=len(batch_labeled), loss_func=self.loss_func)
        # w = parameters_to_vector(self.model.parameters())
        # print(((w - self.w0) ** 2).sum().item())

    def loss_func(self, x, y, outputs):
        w = parameters_to_vector(self.model.parameters())

        # Fisher information matrix estimation
        fim = torch.zeros_like(w)
        for i in range(self.config.kr_size):
            _, (x0, y0) = next(self.loader_enum)
            x0, y0 = x0.to(self.config.device), y0.to(self.config.device)
            z0 = self.model(x0)
            loss = self.criterion(z0, y0)
            self.optimizer.zero_grad()
            loss.backward()
            g = parameters_to_vector([p.grad for p in self.model.parameters()]).detach()
            fim += g ** 2
        fim /= self.config.kr_size

        assert torch.all(fim >= 0)
        fim = fim / torch.linalg.norm(fim, 1) * len(fim) # Normalize
        loss = self.criterion(outputs, y) + self.config.lbd / 2 * (fim * ((w - self.w0) ** 2)).sum()
        return loss
