"""
This code is based on the implementation of ER-ACE in the
Avalanche library https://github.com/ContinualAI/avalanche
# Author : Shahar Ariel
# GitHub : https://github.com/shahariel/TEAL
"""

from typing import Callable, List, Optional, Union

import torch
from scipy.sparse.linalg import svds
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer

from avalanche.core import SupervisedPlugin
from avalanche.models.utils import avalanche_forward
from avalanche.training import ACECriterion
from avalanche.training.plugins.evaluation import (
    EvaluationPlugin,
    default_evaluator,
)
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.utils import cycle

from sampling_strategies.SelecetionStrategy import ProbCoverExemplarsSelectionStrategy, \
    SoloTEALExemplarsSelectionStrategy
from sampling_strategies.closest_to_canter import ClosestToCenterSelectionStrategy
from MERS.mers_utils.storage_policy import ExemplarsBuffer
from sampling_strategies.herding import HerdingSelectionStrategy
from sampling_strategies.rainbow_memory import RainbowMemorySelectionStrategy
from sampling_strategies.teal import TEALExemplarsSelectionStrategy

from MERS.sampling_strategies.SelecetionStrategy import MaxHerding
import torch.nn.functional as F
from avalanche.training.plugins import MIRPlugin, GSS_greedyPlugin


