import os
import random
import time
from types import SimpleNamespace
import numpy as np
import torch
from avalanche.benchmarks.classic import SplitCIFAR100, SplitCIFAR10 , SplitCUB200
from torch import nn
from avalanche.evaluation.metrics import loss_metrics, TrainedExperienceAccuracy
from torch.utils.data import DataLoader
from torchvision import transforms
from avalanche.training.plugins import MIRPlugin, GSS_greedyPlugin
from avalanche.training import GDumb
from avalanche.training.supervised import SCR

from MERS.cl_algorithm_plugins.acr_er_ace import  ER_ACE_Clfd
from MERS.cl_algorithm_plugins.bocl import BudgetedContinualLearning
from MERS.cl_algorithm_plugins.fl import FlashbackLearningPlugin
from MERS.cl_algorithm_plugins.online_er import ReplayPlugin

from MERS.mers_datasets import SplitTinyImageNet
from MERS.mers_utils.avalnche_to_mammoth import AvalancheToMammothDataset
from MERS.sampling_strategies.SelecetionStrategy import ProbCoverExemplarsSelectionStrategy, \
    SoloTEALExemplarsSelectionStrategy
from MERS.sampling_strategies.budgeted_coverage_strategy import BudgetedCoverageSelectionStrategy
# from cl_algorithm_plugins.naive_replay import NaiveReplayPlugin
from mers_datasets import SplitCUB200

from avalanche.evaluation.metrics import accuracy_metrics, forgetting_metrics, forward_transfer_metrics
from avalanche.logging import BaseLogger

from mers_models.arch_craft import arch_craft
from sampling_strategies.herding import HerdingSelectionStrategy
from sampling_strategies.rainbow_memory import RainbowMemorySelectionStrategy
from sampling_strategies.closest_to_canter import ClosestToCenterSelectionStrategy
from MERS.sampling_strategies.SelecetionStrategy import MaxHerding
from MERS.mers_utils.text_logging import TextLogger
from avalanche.models import IncrementalClassifier, SlimResNet18
try:
    from avalanche.training.supervised import FeatureReplay
except:
    pass

from mers_models.resnet18 import resnet18
from cl_algorithm_plugins.er_ace import ER_ACE

from avalanche.training.plugins import LRSchedulerPlugin, EvaluationPlugin
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import StepLR, MultiStepLR

from MERS.mers_utils.storage_policy import ParametricBuffer, ClassBalancedBuffer
from cl_algorithm_plugins.strategy_wrappers import Replay, Naive
# from avalanche.benchmarks.utils.classification_dataset import make_classification_dataset

from torch.optim import SGD
import sys
MAMMOTH_ROOT = "/cs/labs/daphna/danit.yanowsky/CL/MERS/mammoth"  # <- adjust if different
if MAMMOTH_ROOT not in sys.path:
    sys.path.insert(0, MAMMOTH_ROOT)


