import os
import torch
from avalanche.training.plugins import SupervisedPlugin
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader

class ExperienceReplay(SupervisedPlugin):
    def __init__(self, storage_policy, args, mem_batch_size=32):
        super().__init__()
        self.storage_policy = storage_policy
        self.mem_batch_size = mem_batch_size
        self.benchmark_name = args.dataset
        self.model_name = args.model
        self.seed = args.seed

    def before_training_exp(self, strategy, num_workers: int = 0, shuffle: bool = True, **kwargs):
        """Set up a ReplayDataLoader with memory and current data."""
        if len(self.storage_policy.buffer) == 0 or self.mem_batch_size == 0:
            # No memory yet — use default data loader
            return

        transform = kwargs.get("memory_transform", None)
        if transform and hasattr(self.storage_policy.buffer, 'with_transforms'):
            self.storage_policy.buffer = self.storage_policy.buffer.with_transforms(transform)

        print("ReplayDataLoader with memory batch size =", self.mem_batch_size)
        strategy.dataloader = ReplayDataLoader(
            data=strategy.adapted_dataset,                # The current experience data
            memory=self.storage_policy.buffer,            # The rehearsal memory
            batch_size=strategy.train_mb_size,            # Batch size for current data
            batch_size_mem=self.mem_batch_size,           # Batch size for memory data
            oversample_small_tasks=True,
            shuffle=shuffle,
            num_workers=num_workers
        )

    def after_training_exp(self, strategy, **kwargs):
        """Update memory buffer after training on experience and save model checkpoint."""

        print("Updating memory buffer and saving model checkpoint...")
        # Use validation experience if passed
        val_exp = kwargs.get("val_exp", None)
        self.storage_policy.update(strategy, experience=val_exp)
        if not os.path.exists(f"./checkpoints/{self.model_name}/{self.benchmark_name}/seed{self.seed}"):
            os.makedirs(f"./checkpoints/{self.model_name}/{self.benchmark_name}/seed{self.seed}")
        task_id = strategy.experience.current_experience
        num_classes = strategy.model.output.classifier.out_features

        torch.save({
            'state_dict': strategy.model.state_dict(),
            'num_classes': num_classes},
            f"checkpoints/{self.model_name}/{self.benchmark_name}/seed{self.seed}/model_task{task_id}.pth")
