import os
from pathlib import Path
from typing import Callable, Dict, Optional, Sequence, Set, Tuple
import pdb 

import numpy as np
import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset  
from enum import Enum
from pathlib import Path
from torchvision.datasets import CIFAR10
from zenodo_download import DownloadError, zenodo_download

class BenchmarkDataset(Enum):
    cifar_10 = 'cifar10'
    cifar_100 = 'cifar100'
    imagenet = 'imagenet'

CORRUPTIONS = ("shot_noise", "motion_blur", "snow", "pixelate",
               "gaussian_noise", "defocus_blur", "brightness", "fog",
               "zoom_blur", "frost", "glass_blur", "impulse_noise", "contrast",
               "jpeg_compression", "elastic_transform")

ZENODO_CORRUPTIONS_LINKS: Dict[BenchmarkDataset, Tuple[str, Set[str]]] = {
    BenchmarkDataset.cifar_10: ("2535967", {"CIFAR-10-C.tar"}),
    BenchmarkDataset.cifar_100: ("3555552", {"CIFAR-100-C.tar"})
}

CORRUPTIONS_DIR_NAMES: Dict[BenchmarkDataset, str] = {
    BenchmarkDataset.cifar_10: "CIFAR-10-C",
    BenchmarkDataset.cifar_100: "CIFAR-100-C",
    BenchmarkDataset.imagenet: "ImageNet-C"
}



PREPROCESSINGS = {
    'Res256Crop224': transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor()]),
    'Crop288': transforms.Compose([transforms.CenterCrop(288),
                                   transforms.ToTensor()]),
    'none': transforms.Compose([transforms.ToTensor()]),
}

def load_cifar10c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = './data',
    #shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
    batch_order: str = 'uniform',
    #label_shift: bool = False,
    prepr: Optional[str] = 'none',
) -> Tuple[torch.Tensor, torch.Tensor]:
    return load_corruptions_cifar(BenchmarkDataset.cifar_10, n_examples,
                                  severity, data_dir, corruptions, batch_order)


def load_cifar100c(
    n_examples: int,
    severity: int = 5,
    data_dir: str = './data',
    #shuffle: bool = False,
    corruptions: Sequence[str] = CORRUPTIONS,
    batch_order: str = 'uniform',
    #label_shift: bool = False,
    prepr: Optional[str] = 'none'
) -> Tuple[torch.Tensor, torch.Tensor]:
    return load_corruptions_cifar(BenchmarkDataset.cifar_100, n_examples,
                                  severity, data_dir, corruptions, batch_order)


                                  
def load_corruptions_cifar(
        dataset: str,
        n_examples: int,
        severity: int,
        data_dir: str,
        corruptions: Sequence[str] = CORRUPTIONS,
        batch_order: str = 'uniform'
        #shuffle: bool = False,
        #label_shift: bool = False
        ) -> Tuple[torch.Tensor, torch.Tensor]:
    assert 1 <= severity <= 5
    n_total_cifar = 10000

    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    data_dir = Path(data_dir)
    data_root_dir = data_dir / CORRUPTIONS_DIR_NAMES[dataset]
    print(data_root_dir)
    if not data_root_dir.exists():
        zenodo_download(*ZENODO_CORRUPTIONS_LINKS[dataset], save_dir=data_dir)

    # Download labels
    labels_path = data_root_dir / 'labels.npy'
    if not os.path.isfile(labels_path):
        raise DownloadError("Labels are missing, try to re-download them.")
    labels = np.load(labels_path)

    x_test_list, y_test_list = [], []
    n_pert = len(corruptions)
    for corruption in corruptions:
        corruption_file_path = data_root_dir / (corruption + '.npy')
        if not corruption_file_path.is_file():
            raise DownloadError(
                f"{corruption} file is missing, try to re-download it.")

        images_all = np.load(corruption_file_path)
        images = images_all[(severity - 1) * n_total_cifar:severity *
                            n_total_cifar]
        n_img = int(np.ceil(n_examples / n_pert))
        x_test_list.append(images[:n_img])
        # Duplicate the same labels potentially multiple times
        y_test_list.append(labels[:n_img])

    x_test, y_test = np.concatenate(x_test_list), np.concatenate(y_test_list)
    
    if batch_order == 'by_class':
        y_test, x_test = sort_by_class(y_test, x_test) 
        
    #pdb.set_trace()
    elif batch_order == 'uniform':
        rand_idx = np.random.permutation(np.arange(len(x_test)))
        x_test, y_test = x_test[rand_idx], y_test[rand_idx]

    # Make it in the PyTorch format
    x_test = np.transpose(x_test, (0, 3, 1, 2))
    # Make it compatible with our models
    x_test = x_test.astype(np.float32) / 255
    # Make sure that we get exactly n_examples but not a few samples more
    x_test = torch.tensor(x_test)[:n_examples]
    y_test = torch.tensor(y_test)[:n_examples]

    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    x_test = (x_test - mean) / std

    return x_test, y_test

    return x_test, y_test

def sort_by_class(labels, imgs):
    unique_classes = np.unique(labels)
    random_order = np.random.permutation(unique_classes)
    # 2. Create a mapping for the random order
    class_map = {cls: i for i, cls in enumerate(random_order)}
    # 3. Sort labels according to the shuffled class order
    sorted_indices = np.argsort([
        class_map[label.item() if isinstance(label, torch.Tensor) else label]
        for label in labels
    ])
    # 4. Reorder images and labels based on the new sorted order
    imgs = imgs[sorted_indices]
    labels = labels[sorted_indices]
    return labels, imgs

def load_cifar10_like_c(data_dir: str = "./data", n_examples: Optional[int] = None):
    dataset = CIFAR10(root=data_dir, train=False, download=True)
    x_list, y_list = [], []
    for i in range(len(dataset)):
        img, label = dataset[i]  # img is PIL
        img = np.array(img)  # convert to ndarray (HWC)
        img = img.transpose(2, 0, 1)  # HWC → CHW
        img = torch.tensor(img).float() / 255.0  # to float32 tensor
        x_list.append(img)
        y_list.append(label)

        if n_examples is not None and len(x_list) >= n_examples:
            break
    x_tensor = torch.stack(x_list)
    y_tensor = torch.tensor(y_list)

    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    x_tensor = (x_tensor - mean) / std

    return x_tensor, y_tensor