import random
import torch 
import numpy as np
from torchvision import transforms
from avalanche.benchmarks import benchmark_with_validation_stream, split_validation_class_balanced
from avalanche.benchmarks.classic import SplitCIFAR10, SplitCIFAR100, SplitTinyImageNet

def load_benchmark(benchmark_name: str, augmentation: str = 'standard', seed: int = 93):
    """
    Load a benchmark by its name.
    
    Args:
        benchmark_name (str): The name of the benchmark to load.  Possible values are 'cifar10', 'cifar100'. 'tinyimagenet
        augmentation (str): The type of augmentation to apply. Default is 'standard'. Possible values are 'auto', 'standard'
        seed (int): Random seed for reproducibility. Default is 93.
        
    Returns:
        list: loaded benchmark, number of classes, train augmentation, test augmentation.
    """

    if benchmark_name == 'cifar10':
        mean = [0.5071, 0.4865, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        n_classes = 10

        if augmentation == 'auto':
            train_transform = transforms.Compose([
                transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),  # AutoAugment CIFAR policy
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        elif augmentation == 'standard':
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),       
                transforms.RandomHorizontalFlip(),          
                transforms.ToTensor(),                      
                transforms.Normalize(mean, std)
            ])
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation}. Use 'auto' or 'standard'.")

        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        validation_size = 0.2
        benchmark = SplitCIFAR10(
            n_experiences=5,
            seed=seed,
            train_transform=train_transform,
            eval_transform=eval_transform
        )

    elif benchmark_name == 'cifar100':
        mean = [0.5071, 0.4865, 0.4409]
        std = [0.2673, 0.2564, 0.2762]
        n_classes = 100

        if augmentation == 'auto':
            train_transform = transforms.Compose([
                transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),  # AutoAugment CIFAR policy
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        elif augmentation == 'standard':
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),       
                transforms.RandomHorizontalFlip(),          
                transforms.ToTensor(),                      
                transforms.Normalize(mean, std)
            ])
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation}. Use 'auto' or 'standard'.")

        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        validation_size = 0.2
        benchmark = SplitCIFAR100(
            n_experiences=10,
            seed=seed,
            train_transform=train_transform,
            eval_transform=eval_transform
        )

    elif benchmark_name == 'tinyimagenet':
        mean = [0.4802, 0.4481, 0.3975]
        std = [0.2302, 0.2265, 0.2262]
        n_classes = 200

        if augmentation == 'auto':
            train_transform = transforms.Compose([
                transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),  # AutoAugment IMAGENET policy
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        elif augmentation == 'standard':
            train_transform = transforms.Compose([
                transforms.RandomCrop(64, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ])
        else:
            raise ValueError(f"Unknown augmentation type: {augmentation}. Use 'auto' or 'standard'.")

        eval_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        validation_size = 0.2
        benchmark = SplitTinyImageNet(
            n_experiences=10,
            seed=seed,
            train_transform=train_transform,
            eval_transform=eval_transform
        )

    elif benchmark_name in ['bloodmnist', 'dermamnist']:
        import medmnist
        from avalanche.benchmarks import nc_benchmark
        from avalanche.benchmarks.utils import AvalancheDataset, DataAttribute
        from collections import Counter
        from torch.utils.data import ConcatDataset

        class MedMNISTDataset(torch.utils.data.Dataset):
            """Wrap a dataset so labels are always Python ints."""
            def __init__(self, base_dataset):
                self.base_dataset = base_dataset

            def __getitem__(self, idx):
                img, label = self.base_dataset[idx]
                return img, int(label)  

            def __len__(self):
                return len(self.base_dataset)

        if benchmark_name == 'bloodmnist':
            mean = [0.7943, 0.6597, 0.6962]
            std = [0.2156, 0.2416, 0.1179]
            train_ds = medmnist.BloodMNIST(split='train', download=True, root='./raw/', transform=transforms.ToTensor())
            test_ds = medmnist.BloodMNIST(split='test', download=True, root='./raw/', transform=transforms.ToTensor())
            val_ds = medmnist.BloodMNIST(split='val', download=True, root='./raw/', transform=transforms.ToTensor())
            num_experiences = 4
            per_exp_classes = None
        elif benchmark_name == 'dermamnist':
            mean = [0.7631, 0.5381, 0.5614]
            std = [0.1366, 0.1543, 0.1692]
            train_ds = medmnist.DermaMNIST(split='train', download=True, root='./raw/', transform=transforms.ToTensor(), size=64)
            test_ds = medmnist.DermaMNIST(split='test', download=True, root='./raw/', transform=transforms.ToTensor(), size=64)
            val_ds = medmnist.DermaMNIST(split='val', download=True, root='./raw/', transform=transforms.ToTensor(), size=64)
            num_experiences = 3
            per_exp_classes = {num_experiences-1: 3} # last task with 3 classes

        train_ds = MedMNISTDataset(train_ds)
        test_ds = MedMNISTDataset(test_ds)
        val_ds = MedMNISTDataset(val_ds)
        combined_ds = ConcatDataset([train_ds, val_ds])

        # --- Define transforms ---
        train_transform = transforms.Compose([      
            transforms.RandomHorizontalFlip(),                             
            transforms.Normalize(mean, std)
        ])
        eval_transform = transforms.Compose([
                        transforms.Normalize(mean, std),
                    ])

        targets = [int(combined_ds[i][1]) for i in range(len(combined_ds))]
        da = DataAttribute(targets, "targets")
        combined_data = AvalancheDataset(combined_ds, data_attributes=[da])

        targets = [int(test_ds[i][1]) for i in range(len(test_ds))]
        da = DataAttribute(targets, "targets")
        test_data = AvalancheDataset(test_ds, data_attributes=[da])

        class_order = list(k for k, _ in Counter(combined_data.targets).most_common())
        n_classes = len(class_order)

        benchmark = nc_benchmark(
            train_dataset=combined_data,
            test_dataset=test_data,
            task_labels=False,
            n_experiences=num_experiences,
            fixed_class_order=class_order,
            per_exp_classes=per_exp_classes,
            train_transform=train_transform,
            eval_transform=eval_transform
        )

        validation_size = len(val_ds) / (len(train_ds) + len(val_ds))

    else:
        raise ValueError(f"Unknown benchmark name: {benchmark_name}. Use 'cifar10', 'cifar100', 'tinyimagenet', 'bloodmnist' or 'dermamnist'.")

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    foo = lambda data: split_validation_class_balanced(validation_size, data)
    benchmark = benchmark_with_validation_stream(benchmark, split_strategy=foo)

    return benchmark, n_classes, train_transform, eval_transform