from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

from .Buffer import Buffer
from .base import *
from .utils_model import backbone


class SI(ContinualLearning):
    # https://arxiv.org/pdf/1703.04200
    def __init__(
        self,
        encoder: nn.Module,
        lr=0.001,
        xi: float = 0.1,  # Csi, ξ
        c: float = 0.1,
        cls_output_dim: int = 2,
        num_tasks: int = 10,
        z_dim: int = 512,
        device="cuda",
        **kwargs
    ) -> None:
        encoder = backbone(
            encoder, cls_output_dim=cls_output_dim * num_tasks, z_dim=z_dim
        ).to(device)
        super(SI, self).__init__(encoder, lr, num_tasks, cls_output_dim)
        self.c = c
        self.xi = xi
        self.cls_output_dim = cls_output_dim
        self.big_omega = None
        self.small_omega = 0
        self.lr = lr
        self.device = device
        self.checkpoint = self.encoder.get_params().data.clone().to(self.device)

    def penalty(self):
        if self.big_omega is None:
            return torch.tensor(0.0).to(self.device)
        else:
            penalty = (
                self.big_omega
                * ((self.encoder.get_params() - self.checkpoint) ** 2)
            ).sum()
            return penalty

    def end_task(self, dataloader, task_name, task_id, **kwargs):
        # big omega calculation step
        if self.big_omega is None:
            self.big_omega = torch.zeros_like(self.encoder.get_params()).to(
                self.device
            )

        self.big_omega += self.small_omega / (
            (self.encoder.get_params().data - self.checkpoint) ** 2 + self.xi
        )

        # store parameters checkpoint and reset small_omega
        self.checkpoint = self.encoder.get_params().data.clone().to(self.device)
        self.small_omega = 0

    def get_penalty_grads(self):
        return (
            self.c
            * 2
            * self.big_omega
            * (self.encoder.get_params().data - self.checkpoint)
        )

    def compute_loss(
        self,
        inputs: torch.Tensor,
        labels: torch.Tensor,
        not_aug_inputs: torch.Tensor,
        loss_func: nn.Module,
        transform,
        task_id,
    ) -> torch.Tensor:
        self.optimizer.zero_grad()
        outputs = self.forward(inputs)
        outputs_sliced = outputs[
            :,
            task_id * self.cls_output_dim : task_id * self.cls_output_dim
            + self.cls_output_dim,
        ]
        loss = loss_func(outputs_sliced, labels)
        loss.backward()
        cur_small_omega = self.encoder.get_grads().data
        if self.big_omega is not None:
            loss_grads = self.encoder.get_grads()
            self.encoder.set_grads(loss_grads + self.get_penalty_grads())
        cur_small_omega *= self.lr * self.encoder.get_grads().data
        self.small_omega += cur_small_omega
        nn.utils.clip_grad.clip_grad_value_(self.get_parameters(), 1)
        self.optimizer.step()

        return loss.item()
