from copy import deepcopy

import torch
from torch import optim

from methods.base import AdaptiveCL


class OnlineEWC(AdaptiveCL):
    """
    A class that implements the Rwalk method for continual learning.
    """

    def __init__(
        self,
        cl_type,
        model_class,
        n_classes,
        n_tasks: int,
        lr: float,
        device: torch.device,
        alpha=0.9,
        Lambda=0.1,
        **kwargs
    ):
        super().__init__(cl_type, n_classes, n_tasks, device)
        self.online_ewc = model_class(self.n_class_each_task).to(device)
        self.online_ewc_opt = optim.SGD(self.online_ewc.parameters(), lr=lr)
        self.method_name = "OnlineEWC"
        self.alpha = alpha
        self.Lambda = Lambda
        self.fisher = deepcopy(self.online_ewc)
        for p in self.fisher.parameters():
            p.requires_grad = False
            p.data = torch.zeros_like(p.data, device=device)
        self.fisher_current = deepcopy(self.fisher)
        self.mean = deepcopy(self.fisher)
        self.current_task_index = None
        self.task_change = 0
        self.old_parameters = None

    def predict(self, inputs: torch.Tensor, task_index) -> torch.Tensor:
        """
        Make predictions on the input data using the fine-tuned model.
        """
        return self.online_ewc(inputs)

    def update(
        self, inputs: torch.Tensor, labels: torch.Tensor, task_index: int, test=False
    ):
        """
        Update the model based on the input and target data.
        """
        if self.current_task_index is None:
            self.current_task_index = task_index
        else:
            self.task_change = 1 if self.current_task_index != task_index else 0
        self.current_task_index = task_index

        self.online_ewc_opt.zero_grad()
        loss = self.criterion(self.online_ewc(inputs), labels)
        loss.backward()
        # update fisher current
        with torch.no_grad():
            grad_dict = dict(self.online_ewc.named_parameters())
            for name, p in self.mean.named_parameters():
                p.data += (1 - self.alpha) * (grad_dict[name].data - p.data)
            mean_dict = dict(self.mean.named_parameters())
            for name, p in self.fisher_current.named_parameters():
                p.data += (1 - self.alpha) * (
                    (grad_dict[name].grad.data - mean_dict[name].data) ** 2 - p.data
                )

            if self.task_change:
                self.fisher = deepcopy(self.fisher_current)
                self.old_parameters = deepcopy(self.online_ewc)
        # regularize
        if self.old_parameters is not None:
            fisher_dict = dict(self.fisher.named_parameters())
            old_para_dict = dict(self.old_parameters.named_parameters())
            for name, p in self.online_ewc.named_parameters():
                p.grad.data += (
                    self.Lambda
                    * fisher_dict[name].data
                    * (p.data - old_para_dict[name].data)
                )

        self.online_ewc_opt.step()

    def before_fewshot_test(self):
        super().before_fewshot_test()

    def get_models(self) -> list:
        """
        Return the current model.
        """
        return [self.online_ewc]

    def mode(self, is_train: bool = True):
        """
        Set the model to training or evaluation mode.
        """
        self.online_ewc.train() if is_train else self.online_ewc.eval()
