import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import check_integrity


class FastCIFAR100(VisionDataset):
    
    base_folder = "cifar-100-python"
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]
    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]
    meta = {
        "filename": "meta",
        "key": "fine_label_names",
        "md5": "7973b15100ade9c7d40fb424638fde48",
    }
    
    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
        cache_size: int = 50000  # 完整数据集大小
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        
        self.train = train
        self.cache = {}
        self.cache_size = cache_size
        
        if download:
            self.download()
            
        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
            
        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list
            
        self.data = []
        self.targets = []
        
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])
                    
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1)) 
        
        self.tensor_data = torch.from_numpy(self.data).permute(0, 3, 1, 2).contiguous().float() / 255.0
        
        self.targets = torch.tensor(self.targets)
        self._load_meta()
    
    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        if not check_integrity(path, self.meta["md5"]):
            raise RuntimeError("Dataset metadata file not found or corrupted.")
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
  
        img, target = self.tensor_data[index], self.targets[index].item()
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        result = (img, target) 
        return result
    
    def __len__(self) -> int:
        return len(self.data)
        
    def _check_integrity(self) -> bool:
        for filename, md5 in self.train_list + self.test_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True
    
    def download(self) -> None:
        from torchvision.datasets.utils import download_and_extract_archive
        
        if self._check_integrity():
            print("Files already downloaded and verified")
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
    
    def extra_repr(self) -> str:
        return f"Split: {'Train' if self.train else 'Test'}"


def get_train_valid_loader(batch_size, augment=True, random_seed=42, valid_size=0.1,
                          shuffle=True, num_workers=4, pin_memory=True, get_val_temp=0, input_size=None):

    data_dir = './dataset/Cifar/'
    assert 0 <= valid_size <= 1, "valid_size should be in the range [0, 1]."
    
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    if augment:
        train_transform_list = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
        if input_size and input_size != 32:
            train_transform_list.append(transforms.Resize((input_size, input_size)))
        train_transform_list.append(normalize)
        train_transform = transforms.Compose(train_transform_list)
    else:
        transform_list = []
        if input_size and input_size != 32:
            transform_list.append(transforms.Resize((input_size, input_size)))
        transform_list.append(normalize)
        train_transform = transforms.Compose(transform_list)
        
    valid_transform_list = []
    if input_size and input_size != 32:
        valid_transform_list.append(transforms.Resize((input_size, input_size)))
    valid_transform_list.append(normalize)
    valid_transform = transforms.Compose(valid_transform_list)
    
    train_dataset = FastCIFAR100(
        root=data_dir,
        train=True,
        download=True,
        transform=train_transform
    )
    
    valid_dataset = FastCIFAR100(
        root=data_dir,
        train=True,
        download=False,
        transform=valid_transform
    )
    
    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))
    
    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    
    train_idx, valid_idx = indices[split:], indices[:split]
    
    valid_temp_loader = None
    if get_val_temp > 0:
        valid_temp_dataset = FastCIFAR100(
            root=data_dir,
            train=True,
            download=False,
            transform=valid_transform
        )
        split_temp = int(np.floor(get_val_temp * split))
        valid_idx, valid_temp_idx = valid_idx[split_temp:], valid_idx[:split_temp]
        valid_temp_sampler = SubsetRandomSampler(valid_temp_idx)
        valid_temp_loader = torch.utils.data.DataLoader(
            valid_temp_dataset,
            batch_size=batch_size,
            sampler=valid_temp_sampler,
            num_workers=num_workers,
            pin_memory=pin_memory
        )
    
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0
    )
    
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        sampler=valid_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0
    )
    
    if get_val_temp > 0:
        return (train_loader, valid_loader, valid_temp_loader)
    else:
        return (train_loader, valid_loader, train_idx)


def get_test_loader(batch_size, shuffle=True, num_workers=4, pin_memory=True, input_size=None):
    data_dir = './dataset/Cifar/'
    
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )
    
    transform_list = []
    if input_size and input_size != 32:
        transform_list.append(transforms.Resize((input_size, input_size)))
    transform_list.append(normalize)
    transform = transforms.Compose(transform_list)
    
    dataset = FastCIFAR100(
        root=data_dir,
        train=False,
        download=True,
        transform=transform
    )
    
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0
    )
    
    return data_loader