# https://proceedings.neurips.cc/paper_files/paper/2020/file/b704ea2c39778f07c617f6b7ce480e9e-Paper.pdf

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone


class EWC(ContinualLearning):
    # https://github.com/aimagelab/mammoth/blob/master/models/der.py
    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        temperature: float = 1,
        lambda_: float = 1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        buffer_size: int = 2000,
        input_size: int = 64,
        z_dim: int = 512,
        dataset_name: str = "celeba",
        device="cuda",
        **kwargs
    ) -> None:
        encoder = backbone(
            encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim
        )
        super(EWC, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.e_lambda = lambda_

        self.cls_output_dim = cls_output_dim

        self.gamma = temperature
        self.logsoft = nn.LogSoftmax(dim=1)
        self.checkpoint = None
        self.fish = None
        self.device = device

    def penalty(self):
        if self.checkpoint is None:
            return torch.tensor(0.0).to(self.device)
        else:
            penalty = (
                self.e_lambda
                * (
                    self.fish
                    * ((self.encoder.get_params() - self.checkpoint) ** 2)
                ).sum()
            )
            return penalty

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.encoder(x)
        return z

    def end_task(self, dataloader, task_name, task_id):
        fish = torch.zeros_like(self.encoder.get_params())
        total = 0
        for j, sample in enumerate(dataloader):
            # inputs, labels = data[0], data[1]
            # inputs, labels = inputs.to(self.device), labels.to(self.device)
            inputs = sample["image"].to(self.device)
            labels = sample[task_name].to(self.device)
            for ex, lab in zip(inputs, labels):
                self.optimizer.zero_grad()
                outputs = self.encoder(ex.unsqueeze(0))
                output = outputs[
                    :,
                    task_id
                    * self.cls_output_dim : task_id
                    * self.cls_output_dim
                    + self.cls_output_dim,
                ]
                loss = -F.nll_loss(
                    self.logsoft(output), lab.unsqueeze(0), reduction="none"
                )
                exp_cond_prob = torch.mean(torch.exp(loss.detach().clone()))
                loss = torch.mean(loss)
                loss.backward()
                fish += exp_cond_prob * self.encoder.get_grads() ** 2

            total += len(labels)

        fish /= total

        if self.fish is None:
            self.fish = fish
        else:
            self.fish *= self.gamma
            self.fish += fish

        self.checkpoint = self.encoder.get_params().data.clone()

    def get_penalty_grads(self):
        return (
            self.e_lambda
            * 2
            * self.fish
            * (self.encoder.get_params().data - self.checkpoint)
        )

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:

        # self.optimizer.zero_grad()
        # tot_loss = 0
        # outputs = self.forward(inputs)
        # outputs_sliced = outputs[
        #     :,
        #     task_id * self.cls_output_dim : task_id * self.cls_output_dim
        #     + self.cls_output_dim,
        # ]
        # loss = loss_func(outputs_sliced, labels)
        # loss.backward()

        self.optimizer.zero_grad()
        outputs = self.encoder(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        if self.checkpoint is not None:
            self.encoder.set_grads(self.get_penalty_grads())
        loss = loss_func(outputs_sliced, labels)
        assert not torch.isnan(loss)
        loss.backward()
        self.optimizer.step()

        return loss.item()
