import torch
from torchvision import transforms, datasets
from torch.utils.data import Subset, Dataset, TensorDataset
from typing import Sequence, Tuple, Dict, Set, Optional
import numpy as np
from .corruption import load_corruptions_cifar, CORRUPTIONS, BenchmarkDataset
    
def load_cifar100c(
    data_dir: str = './data/cifar100c',
    corruptions: Sequence[str] = CORRUPTIONS,
    severity: int = 5,
    n_examples: int = 10000,
    shuffle: bool = False,
) -> Dataset:
    try:
        return load_corruptions_cifar(BenchmarkDataset.cifar_100, n_examples, severity, data_dir, corruptions, shuffle)
    except Exception as e:
        print(f"Error loading CIFAR-100C from {data_dir}: {e}")
        data_dir = './data/cifar100c'
        print(f"Downloading CIFAR-100C to {data_dir}")
        return load_corruptions_cifar(BenchmarkDataset.cifar_100, n_examples, severity, data_dir, corruptions, shuffle)
    
def load_cifar100(
    data_dir: str = './data/cifar100',
    examples_num: int = None,
    shuffle: bool = False,
) -> Tuple[Dataset, Dataset]:
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    try:
        d_tr = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=transform)
        d_ts = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform)
    except Exception as e:
        print(f"Error loading CIFAR-100 from {data_dir}: {e}")
        data_dir = './data/cifar100'
        print(f"Downloading CIFAR-100 to {data_dir}")
        d_tr = datasets.CIFAR100(root=data_dir, train=True, download=True, transform=transform)
        d_ts = datasets.CIFAR100(root=data_dir, train=False, download=True, transform=transform)

    if shuffle:
        d_tr.data = d_tr.data[np.random.permutation(len(d_tr))]
        d_ts.data = d_ts.data[np.random.permutation(len(d_ts))]
    
    if examples_num is not None:
        d_tr = Subset(d_tr, range(examples_num))
        d_ts = Subset(d_ts, range(examples_num))
    
    return d_tr, d_ts

d_num = 10000
n2idx = {
    2: [0, 5],
    3: [0, 3, 5],
    4: [0, 2, 4, 5],
    5: [0, 2, 3, 4, 5],
    6: [0, 1, 2, 3, 4, 5],
}

def get_domain(data_dir: str, data_c_dir: str, domains_num: int, corruption: str, idx: int = None) -> Dataset:
    severity = n2idx[domains_num][idx]
    try:
        if severity == 0:
            d_tr, d_ts = load_cifar100(data_dir, examples_num=d_num)
            return d_ts
        return load_cifar100c(data_c_dir, severity=severity, corruptions=[corruption])
    except Exception as e:
        raise Exception(f"Error loading CIFAR-100C from '{data_dir}' with domains_num {domains_num}, idx {idx}, severity {severity}, corruption '{corruption}'.")

def get_domains(data_dir: str, data_c_dir: str, domains_num: int, corruption: str) -> list[Dataset]:
    domains = []
    for i in range(domains_num):
        domains.append(get_domain(data_dir, data_c_dir, domains_num, corruption, i))
    return domains


# ------------ test-code ------------

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    
    data_dir = "/home/me/share/CIFAR100/"
    data_c_dir = "./data/cifar100c/"
    d_tr, d_ts = load_cifar100(data_dir)
    print(d_tr)
    print(d_ts)
    # plt.figure(figsize=(10, 10))
    # for i in range(10):
    #     plt.subplot(2, 5, i + 1)
    #     plt.imshow(d_tr[i][0].permute(1, 2, 0))
    #     plt.title(d_tr[i][1])
    #     plt.axis("off")
    # plt.tight_layout()
    # plt.savefig("cifar100.png")
    # plt.show()
    # print(d_tr.data.shape)
    
    
    plt.figure(figsize=(64, 120))
    for i,c in enumerate(CORRUPTIONS):
        for j in range(6):
            print(f"{c} {j}")
            dataset = get_domain(data_dir, data_c_dir, 6, c, j)
            plt.subplot(len(CORRUPTIONS), 6, i*6+j+1)
            plt.imshow(dataset[0][0].permute(1, 2, 0))
            plt.title(f"{c} {j}")
            plt.axis("off")
    plt.tight_layout()
    plt.savefig("cifar100c.png")
    plt.show()
    
    
    
    