import torch
from torch import optim

from methods.base import AdaptiveCL
from utils import Reservoir


class AGEM(AdaptiveCL):
    """
    A class that implements the Average Gradient Episodic Memory (A-GEM) for continual learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        n_classes,
        n_tasks,
        lr: float,
        buffer_size: int,
        agem_batch_size: int,
        device: torch.device,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.agem = model_class(self.n_class_each_task).to(device)
        self.agem_opt = optim.SGD(self.agem.parameters(), lr=lr)
        self.agem_batch_size = agem_batch_size
        self.buffer = Reservoir(buffer_size, device)
        self.method_name = "AGEM"

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        return self.agem(inputs)

    def update(
        self, inputs: torch.Tensor, labels: torch.Tensor, task_index: int, test=False
    ):
        if len(self.buffer) >= self.agem_batch_size:
            inputs_r, labels_r = self.buffer.sample(self.agem_batch_size)
            loss = self._get_loss(inputs_r, labels_r)
            loss.backward()
            grad2, grad2_dict = self.get_gradient(True)
            loss = self._get_loss(inputs, labels)
            loss.backward()
            grad1 = self.get_gradient(False)
            with torch.no_grad():
                dot_g = torch.dot(grad1, grad2)
                if dot_g < 0:
                    c = dot_g / (
                        torch.dot(grad2, grad2) + torch.finfo(torch.float32).eps
                    )
                    for name, para in self.agem.named_parameters():
                        para.grad.data -= c * grad2_dict[name]
        else:
            loss = self._get_loss(inputs, labels)
            loss.backward()

        self.agem_opt.step()
        self.agem_opt.zero_grad()

        # update buffer
        self.buffer.add(zip(inputs, labels.view(-1, 1)))

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def _get_loss(self, inputs: torch.Tensor, labels: torch.Tensor):
        self.agem_opt.zero_grad()
        loss = self.criterion(self.agem(inputs), labels)
        return loss

    def get_models(self) -> list:
        return [self.agem]

    def get_gradient(self, is_dict: bool = False):
        grads = []
        grads_dict = {}
        for name, param in self.agem.named_parameters():
            grads.append(param.grad.data.view(-1))
            if is_dict:
                grads_dict[name] = param.grad.data.detach()
        grads = torch.cat(grads)  # flatten the gradients
        if is_dict:
            return grads, grads_dict
        else:
            return grads

    def mode(self, is_train=True):
        if is_train:
            self.agem.train()
        else:
            self.agem.eval()
