
import torch
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
import collections
import numpy
class GradePlugin(SupervisedPlugin):

    def __init__(
        self,
             penalty_lambda=0,
            schedual=None
    ):
        """
        :param ewc_lambda: hyperparameter to weigh the penalty inside the total
               loss. The larger the lambda, the larger the regularization.
        :param mode: `separate` to keep a separate penalty for each previous
               experience.
               `online` to keep a single penalty summed with a decay factor
               over all previous tasks.
        :param decay_factor: used only if mode is `online`.
               It specifies the decay term of the importance matrix.
        :param keep_importance_data: if True, keep in memory both parameter
                values and importances for all previous task, for all modes.
                If False, keep only last parameter values and importances.
                If mode is `separate`, the value of `keep_importance_data` is
                set to be True.
        """
        super().__init__()


        self.penalty = penalty_lambda
        self.schedual = schedual





    def before_backward(self, strategy, **kwargs):
        exp_counter = strategy.clock.train_exp_counter
        strategy.mb_x.requires_grad = True
        strategy.optimizer.zero_grad()
        out_new = strategy.model(strategy.mb_x.to(strategy.device), strategy.mb_task_id.to(strategy.device))
        # ewc_loss = 0
        # for n, p in enumerate( list(strategy.model.parameters())):
        #         if p.ndim == 1 and exp_counter != 0:
        #             _loss = self.prior_task_recorder[exp_counter - 1][1][n] * (p - self.prior_task_recorder[exp_counter - 1][0][n]) ** 2
        #             ewc_loss += _loss.sum()
        loss_1 = torch.nn.functional.cross_entropy(out_new, strategy.mb_y)
        inputs_grad = torch.autograd.grad(loss_1, strategy.mb_x, create_graph=True)[0]
        # print(inputs_grad)
        loss_2 = torch.mean(torch.sum(inputs_grad ** 2), dim=0)

        strategy.loss += self.penalty * loss_2

    def after_update(
        self, strategy, *args, **kwargs
    ) :
        self.schedual.step()

