import torch
from avalanche.core import SupervisedPlugin
from avalanche.training.plugins import ReplayPlugin
from avalanche.training.plugins import GSS_greedyPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer

class FlashbackLearningPlugin(SupervisedPlugin):
    """
    Flashback Learning plugin that can be integrated with existing CL strategies.
    Based on the 2024/2025 paper on bidirectional regularization for CL.
    Applies regularization during training iterations, not after.
    """

    def __init__(self, flashback_weight=0.1, stability_weight=0.5):
        super().__init__()
        self.flashback_weight = flashback_weight
        self.stability_weight = stability_weight
        self.previous_params = {}
        self.current_task_params = {}

    def before_training_exp(self, strategy, **kwargs):
        """Store current parameters before training new experience"""
        if strategy.clock.train_exp_counter > 0:
            self.previous_params = {
                name: param.clone().detach()
                for name, param in strategy.model.named_parameters()
            }
        # Store initial parameters of current task for plasticity regularization
        self.current_task_params = {
            name: param.clone().detach()
            for name, param in strategy.model.named_parameters()
        }

    def before_backward(self, strategy, **kwargs):
        """Apply flashback regularization to the loss during training"""
        if len(self.previous_params) > 0:
            # Stability regularization: don't drift too far from previous task params
            stability_loss = 0
            for name, param in strategy.model.named_parameters():
                if name in self.previous_params and param.requires_grad:
                    stability_loss += torch.norm(param - self.previous_params[name]) ** 2

            # Plasticity regularization: allow learning new task
            plasticity_loss = 0
            for name, param in strategy.model.named_parameters():
                if name in self.current_task_params and param.requires_grad:
                    plasticity_loss += torch.norm(param - self.current_task_params[name]) ** 2

            # Bidirectional regularization: balance stability and plasticity
            flashback_loss = (self.stability_weight * stability_loss -
                              (1 - self.stability_weight) * plasticity_loss)

            # Add to the current loss
            strategy.loss += self.flashback_weight * flashback_loss



import torch
import torch.nn as nn
import numpy as np
from avalanche.core import SupervisedPlugin
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins.evaluation import default_evaluator
from typing import Optional, List, Union, Callable
from torch.optim import Optimizer
from avalanche.training.plugins.evaluation import EvaluationPlugin
from torch.nn import CrossEntropyLoss


class TSVDPlugin(SupervisedPlugin):
    """
    TSVD (Truncated Singular Value Decomposition) for Continual Learning.

    Based on ICLR 2025 paper: "TSVD: Bridging Theory and Practice in Continual Learning
    with Pre-trained Models" by Liangzu Peng et al.

    Key idea: Use SVD to decompose weight matrices and selectively update components
    to balance stability (old tasks) and plasticity (new tasks).
    """

    def __init__(self,
                 rank_ratio=0.8,  # Ratio of singular values to keep
                 plasticity_weight=0.5,  # Balance between old/new tasks
                 update_frequency=10):  # How often to apply TSVD
        super().__init__()
        self.rank_ratio = rank_ratio
        self.plasticity_weight = plasticity_weight
        self.update_frequency = update_frequency
        self.task_count = 0

        # Store original weights and SVD components
        self.original_weights = {}
        self.svd_components = {}
        self.task_specific_components = {}

    def before_training_exp(self, strategy, **kwargs):
        """Initialize TSVD for new task"""
        self.task_count += 1

        if self.task_count == 1:
            # First task: store original pre-trained weights
            self._store_original_weights(strategy.model)
        else:
            # Subsequent tasks: apply TSVD decomposition
            self._apply_tsvd_decomposition(strategy.model)

    def after_training_iteration(self, strategy, **kwargs):
        """Periodically apply TSVD updates during training"""
        if (strategy.clock.train_iter_counter % self.update_frequency == 0 and
                self.task_count > 1):
            self._selective_weight_update(strategy.model)

    def after_training_exp(self, strategy, **kwargs):
        """Finalize TSVD after task completion"""
        if self.task_count > 1:
            self._consolidate_task_knowledge(strategy.model)

    def _store_original_weights(self, model):
        """Store original pre-trained weights"""
        for name, param in model.named_parameters():
            if 'weight' in name and len(param.shape) == 2:  # Only 2D weight matrices
                self.original_weights[name] = param.data.clone()

    def _apply_tsvd_decomposition(self, model):
        """Apply SVD decomposition to weight matrices"""
        for name, param in model.named_parameters():
            if name in self.original_weights:
                # Decompose current weights using SVD
                W = param.data.cpu().numpy()
                U, S, Vt = np.linalg.svd(W, full_matrices=False)

                # Determine rank to keep
                rank = int(self.rank_ratio * min(W.shape))
                rank = max(1, min(rank, len(S)))

                # Store SVD components
                self.svd_components[name] = {
                    'U': torch.tensor(U[:, :rank], device=param.device, dtype=param.dtype),
                    'S': torch.tensor(S[:rank], device=param.device, dtype=param.dtype),
                    'Vt': torch.tensor(Vt[:rank, :], device=param.device, dtype=param.dtype),
                    'rank': rank
                }

                print(f"Applied TSVD to {name}: shape {W.shape} -> rank {rank}")

    def _selective_weight_update(self, model):
        """Selectively update weights based on TSVD components"""
        for name, param in model.named_parameters():
            if name in self.svd_components and param.requires_grad:
                components = self.svd_components[name]

                # Reconstruct stable component (preserves old knowledge)
                U, S, Vt = components['U'], components['S'], components['Vt']
                stable_weight = torch.mm(U * S.unsqueeze(0), Vt)

                # Get current weight (includes new learning)
                current_weight = param.data

                # Blend stable and current weights
                blended_weight = (
                        (1 - self.plasticity_weight) * stable_weight +
                        self.plasticity_weight * current_weight
                )

                # Update parameter
                param.data.copy_(blended_weight)

    def _consolidate_task_knowledge(self, model):
        """Consolidate knowledge after task completion"""
        for name, param in model.named_parameters():
            if name in self.svd_components:
                # Re-decompose updated weights
                W = param.data.cpu().numpy()
                U, S, Vt = np.linalg.svd(W, full_matrices=False)

                components = self.svd_components[name]
                rank = components['rank']

                # Update stored components
                self.svd_components[name] = {
                    'U': torch.tensor(U[:, :rank], device=param.device, dtype=param.dtype),
                    'S': torch.tensor(S[:rank], device=param.device, dtype=param.dtype),
                    'Vt': torch.tensor(Vt[:rank, :], device=param.device, dtype=param.dtype),
                    'rank': rank
                }


