class AvalancheToMammothDataset:
    """Wrapper to make Avalanche dataset compatible with Mammoth XDER."""

    def __init__(self, avalanche_dataset, dataset_name=None):
        self.avalanche_dataset = avalanche_dataset

        # Map Avalanche attributes to Mammoth expected attributes
        self.N_CLASSES = avalanche_dataset.n_classes
        self.N_CLASSES_PER_TASK = avalanche_dataset.n_classes // avalanche_dataset.n_experiences
        self.N_TASKS = avalanche_dataset.n_experiences
        self.n_experiences = avalanche_dataset.n_experiences
        self.train_stream = avalanche_dataset.train_stream
        self.test_stream = avalanche_dataset.test_stream
        self.test_loaders = []
        for experience in self.test_stream:
            from torch.utils.data import DataLoader
            loader = DataLoader(experience.dataset, batch_size=64, shuffle=False)
            self.test_loaders.append(loader)

        # Create train_loaders from train_stream
        self.train_loaders = []
        for experience in self.train_stream:
            from torch.utils.data import DataLoader
            loader = DataLoader(experience.dataset, batch_size=64, shuffle=False)
            self.train_loaders.append(loader)

        # Set SETTING - Mammoth continual learning setting
        self.SETTING = 'class-il'  # Default to class incremental learning

        # Set SIZE based on dataset type
        if hasattr(avalanche_dataset, 'SIZE'):
            self.SIZE = avalanche_dataset.SIZE
        else:
            size_map = {
                'cifar10': [32, 32],
                'cifar100': [32, 32],
                'tinyimg': [64, 64],
                'cub200': [224, 224]
            }

            if dataset_name:
                self.SIZE = size_map.get(dataset_name, [32, 32])
            else:
                self.SIZE = [32, 32]  # Default fallback

        # Add other potentially needed Mammoth attributes
        self.classes_names = getattr(avalanche_dataset, 'classes_names', None)

        # Create test_loaders from test_stream (what Mammoth expects)
        self.test_loaders = []
        for experience in self.test_stream:
            from torch.utils.data import DataLoader
            loader = DataLoader(experience.dataset, batch_size=64, shuffle=False)
            self.test_loaders.append(loader)

        # Create train_loaders from train_stream
        self.train_loaders = []
        for experience in self.train_stream:
            from torch.utils.data import DataLoader
            loader = DataLoader(experience.dataset, batch_size=64, shuffle=False)
            self.train_loaders.append(loader)

    def __getattr__(self, name):
        """Delegate any other attribute access to the original dataset."""
        return getattr(self.avalanche_dataset, name)

    def get_offsets(self, task=None):
        """Get dataset offsets for Mammoth compatibility."""
        if task is not None:
            # Return class range for a specific task
            start_class = task * self.N_CLASSES_PER_TASK
            end_class = start_class + self.N_CLASSES_PER_TASK
            return (start_class, end_class)
        else:
            # Return total classes info (original behavior)
            n_classes_per_task = self.N_CLASSES_PER_TASK
            n_seen_classes = self.N_CLASSES
            return (n_seen_classes, n_classes_per_task)

    def get_loss(self):
        """Get loss function for Mammoth compatibility."""
        from torch.nn import CrossEntropyLoss
        return CrossEntropyLoss()

    def get_data_loaders(self):
        """Mammoth-style data loader access."""
        return self.train_loaders, self.test_loaders

    def get_transform(self):
        """Get dataset transform if available."""
        if hasattr(self.avalanche_dataset, 'transform'):
            return self.avalanche_dataset.transform
        return None

    def get_normalization_transform(self):
        """Get normalization transform for the dataset."""
        from torchvision import transforms
        if self.SIZE == [32, 32]:  # CIFAR
            return transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        elif self.SIZE == [64, 64]:  # TinyImageNet
            return transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        elif self.SIZE == [224, 224]:  # CUB200
            return transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

        # Default normalization
        return transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

    def get_denormalization_transform(self):
        """Get denormalization transform for the dataset."""

        class DeNormalize:
            def __init__(self, mean, std):
                self.mean = mean
                self.std = std

            def __call__(self, tensor):
                import torch
                for t, m, s in zip(tensor, self.mean, self.std):
                    t.mul_(s).add_(m)
                return tensor

        return DeNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    def get_backbone(self):
        """Get backbone architecture info."""
        return {'type': 'resnet18', 'num_classes': self.N_CLASSES}

    def get_epochs(self):
        """Get number of epochs."""
        return getattr(self, 'n_epochs', 100)

    def get_batch_size(self):
        """Get batch size."""
        return getattr(self, 'batch_size', 64)

    def get_device(self):
        """Get device."""
        return getattr(self, 'device', 'cuda')

    def get_scheduler(self):
        """Get learning rate scheduler info."""
        return None  # No scheduler by default

    def get_optimizer(self):
        """Get optimizer info."""
        return {'type': 'sgd', 'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0002}

    def get_classes_per_task(self):
        """Get number of classes per task."""
        return self.N_CLASSES_PER_TASK

    def get_num_tasks(self):
        """Get number of tasks."""
        return self.N_TASKS

    def get_num_classes(self):
        """Get total number of classes."""
        return self.N_CLASSES

    def get_setting(self):
        """Get continual learning setting."""
        return self.SETTING


# OR, if you want to add these methods one by one to your existing class:

# In your existing mers_utils/avalnche_to_mammoth.py file, add these methods
# to your existing AvalancheToMammothDataset class:

def get_offsets(self):
    """Get dataset offsets for Mammoth compatibility."""
    n_classes_per_task = self.N_CLASSES_PER_TASK
    n_seen_classes = self.N_CLASSES
    return (n_seen_classes, n_classes_per_task)


def get_loss(self):
    """Get loss function for Mammoth compatibility."""
    from torch.nn import CrossEntropyLoss
    return CrossEntropyLoss()