def set_complete_mammoth_args(args, buffer_size, device):
    """Set all arguments that Mammoth XDER expects."""

    # XDER-specific parameters
    args.csv_log = True  # Enable CSV logging
    args.disable_log = True  # Enable logging
    args.eval_future = False  # Evaluate future tasks
    args.print_results = True  # Print results to console
    args.ignore_other_metrics = False  # Show all metrics
    args.simclr_temp = 5
    args.gamma = 0.85
    args.simclr_batch_size = 64
    args.simclr_num_aug = 4
    args.lambd = 0.05
    args.constr_eta = 0.1
    args.constr_margin = 0.3
    args.dp_weight = 0
    args.past_constraint = False
    args.future_constraint = True
    args.align_bn = True

    # Optimizer arguments
    args.optimizer = 'sgd'
    args.optim_wd = args.weight_decay if hasattr(args, 'weight_decay') else 0.0002
    args.optim_mom = args.momentum if hasattr(args, 'momentum') else 0.9
    args.optim_nesterov = False

    # Learning rate scheduler
    args.lr_scheduler = None
    args.lr_milestones = None
    args.lr_gamma = 0.1

    # Buffer and training settings
    args.buffer_size = buffer_size
    args.minibatch_size = min(32, buffer_size // 4)
    args.n_epochs = args.num_epochs if hasattr(args, 'num_epochs') else 100

    # MISSING ARGUMENTS that Mammoth expects:
    args.label_perc = 1.0  # Label percentage (1.0 = 100% labeled data)
    args.model = 'resnet18'  # Model name
    args.csv_log = False  # CSV logging
    args.tensorboard = False  # Tensorboard logging
    args.validation = False  # Use validation set
    args.ignore_other_metrics = False  # Ignore other metrics
    args.debug_mode = False  # Debug mode
    args.non_verbose = False  # Non-verbose mode for avalanche compatibility
    args.disable_log = False  # Disable logging
    args.notes = ""  # Notes for the experiment
    args.load_best_args = False  # Load best arguments

    # Device settings
    args.device = device

    # Seed settings
    if not hasattr(args, 'seed'):
        args.seed = 42

    # Notebook settings (for compatibility)
    args.nowand = True  # Disable wandb

    return args
def get_der_transform(dataset_name):
    """Get appropriate data augmentation transform for DER strategy"""
    if dataset_name in ['cifar10', 'cifar100']:
        return transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    elif dataset_name == 'tinyimg':
        return transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    elif dataset_name == 'cub200':
        return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    else:
        # Default transform - adjust based on your dataset
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
def init_dataset(args):
    if args.dataset == 'cifar10':
        dataset = SplitCIFAR10(n_experiences=args.num_experiences,
                               return_task_id=False,
                               shuffle=True, class_ids_from_zero_in_each_exp=False,
                               fixed_class_order=[i for i in range(10)] if args.seed is None else None,
                               seed=args.seed,
                               dataset_root='data/avalanche/cifar10')
    elif args.dataset == 'cifar100':
        if args.order:
            random.seed(args.order)
            random_order = list(range(100))
            random.shuffle(random_order)
        else:
            random_order = list(range(100))
        dataset = SplitCIFAR100(n_experiences=args.num_experiences,
                                return_task_id=False,
                                shuffle=True, class_ids_from_zero_in_each_exp=False,
                                fixed_class_order=random_order,
                                seed=args.seed)
    elif args.dataset == 'tinyimg':
        if args.order:
            random.seed(args.order)
            random_order = list(range(200))
            random.shuffle(random_order)
        else:
            random_order = list(range(200))
        dataset = SplitTinyImageNet(n_experiences=args.num_experiences,
                                    return_task_id=False,
                                    shuffle=True, class_ids_from_zero_in_each_exp=False,
                                    fixed_class_order=random_order ,
                                    seed=args.seed,
                                    dataset_root='/cs/usr/danit.yanowsky/data/avalanche/tiny_imagenet')
        # for index, im, lab in tqdm(loader):
        #     print("index:", index)
    elif args.dataset == 'cub200':
        dataset = SplitCUB200(n_experiences=args.num_experiences,
                              return_task_id=False,
                              shuffle=True, class_ids_from_zero_in_each_exp=False,
                              fixed_class_order=[i for i in range(200)] if args.seed is None else None,
                              seed=args.seed,
                              classes_first_batch=None,
                              dataset_root='../data/')
    else:
        raise ValueError(f'Unknown dataset {args.dataset}')
    return dataset


def init_incremental_model(inc_model_name, dataset, dataset_name):
    if inc_model_name == 'arch_craft':
        model = arch_craft(code=[10, 144, [3, 7, 8, 10, 10], [3, 8, 10, 10, 10]], dataset=dataset_name)  # for debug: [9, 12, [4, 6, 8, 8, 9], [1, 3, 5, 8, 9]]
        model.linear = IncrementalClassifier(model.linear.in_features,
                                             dataset.n_classes // dataset.n_experiences)
    elif inc_model_name == 'resnet18':
        model = resnet18(dataset.n_classes)
        # For ResNet18, input features to linear layer is 512 (64 * 8 * 1)
        linear_input_features = 512
        model.linear = IncrementalClassifier(linear_input_features,
                                             dataset.n_classes // dataset.n_experiences)
        # Store input features as an attribute for DER++ access
        model.linear_input_features = linear_input_features

    elif inc_model_name == 'slim_resnet18':
        model = SlimResNet18(nclasses=1)
        model.linear = IncrementalClassifier(model.linear.in_features,
                                             dataset.n_classes // dataset.n_experiences)
    else:
        raise ValueError(f'Unknown incremental model {inc_model_name}')

    return model


class ContinualLearningPipeline:

    def __init__(self, args, device, exp_dir):
        self.args = args
        self.device = device
        self.dataset = init_dataset(args)
        self.inc_model = init_incremental_model(args.inc_model, self.dataset, args.dataset)
        self.buffer_size = args.buffer
        self.optimizer = self.get_optimizer()
        self.scheduler_plugin = self.get_scheduler_plugin()
        self.cl_strategy = self.init_cl_strategy()
        self.exp_dir=exp_dir
        self.batch_id =args.batch_id
        self.alpha=args.alpha

    def get_optimizer(self):
        optimizer = SGD(self.inc_model.parameters(), momentum=self.args.momentum,
                        weight_decay=self.args.weight_decay, lr=self.args.lr)
        return optimizer

    def get_scheduler_plugin(self):
        scheduler = StepLR(self.optimizer, step_size=max(self.args.num_epochs // 3,1), gamma=0.3)
        scheduler_plugin = LRSchedulerPlugin(scheduler, step_granularity="epoch", first_exp_only=False)
        if self.args.algorithm =='der_pp':
            scheduler_plugin =  MultiStepLR(self.optimizer, milestones=[100, 120], gamma=0.1)
        return scheduler_plugin

    def init_cl_strategy(self):

        if self.args.algorithm == 'er_ace' or self.args.algorithm=="acr_er_ace":  # ER-ACE initialize with random exemplars
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class')
        elif self.args.sel_strategy == 'herding':
            selection_strategy = HerdingSelectionStrategy(self.args)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )
        elif self.args.sel_strategy == 'budget':
            selection_strategy = BudgetedCoverageSelectionStrategy(self.args, self.device)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )
        elif self.args.sel_strategy == 'teal':
            selection_strategy = SoloTEALExemplarsSelectionStrategy(self.args, self.device)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )
        elif self.args.sel_strategy == 'max_herding':
            selection_strategy=MaxHerding(self.args, self.device)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',selection_strategy=selection_strategy)
        elif self.args.sel_strategy =='probcover':
            selection_strategy = ProbCoverExemplarsSelectionStrategy(self.args, self.device)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )

        elif self.args.sel_strategy == 'rm':
            if self.args.dataset=='tinyimg':
                self.args.dataset="TinyImagenet"
            selection_strategy = RainbowMemorySelectionStrategy(self.args, self.device)
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )
            if self.args.dataset =='TinyImagenet':
                self.args.dataset = 'tinyimg'

        elif self.args.sel_strategy == 'centered':
            selection_strategy = ClosestToCenterSelectionStrategy()
            storage_policy = ParametricBuffer(self.buffer_size, groupby='class',
                                              selection_strategy=selection_strategy,
                                              )
        else:
            storage_policy = ClassBalancedBuffer(self.buffer_size, adaptive_size=True)

        criterion = CrossEntropyLoss()
        loggers = [BaseLogger(), TextLogger()]
        evaluator = EvaluationPlugin(
            accuracy_metrics(epoch=True, trained_experience=True, experience=True, epoch_running=True),
            loss_metrics(minibatch=True, epoch=True),
            forgetting_metrics(experience=True),
            # forward_transfer_metrics(experience=True),
            loggers=loggers,
        )
        if self.args.algorithm == 'xder':
            # Use torchvision ResNet18 but keep the fc layer
            import torchvision.models as models
            import torch.nn as nn

            # Use torchvision ResNet18 as base
            backbone = models.resnet18(weights=None)  # Updated for newer torchvision
            # Replace the final layer with our number of classes
            backbone.fc = nn.Linear(backbone.fc.in_features, self.dataset.n_classes)

            # Add linear as an alias to fc (don't delete fc!)
            backbone.linear = backbone.fc

            self.inc_model = backbone

            self.args = set_complete_mammoth_args(self.args, self.buffer_size, self.device)
            self.inc_model = self.inc_model.to(self.device)

            mammoth_dataset = AvalancheToMammothDataset(self.dataset, self.args.dataset)
            from mammoth.models.xder import XDer
            transform = get_der_transform(self.args.dataset)


            xder_strategy = XDer(self.inc_model, criterion, self.args, transform, mammoth_dataset)
            xder_strategy = xder_strategy.to(self.device)

            if hasattr(xder_strategy, 'net'):
                xder_strategy.net = xder_strategy.net.to(self.device)

            return xder_strategy
        elif self.args.algorithm == 'mir':
                mir_plugin = MIRPlugin(
                    subsample=getattr(self.args, 'mir_subsample', min(self.buffer_size, 128)),  # Number of patterns to retrieve for MIR
                    batch_size_mem=getattr(self.args, 'batch_size_mem', self.args.batch_size)  # Memory batch size
                )

                replay_strategy = Replay(
                    model=self.inc_model,
                    optimizer=self.optimizer,
                    criterion=criterion,
                    mem_size=self.buffer_size,
                    device=self.device,
                    train_epochs=self.args.num_epochs,
                    train_mb_size=self.args.batch_size,
                    eval_mb_size=self.args.batch_size,
                    plugins=[self.scheduler_plugin],  # Only scheduler initially
                    evaluator=evaluator,
                    storage_policy=storage_policy  # Use your custom storage policy
                )

                # Add MIR plugin while preserving the storage plugin position
                # Insert MIR plugin at position 0, keeping storage at position 1 (where your code expects it)
                replay_strategy.plugins.insert(2, mir_plugin)

                return replay_strategy
        elif self.args.algorithm == 'gdumb':
            gdumb_strategy= GDumb(model=self.inc_model,
                                optimizer = self.optimizer,
                                criterion = criterion,
                                mem_size = self.buffer_size,
                                train_mb_size = self.args.batch_size,
                                train_epochs = self.args.num_epochs,
                                eval_mb_size = self.args.batch_size,
                                device = self.device,
                                plugins = [self.scheduler_plugin],
                                evaluator = evaluator
                                                       )
            gdumb_strategy.plugins[1].storage_policy = storage_policy
            return gdumb_strategy
        elif self.args.algorithm == 'budgeted_cl':
            # Budgeted Online Continual Learning (aL-SAR) - ICLR 2025
            budgeted_strategy = BudgetedContinualLearning(
                model=self.inc_model,
                optimizer=self.optimizer,
                criterion=criterion,
                mem_size=self.buffer_size,
                device=self.device,
                train_epochs=self.args.num_epochs,
                train_mb_size=self.args.batch_size,
                eval_mb_size=self.args.batch_size,
                flop_budget_ratio=getattr(self.args, 'flop_budget_ratio', 1.0),
                temperature=getattr(self.args, 'temperature', 1.0),
                freeze_threshold=getattr(self.args, 'freeze_threshold', 0.1),
                plugins=[self.scheduler_plugin],
                evaluator=evaluator,
                # storage_policy=storage_policy  # Pass your custom storage policy
            )
            # Set the storage policy after creation
            if hasattr(budgeted_strategy, 'plugins'):
                for plugin in budgeted_strategy.plugins:
                    if hasattr(plugin, 'storage_policy'):
                        plugin.storage_policy = storage_policy
                        break

            return budgeted_strategy
        elif self.args.algorithm == 'der':
            # Use Avalanche's native DER implementation
            from avalanche.training import DER

            # Set DER-specific parameters
            alpha = getattr(self.args, 'alpha', 0.1)
            beta = getattr(self.args, 'beta', 0.0)  # 0.0 for DER, >0 for DER++

            der_strategy = DER(
                model=self.inc_model,
                optimizer=self.optimizer,
                criterion=criterion,
                mem_size=self.buffer_size,
                batch_size_mem=self.args.batch_size,
                alpha=alpha,
                beta=beta,
                train_mb_size=self.args.batch_size,
                train_epochs=self.args.num_epochs,
                eval_mb_size=self.args.batch_size,
                device=self.device,
                plugins=[self.scheduler_plugin],
                evaluator=evaluator
            )

            return der_strategy

        elif self.args.algorithm == 'der_pp':
            # Use Avalanche's native DER++ implementation (DER with beta > 0)
            from avalanche.training import DER
            import torchvision.models as models
            import torch.nn as nn

            # Use torchvision ResNet18 as base
            backbone = models.resnet18(weights=None)  # Updated for newer torchvision
            # Replace the final layer with our number of classes
            backbone.fc = nn.Linear(backbone.fc.in_features, self.dataset.n_classes)

            # Add linear as an alias to fc (don't delete fc!)
            backbone.linear = backbone.fc

            self.inc_model = backbone

            self.args = set_complete_mammoth_args(self.args, self.buffer_size, self.device)
            self.inc_model = self.inc_model.to(self.device)
            # Set DER++-specific parameters
            alpha = getattr(self.args, 'alpha_', 0.1)  # Lower alpha for better balance
            beta = getattr(self.args, 'beta_', 0.5)    # Lower beta for better balance

            # Keep original model structure for class-incremental learning
            # Don't modify to use total classes - this causes dimension issues
            
            derpp_strategy = DER(
                model=self.inc_model,
                optimizer=self.optimizer,
                criterion=criterion,
                mem_size=self.buffer_size,
                batch_size_mem=self.args.batch_size,
                alpha=alpha,
                beta=beta,
                train_mb_size=self.args.batch_size,
                train_epochs=self.args.num_epochs,
                eval_mb_size=self.args.batch_size,
                device=self.device,
                plugins=[self.scheduler_plugin],
                evaluator=evaluator
            )

            return derpp_strategy

        elif self.args.algorithm == 'der_pp_probcover':
            # Use Avalanche's DER++ with custom ProbCover storage policy
            from avalanche.training import DER
            
            # Set DER++-specific parameters
            alpha = getattr(self.args, 'alpha_', 0.1)
            beta = getattr(self.args, 'beta_', 0.5)  # >0 for DER++
            
            # Create DER++ strategy with ProbCover storage policy
            derpp_probcover_strategy = DER(
                model=self.inc_model,
                optimizer=self.optimizer,
                criterion=criterion,
                mem_size=self.buffer_size,
                batch_size_mem=self.args.batch_size,
                alpha=alpha,
                beta=beta,
                train_mb_size=self.args.batch_size,
                train_epochs=self.args.num_epochs,
                eval_mb_size=self.args.batch_size,
                device=self.device,
                plugins=[self.scheduler_plugin],
                evaluator=evaluator
            )
            
            # Replace the default storage policy with ProbCover if available
            # if hasattr(derpp_probcover_strategy, 'plugins'):
                # for plugin in derpp_probcover_strategy.plugins:
                #     if hasattr(plugin, 'storage_policy') and hasattr(plugin.storage_policy, 'buffer'):
                #         # Replace with ProbCover storage policy
                #         from avalanche.training.storage_policy import ReservoirSamplingBuffer
                #         # from MERS.sampling_strategies.SelecetionStrategy import ProbCoverExemplarsSelectionStrategy
                #
                #         # Create ProbCover selection strategy
                #         probcover_strategy = ProbCoverExemplarsSelectionStrategy(self.args, self.device)
                #
                #         # Create new storage policy with ProbCover
                #         plugin.storage_policy = ParametricBuffer(
                #             self.buffer_size,
                #             groupby='class',
                #             selection_strategy=probcover_strategy
                #         )
                #         break
                #
            return derpp_probcover_strategy

        elif self.args.algorithm == 'er_ace':
            use_gss = getattr(self.args, 'sel_strategy', None) == 'gss'
            if use_gss:
                # replay = ReplayPlugin(
                #     mem_size=self.buffer_size,
                #     batch_size_mem=self.args.batch_size,
                #     storage_policy=ClassBalancedBuffer(max_size=self.buffer_size)
                # )
                gss = GSS_greedyPlugin(
                    mem_size=self.buffer_size,
                    mem_strength=getattr(self.args, 'gss_mem_strength', 4),
                    input_size=getattr(self.args, 'gss_input_size', [3, 32, 32])
                )

            strategy= ER_ACE(
                    self.inc_model, self.optimizer, criterion, self.buffer_size, device=self.device,
                    train_epochs=self.args.num_epochs, train_mb_size=self.args.batch_size,
                    eval_mb_size=self.args.batch_size,
                    batch_size_mem=self.args.batch_size,
                    plugins=[self.scheduler_plugin], evaluator=evaluator,
                    storage_policy=storage_policy,
                    args=self.args,
                )

            if use_gss:
                # strategy.plugins.insert(1,replay)
                strategy.plugins.insert(2,gss)
            return strategy


        elif self.args.algorithm == 'er':
            use_gss = getattr(self.args, 'sel_strategy', None) == 'gss'
            if use_gss:
                gss = GSS_greedyPlugin(
                    mem_size=self.buffer_size,
                    mem_strength=getattr(self.args, 'gss_mem_strength', 5),
                    input_size=getattr(self.args, 'gss_input_size', [3, 32, 32])
                )
                return Replay(
                    self.inc_model, self.optimizer, criterion, self.buffer_size, device=self.device,
                    train_epochs=self.args.num_epochs, train_mb_size=self.args.batch_size,
                    eval_mb_size=self.args.batch_size,
                    plugins=[self.scheduler_plugin, gss], evaluator=evaluator, storage_policy=storage_policy
                )

            return Replay(
                self.inc_model, self.optimizer, criterion, self.buffer_size, device=self.device,
                train_epochs=self.args.num_epochs, train_mb_size=self.args.batch_size,
                eval_mb_size=self.args.batch_size,
                plugins=[self.scheduler_plugin], evaluator=evaluator, storage_policy=storage_policy
            )
        elif self.args.algorithm == 'naive_replay':
            return Naive(model=self.inc_model, optimizer=self.optimizer,
                         plugins=[self.scheduler_plugin, NaiveReplayPlugin(storage_policy=storage_policy)],
                         evaluator=evaluator, device=self.device, train_epochs=self.args.num_epochs,
                         train_mb_size=self.args.batch_size, eval_mb_size=self.args.batch_size,
                         )
        elif self.args.algorithm == 'finetune':
            return Naive(
                self.inc_model, self.optimizer, criterion,
                device=self.device,
                train_epochs=self.args.num_epochs, train_mb_size=self.args.batch_size,
                eval_mb_size=self.args.batch_size,
                plugins=[self.scheduler_plugin], evaluator=evaluator,
            )
        elif self.args.algorithm == 'flashback_er_ace':
            # Enhanced ER-ACE with Flashback Learning
            flashback_plugin = FlashbackLearningPlugin(
                flashback_weight=getattr(self.args, 'flashback_weight', 0.1),
                stability_weight=getattr(self.args, 'stability_weight', 0.5)
            )

            use_gss = getattr(self.args, 'sel_strategy', None) == 'gss'
            if use_gss:
                replay = ReplayPlugin(
                    mem_size=self.buffer_size,
                    batch_size_mem=self.args.batch_size,
                    storage_policy=ClassBalancedBuffer(max_size=self.buffer_size)
                )
                gss = GSS_greedyPlugin(
                    mem_size=self.buffer_size,
                    mem_strength=getattr(self.args, 'gss_mem_strength', 5),
                    input_size=getattr(self.args, 'gss_input_size', [3, 32, 32])
                )

            # Create ER-ACE strategy first with just the scheduler plugin
            strategy = ER_ACE(
                self.inc_model, self.optimizer, criterion, self.buffer_size,
                device=self.device,
                train_epochs=self.args.num_epochs,
                train_mb_size=self.args.batch_size,
                eval_mb_size=self.args.batch_size,
                batch_size_mem=self.args.batch_size,
                plugins=[self.scheduler_plugin],
                evaluator=evaluator,
                storage_policy=storage_policy,
                args=self.args,
            )

            # Now add the flashback plugin after the ER-ACE plugin (which is at index 1)
            strategy.plugins.insert(2, flashback_plugin)

            if use_gss:
                strategy.plugins.insert(1, replay)
                strategy.plugins.insert(4, gss)  # Adjust index due to added plugins

            return strategy

        elif self.args.algorithm == 'feature_replay':
            # Debug: Print model structure to understand the issue
            print(f"DEBUG: Model type: {type(self.inc_model)}")
            print(f"DEBUG: Model attributes: {[attr for attr in dir(self.inc_model) if not attr.startswith('_')]}")

            # Use your existing ResNet18 model but ensure it's properly configured
            model = self.inc_model

            # Determine the correct last layer name for your ResNet18 model
            last_layer_name = 'linear'  # Your ResNet18 uses 'linear' as the final layer

            # Debug: Verify the last layer exists and its properties
            if hasattr(model, last_layer_name):
                last_layer = getattr(model, last_layer_name)
                print(f"DEBUG: Found last layer '{last_layer_name}': {last_layer}")
                print(f"DEBUG: Last layer type: {type(last_layer)}")
                if hasattr(last_layer, 'in_features'):
                    print(f"DEBUG: Last layer input features: {last_layer.in_features}")
            else:
                print(f"ERROR: Model does not have attribute '{last_layer_name}'")
                # Fallback to auto-detection
                for attr_name in ['linear', 'fc', 'classifier']:
                    if hasattr(model, attr_name):
                        last_layer_name = attr_name
                        print(f"DEBUG: Using fallback last layer name: {last_layer_name}")
                        break

            # Ensure the model is properly configured for class-incremental learning
            if hasattr(model, last_layer_name):
                last_layer = getattr(model, last_layer_name)
                if isinstance(last_layer, IncrementalClassifier):
                    print(f"DEBUG: Model already has IncrementalClassifier")
                else:
                    print(f"WARNING: Last layer is not IncrementalClassifier: {type(last_layer)}")

            # Create Feature Replay strategy with careful parameter settings
            feature_replay_strategy = FeatureReplay(
                model=model,
                optimizer=self.optimizer,  # Use existing optimizer
                last_layer_name=last_layer_name,
                mem_size=self.buffer_size,
                batch_size_mem=min(32, self.buffer_size // 4),  # Conservative memory batch size
                train_mb_size=self.args.batch_size,
                train_epochs=self.args.num_epochs,
                eval_mb_size=self.args.batch_size,
                device=self.device,
                plugins=[self.scheduler_plugin],  # Use existing scheduler
                evaluator=evaluator
            )



            print(f"DEBUG: FeatureReplay strategy created successfully")
            return feature_replay_strategy
        elif self.args.algorithm == 'acr_er_ace':
            strategy = ER_ACE_Clfd(
                self.inc_model, self.optimizer, criterion, self.buffer_size, device=self.device,
                train_epochs=self.args.num_epochs, train_mb_size=self.args.batch_size,
                eval_mb_size=self.args.batch_size,
                batch_size_mem=self.args.batch_size,
                plugins=[self.scheduler_plugin], evaluator=evaluator,
                storage_policy=storage_policy,
                args=self.args,
            )
            return strategy

        else:
            raise ValueError(f'Unknown algorithm {self.args.algorithm}')
    def run_and_save_results(self, seed, exp_results, output_file="accuracy_results"):
        print(f"DEBUG: exp_results type: {type(exp_results)}")
        print(f"DEBUG: exp_results value: {exp_results}")
        
        if exp_results is None:
            print("No experimental results to save")
            return
            
        results_by_seed = {
                exp_id: exp_results[exp_id]
                for exp_id in range(len(exp_results))
            }
        # Save results by seed
        output_path =os.path.join(self.exp_dir, f"_alpha_{self.alpha}_{output_file}_seed_{seed}.npy")
        np.save(output_path, results_by_seed)
        print(f"Accuracy results saved to {output_path}")


    def train(self, exp_folder):
        self.args.exp_dir=exp_folder
        _CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
        _CIFAR100_STD = (0.2673, 0.2564, 0.2761)

        def _eval_top1_xder(xder_strategy, test_loaders, device):
            # Mammoth's XDer stores the backbone in .net (common) or .model (fallback)
            net = getattr(xder_strategy, "net", None) or getattr(xder_strategy, "model", None)
            if net is None:
                raise RuntimeError("XDer object has no .net/.model for forward pass.")
            net.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for ld in test_loaders:
                    for x, y, *rest in ld:
                        x, y = x.to(device), y.to(device)
                        logits = net(x)
                        pred = logits.argmax(1)
                        correct += (pred == y).sum().item()
                        total += y.numel()
            return (correct / total) if total > 0 else 0.0
        exp_acc_dict = {exp_id: [] for exp_id in range(self.dataset.n_experiences)}
        # Initialize FWT results matrix to track forward transfer
        # fwt_matrix = np.zeros((self.dataset.n_experiences, self.dataset.n_experiences))

        for t, experience in enumerate(self.dataset.train_stream):
            print(f"Start of experience: {experience.current_experience}")
            print(f"Current Classes: {experience.classes_in_this_experience}")
            print(f"Experience {t + 1}/{self.dataset.n_experiences}")
            print(f"Current Time: {time.strftime('%H:%M:%S', time.localtime())}")

            # Additional check to prevent infinite loop
            if t >= self.args.num_experiences:
                print(f"WARNING: Stopping at experience {t} (expected {self.args.num_experiences})")
                break

            else:
                self.cl_strategy.train(
                    experience,
                    num_workers=1,
                    drop_last=True,
                )
                exp_res = self.cl_strategy.eval(self.dataset.test_stream[: t + 1])
                for exp_id in range(t + 1):
                    if self.args.algorithm != 'contrastive':
                        exp_acc_dict[exp_id].append(
                            exp_res[f'Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp{str(exp_id).zfill(3)}'])

                print(f"\nExperience {experience.current_experience} Results:\n{exp_res}\n")
                
                # Save episode summary for weight analysis
                self.save_episode_summary(t)

        # Save model and features_ss embeddings at the end of all experiences
        self.save_model_and_features()
        
        print(f"DEBUG: Returning exp_acc_dict with {len(exp_acc_dict)} experiences")
        return exp_acc_dict

    def save_episode_summary(self, experience_id):
        """
        Save episode summary for weight analysis.
        This method accesses the selection strategy from the buffer groups
        and calls its save_episode_summary method.
        """
        try:
            # Access the selection strategy from the buffer groups
            if hasattr(self.cl_strategy, 'plugins') and len(self.cl_strategy.plugins) > 1:
                storage_policy = self.cl_strategy.plugins[1].storage_policy
                
                if hasattr(storage_policy, 'buffer_groups'):
                    print(f"Saving episode {experience_id} weight summary:")
                    
                    # Get the first buffer group to access the selection strategy
                    # (all groups should have the same selection strategy)
                    for class_id in storage_policy.buffer_groups.keys():
                        buffer_group = storage_policy.buffer_groups[class_id]
                        
                        # Access the selection strategy from the buffer group
                        if hasattr(buffer_group, 'selection_strategy'):
                            selection_strategy = buffer_group.selection_strategy
                            if hasattr(selection_strategy, 'save_episode_summary'):
                                summary_file = selection_strategy.save_episode_summary()
                                if summary_file:
                                    print(f"Episode {experience_id} weight summary saved to: {summary_file}")
                                else:
                                    print(f"No weights to save for episode {experience_id}")
                                break  # Only need to call once since all groups share the same strategy
                            else:
                                print(f"Selection strategy {type(selection_strategy).__name__} does not support episode summary saving")
                                break
                        else:
                            print("Buffer group does not have selection strategy")
                            break
                else:
                    print("Storage policy does not have buffer_groups")
            else:
                print("No storage plugin found in strategy")
        except Exception as e:
            print(f"Error saving episode summary for experience {experience_id}: {e}")

    def save_model_and_features(self):
        """Save the trained model and features_ss embeddings for downstream learning."""
        import os
        
        # Get the selection strategy method name
        ss_method = getattr(self.args, 'sel_strategy', 'default')
        if hasattr(self.args, 'features_type'):
            ss_method = f"{ss_method}_{self.args.features_type}"
        
        # Create save directory with method type
        save_dir = os.path.join(self.exp_dir, f"saved_embeddings_{ss_method}")
        os.makedirs(save_dir, exist_ok=True)
        
        # # Save the trained model
        # model_path = os.path.join(save_dir, f"model_alpha_{self.alpha}_batch_{self.batch_id}.pth")
        # torch.save(self.cl_strategy.model.state_dict(), model_path)
        # print(f"Model saved to: {model_path}")
        #
        # # Save features_ss embeddings from all buffer groups
        # if hasattr(self.cl_strategy, 'plugins') and len(self.cl_strategy.plugins) > 1:
        #     storage_policy = self.cl_strategy.plugins[1].storage_policy
        #
        #     if hasattr(storage_policy, 'buffer_groups'):
        #         print(f"Saving features_ss embeddings ({ss_method}):")
        #
        #         for class_id in storage_policy.buffer_groups.keys():
        #             buffer_group = storage_policy.buffer_groups[class_id]
        #
        #             # Access the selection strategy from the buffer group
        #             if hasattr(buffer_group, 'selection_strategy'):
        #                 ss = buffer_group.selection_strategy
        #                 # Check if selected_features_ss exists and is not None (top 5 selected)
        #                 if hasattr(ss, 'selected_features_ss') and ss.selected_features_ss is not None:
        #                     # Save top 5 selected features_ss
        #                     np_path = os.path.join(save_dir, f"selected_features_ss_class_{class_id}_{ss_method}.npy")
        #                     np.save(np_path, ss.selected_features_ss)
        #                     print(f"Class {class_id}: {len(ss.selected_features_ss)} selected features_ss saved to: {np_path}")
        #                 else:
        #                     print(f"Class {class_id}: No selected_features_ss available in selection strategy")
        #
        #                 # Save remaining features_ss (everything not selected) for KNN testing
        #                 if hasattr(ss, 'remaining_features_ss') and ss.remaining_features_ss is not None:
        #                     # Save remaining features_ss for KNN testing
        #                     np_path = os.path.join(save_dir, f"remaining_features_ss_class_{class_id}_{ss_method}.npy")
        #                     np.save(np_path, ss.remaining_features_ss)
        #                     print(f"Class {class_id}: {len(ss.remaining_features_ss)} remaining features_ss saved to: {np_path}")
        #                 else:
        #                     print(f"Class {class_id}: No remaining_features_ss available in selection strategy")
        #
        #                 # Save all important metadata as numeric arrays for clarity
        #                 # Metadata encoding:
        #                 # - alpha: float (weight for model-based features in integrated mode)
        #                 # - ss_method: int (0=model_based, 1=dino, 2=simclr, 3=vicreg, -1=unknown)
        #                 # - integrated: int (0=False, 1=True, -1=unknown)
        #                 # - weight_method: int (0=ratio_median_knn_density_k_1, 1=herding_specific_balance,
        #                 #                        2=mean_approximation_quality, 3=spread_vs_compactness, 4=equal, -1=unknown)
        #                 metadata = {}
        #
        #                 # Save alpha value (important for integrated features and KNN testing)
        #                 if hasattr(ss, 'alpha'):
        #                     metadata['alpha'] = ss.alpha
        #
        #                 # Save ss_method as numeric encoding (important for understanding feature type)
        #                 if hasattr(ss, 'ss_method'):
        #                     ss_method_encoding = {
        #                         'model_based': 0,
        #                         'dino': 1,
        #                         'simclr': 2,
        #                         'vicreg': 3
        #                     }
        #                     metadata['ss_method'] = ss_method_encoding.get(ss.ss_method, -1)
        #
        #                 # Save integrated flag (important for understanding feature combination)
        #                 if hasattr(ss, 'integrated'):
        #                     metadata['integrated'] = 1 if ss.integrated else 0
        #
        #                 # Save weight_method as numeric encoding if available
        #                 if hasattr(ss, 'weight_method'):
        #                     weight_method_encoding = {
        #                         'ratio_median_knn_density_k_1': 0,
        #                         'herding_specific_balance': 1,
        #                         'mean_approximation_quality': 2,
        #                         'spread_vs_compactness': 3,
        #                         'equal': 4
        #                     }
        #                     metadata['weight_method'] = weight_method_encoding.get(ss.weight_method, -1)
        #
        #                 # Save all metadata as a single numpy array
        #                 if metadata:
        #                     metadata_path = os.path.join(save_dir, f"metadata_class_{class_id}_{ss_method}.npy")
        #                     # Create structured array with metadata
        #                     metadata_array = np.array([metadata.get('alpha', -1),
        #                                              metadata.get('ss_method', -1),
        #                                              metadata.get('integrated', -1),
        #                                              metadata.get('weight_method', -1)])
        #                     np.save(metadata_path, metadata_array)
        #                     print(f"Class {class_id}: metadata saved to: {metadata_path}")
        #                     print(f"  - alpha: {metadata.get('alpha', 'N/A')}")
        #                     print(f"  - ss_method: {metadata.get('ss_method', 'N/A')} ({ss.ss_method if hasattr(ss, 'ss_method') else 'N/A'})")
        #                     print(f"  - integrated: {metadata.get('integrated', 'N/A')}")
        #                     print(f"  - weight_method: {metadata.get('weight_method', 'N/A')} ({ss.weight_method if hasattr(ss, 'weight_method') else 'N/A'})")
        #             else:
        #                 print(f"Class {class_id}: No selection strategy found in buffer group")
        #     else:
        #         print("No buffer_groups found in storage policy")
        # else:
        #     print("No storage policy found in strategy plugins")
        #
        # print(f"All embeddings saved to: {save_dir}")
        #
        #