class AdaptiveTSVDPlugin(SupervisedPlugin):
    """
    Enhanced TSVD with adaptive rank selection and task-aware updates
    """

    def __init__(self,
                 initial_rank_ratio=0.8,
                 min_rank_ratio=0.3,
                 adaptation_factor=0.9):
        super().__init__()
        self.initial_rank_ratio = initial_rank_ratio
        self.min_rank_ratio = min_rank_ratio
        self.adaptation_factor = adaptation_factor
        self.current_rank_ratio = initial_rank_ratio

        self.task_performances = []
        self.weight_importance = {}

    def before_training_exp(self, strategy, **kwargs):
        """Adapt rank ratio based on previous task performance"""
        if len(self.task_performances) > 0:
            # If performance dropped, reduce plasticity (higher rank preservation)
            avg_performance = np.mean(self.task_performances[-2:])  # Last 2 tasks
            if avg_performance < 0.7:  # Performance threshold
                self.current_rank_ratio = min(
                    self.initial_rank_ratio,
                    self.current_rank_ratio / self.adaptation_factor
                )
            else:
                self.current_rank_ratio = max(
                    self.min_rank_ratio,
                    self.current_rank_ratio * self.adaptation_factor
                )

            print(f"Adapted rank ratio to: {self.current_rank_ratio:.3f}")

    def after_training_exp(self, strategy, **kwargs):
        """Record task performance for adaptation"""
        # Simple performance estimation (you can make this more sophisticated)
        performance = 0.8  # Placeholder - integrate with actual evaluation
        self.task_performances.append(performance)


class TSVDContinualLearning(SupervisedTemplate):
    """
    TSVD-based Continual Learning Strategy

    Integrates TSVD decomposition with standard continual learning template.
    """

    def __init__(
            self,
            model: nn.Module,
            optimizer: Optimizer,
            criterion=CrossEntropyLoss(),
            rank_ratio: float = 0.8,
            plasticity_weight: float = 0.5,
            adaptive: bool = True,
            train_mb_size: int = 1,
            train_epochs: int = 1,
            eval_mb_size: Optional[int] = 1,
            device: Union[str, torch.device] = "cpu",
            plugins: Optional[List[SupervisedPlugin]] = None,
            evaluator: Union[EvaluationPlugin, Callable[[], EvaluationPlugin]] = default_evaluator,
            eval_every=-1,
            **kwargs
    ):
        # Create TSVD plugin
        if adaptive:
            tsvd_plugin = AdaptiveTSVDPlugin(
                initial_rank_ratio=rank_ratio,
            )
        else:
            tsvd_plugin = TSVDPlugin(
                rank_ratio=rank_ratio,
                plasticity_weight=plasticity_weight
            )

        if plugins is None:
            plugins = [tsvd_plugin]
        else:
            plugins.append(tsvd_plugin)

        super().__init__(
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            train_mb_size=train_mb_size,
            train_epochs=train_epochs,
            eval_mb_size=eval_mb_size,
            device=device,
            plugins=plugins,
            evaluator=evaluator,
            eval_every=eval_every,
            **kwargs
        )