
import torch
from torch.utils.data import DataLoader
from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
import collections
import numpy
class GPMPlugin(SupervisedPlugin):

    def __init__(
        self,
             penalty_lambda=0
    ):
        """
        :param ewc_lambda: hyperparameter to weigh the penalty inside the total
               loss. The larger the lambda, the larger the regularization.
        :param mode: `separate` to keep a separate penalty for each previous
               experience.
               `online` to keep a single penalty summed with a decay factor
               over all previous tasks.
        :param decay_factor: used only if mode is `online`.
               It specifies the decay term of the importance matrix.
        :param keep_importance_data: if True, keep in memory both parameter
                values and importances for all previous task, for all modes.
                If False, keep only last parameter values and importances.
                If mode is `separate`, the value of `keep_importance_data` is
                set to be True.
        """
        super().__init__()
        self.pc = torch.Tensor([])
        # self.percent=percent
        self.penalty = penalty_lambda

        # self.schedual=schedual
        # self.Pool=Pool

    def renew_pool(self,Pool =None):
        self.Pool = Pool


    def before_backward(self, strategy, **kwargs):
        exp_counter = strategy.clock.train_exp_counter
        strategy.mb_x.requires_grad = True
        strategy.optimizer.zero_grad()
        out_new = strategy.model(strategy.mb_x.to(strategy.device), strategy.mb_task_id.to(strategy.device))
        # ewc_loss = 0
        # for n, p in enumerate( list(strategy.model.parameters())):
        #         if p.ndim == 1 and exp_counter != 0:
        #             _loss = self.prior_task_recorder[exp_counter - 1][1][n] * (p - self.prior_task_recorder[exp_counter - 1][0][n]) ** 2
        #             ewc_loss += _loss.sum()
        loss_1 = torch.nn.functional.cross_entropy(out_new, strategy.mb_y)
        inputs_grad = torch.autograd.grad(loss_1, strategy.mb_x, create_graph=True)[0]
        # print(inputs_grad)
        loss_2 = torch.mean(torch.sum(inputs_grad ** 2), dim=0)

        strategy.loss += self.penalty * loss_2


    def after_backward(
        self, strategy, *args, **kwargs
    ):
        exp_counter = strategy.clock.train_exp_counter
        # print(exp_counter)
        # for parm in list(strategy.model.parameters()):
        #     if parm.ndim == 1 and exp_counter != 0:
        #         parm.grad.data.fill_(0)


        if exp_counter > 0:

            for layer in self.Pool:
                param = list(strategy.model.parameters())[layer]
                # param.grad -= torch.matmul(torch.matmul(param.grad.flatten(-3, -1), self.Pool[layer].T),
                #                            self.Pool[layer]).reshape(param.grad.shape)
                if layer == 10 or layer == 11:
                    param.grad -= torch.matmul(torch.matmul(param.grad.flatten(-3, -1), self.Pool[layer].T),
                                           self.Pool[layer]).reshape(param.grad.shape)
                else:
                    param.grad -= torch.matmul(torch.matmul(param.grad, self.Pool[layer].T), self.Pool[layer])
    # def after_training_epoch(
    #     self, strategy,*args, **kwargs
    # ) :
    #     self.schedual.step()