class ER_ACE(SupervisedTemplate):
    """
    ER ACE, as proposed in
    "New Insights on Reducing Abrupt Representation
    Change in Online Continual Learning"
    by Lucas Caccia et. al.
    https://openreview.net/forum?id=N8MaByOzUfb

    This version is adapted to non-online scenario,
    the difference with OnlineER_ACE is that it introduces
    all of the exemples from the new classes in the buffer at the
    beggining of the task instead of introducing them progressively.
    """

    def __init__(
        self,
        model: Module,
        optimizer: Optimizer,
        criterion=CrossEntropyLoss(),
        mem_size: int = 200,
        batch_size_mem: int = 10,
        train_mb_size: int = 1,
        train_epochs: int = 1,
        eval_mb_size: Optional[int] = 1,
        device: Union[str, torch.device] = "cpu",
        storage_policy: ExemplarsBuffer = None,
        plugins: Optional[List[SupervisedPlugin]] = None,
        evaluator: Union[
            EvaluationPlugin, Callable[[], EvaluationPlugin]
        ] = default_evaluator,
        eval_every=-1,
        peval_mode="epoch",
        args=None,
        use_gss: bool = False
    ):
        """
        :param model: PyTorch model.
        :param optimizer: PyTorch optimizer.
        :param criterion: loss function.
        :param mem_size: int       : Fixed memory size
        :param batch_size_mem: int : Size of the batch sampled from the buffer
        :param train_mb_size: mini-batch size for training.
        :param train_epochs: number of training epochs.
        :param eval_mb_size: mini-batch size for eval.
        :param device: PyTorch device where the model will be allocated.
        :param plugins: (optional) list of StrategyPlugins.
        :param evaluator: (optional) instance of EvaluationPlugin for logging
            and metric computations. None to remove logging.
        :param eval_every: the frequency of the calls to `eval` inside the
            training loop. -1 disables the evaluation. 0 means `eval` is called
            only at the end of the learning experience. Values >0 mean that
            `eval` is called every `eval_every` epochs and at the end of the
            learning experience.
        :param peval_mode: one of {'epoch', 'iteration'}. Decides whether the
            periodic evaluation during training should execute every
            `eval_every` epochs or iterations (Default='epoch').
        """

        # Instantiate plugin
        er_ace = ER_ACE_Plugin(
            storage_policy=storage_policy, mem_size=mem_size
        )

        # Add plugin to the strategy
        if plugins is None:
            plugins = [er_ace]
        else:
            plugins.append(er_ace)
        if use_gss:
            gss_plugin = GSS_greedyPlugin(
                mem_size=mem_size,
                mem_strength=getattr(args, 'gss_mem_strength', 20),
                input_size=getattr(args, 'gss_input_size', [3, 32, 32]),  # Adjust based on dataset
            )
            plugins.append(gss_plugin)
        self.mem_size = mem_size
        self.batch_size_mem = batch_size_mem

        self.replay_loader = None
        self.ace_criterion = ACECriterion()
        self.args = args
        self.prev_features = None
        self.use_gss = use_gss

        super().__init__(
            model,
            optimizer,
            criterion,
            train_mb_size,
            train_epochs,
            eval_mb_size,
            device,
            plugins,
            evaluator,
            eval_every,
            peval_mode,
        )

    def training_epoch(self, **kwargs):
        """Training epoch.

        :param kwargs:
        :return:
        """
        for self.mbatch in self.dataloader:
            if self._stop_training:
                break

            self._unpack_minibatch()
            self._before_training_iteration(**kwargs)

            if self.replay_loader is not None:
                self.mb_buffer_x, self.mb_buffer_y, self.mb_buffer_tid = next(
                    self.replay_loader
                )
                self.mb_buffer_x, self.mb_buffer_y, self.mb_buffer_tid = (
                    self.mb_buffer_x.to(self.device),
                    self.mb_buffer_y.to(self.device),
                    self.mb_buffer_tid.to(self.device),
                )

            self.optimizer.zero_grad()
            self.loss = self._make_empty_loss()

            # Forward
            self._before_forward(**kwargs)
            self.mb_output = self.forward()
            if self.replay_loader is not None:
                self.mb_buffer_out = avalanche_forward(
                    self.model, self.mb_buffer_x, self.mb_buffer_tid
                )
            self._after_forward(**kwargs)

            # Loss & Backward
            if self.replay_loader is None:
                self.loss += self.criterion()
            else:
                self.loss += self.ace_criterion(
                    self.mb_output,
                    self.mb_y,
                    self.mb_buffer_out,
                    self.mb_buffer_y,
                )


            self._before_backward(**kwargs)
            self.backward()
            if hasattr(self.args, 'sel_strategy') and self.args.sel_strategy == 'rm':
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            if hasattr(self.args, 'sel_strategy') and self.args.sel_strategy == 'gss':
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self._after_backward(**kwargs)

            # Optimization step
            self._before_update(**kwargs)
            self.optimizer_step()
            self._after_update(**kwargs)

            self._after_training_iteration(**kwargs)

    def _before_training_exp(self, **kwargs):
        # Update buffer before training exp so that we have current data in

        self.plugins[1].storage_policy.update(self, **kwargs)
        buffer = self.plugins[1].storage_policy.buffer
        if (
            len(buffer) >= self.batch_size_mem
            and self.experience.current_experience > 0
        ):
            self.replay_loader = cycle(
                torch.utils.data.DataLoader(
                    buffer,
                    batch_size=self.batch_size_mem,
                    shuffle=True,
                    drop_last=True,
                )
            )
        super()._before_training_exp(**kwargs)

    def _after_training_exp(self, **kwargs):
        if self.args.sel_strategy == 'gss':
            print("Using GSS plugin - skipping custom selection strategy")
            super()._after_training_exp(**kwargs)
            return
        if self.args.sel_strategy is None or self.args.sel_strategy == 'random':
            super()._after_training_exp(**kwargs)
            return

        elif self.args.sel_strategy == 'herding':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"Herding selection strategy: classes ", end='')
            self.ss = HerdingSelectionStrategy(self.args)
        elif self.args.sel_strategy == 'teal':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"TEAL selection strategy: classes ", end='')
            self.ss = SoloTEALExemplarsSelectionStrategy(self.args, self.device)
        elif self.args.sel_strategy == 'rm':
            if self.args.dataset=='tinyimg':
                self.args.dataset="TinyImagenet"
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"RM (Uncertainty) selection strategy: classes ", end='')
            self.ss = RainbowMemorySelectionStrategy(self.args, self.device)
            if self.args.dataset =='TinyImagenet':
                self.args.dataset = 'tinyimg'
        elif self.args.sel_strategy == 'centered':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"Centered selection strategy: classes ", end='')
            self.ss = ClosestToCenterSelectionStrategy()
        elif self.args.sel_strategy == 'probcover':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"probcover selection strategy: classes ", end='')
            self.ss = ProbCoverExemplarsSelectionStrategy(self.args, self.device)
        elif self.args.sel_strategy == 'max_herding':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"Max Herding selection strategy: classes ", end='')
            self.ss = MaxHerding(self.args, self.device)
        elif self.args.sel_strategy == 'budget':
            print(f"Updating buffer of the experience {self.experience.current_experience} into "
                  f"Budgeted Coverage selection strategy: classes ", end='')
            from sampling_strategies.budgeted_coverage_strategy import BudgetedCoverageSelectionStrategy
            self.ss = BudgetedCoverageSelectionStrategy(self.args, self.device)
        for c in self.experience.classes_in_this_experience:
            print(f"{c},", end=' ')
            buffer_group = self.plugins[1].storage_policy.buffer_groups[c]
            buffer_group.buffer = buffer_group.buffer.subset([])
            buffer_group.selection_strategy = self.ss
        print()
        self.plugins[1].storage_policy.update(self, **kwargs)

        super()._after_training_exp(**kwargs)

    def _train_cleanup(self):
        super()._train_cleanup()
        # reset the value to avoid serialization failures
        self.replay_loader = None


class ER_ACE_Plugin(SupervisedPlugin):
    def __init__(
        self,
        storage_policy: Optional["ExemplarsBuffer"] = None,
        mem_size: int = 2000,
    ):
        """
        :param storage_policy: The policy that controls how to add new exemplars in memory
        """
        super().__init__()
        self.mem_size = mem_size

        if storage_policy is not None:
            self.storage_policy = storage_policy
        else:
            self.storage_policy = ClassBalancedBuffer(
                max_size=self.mem_size, adaptive_size=True
            )
