import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
import os.path
from torch.utils.data import DataLoader
import pickle
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import check_integrity


class CIFAR10_C(VisionDataset):
    def __init__(self, root, corruption_type, severity=1, transform=None, target_transform=None):
        super().__init__()
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.severity = severity
        self.corruption_type = corruption_type
        
        if severity < 1 or severity > 5:
            raise ValueError("Severity must be between 1 and 5")
        
        corruption_file = os.path.join(root, f"{corruption_type}.npy")
        labels_file = os.path.join(root, "labels.npy")
        
        if not os.path.exists(corruption_file):
            raise FileNotFoundError(f"not find : {corruption_file}")
        if not os.path.exists(labels_file):
            raise FileNotFoundError(f"not find : {labels_file}")
        
        self.data = np.load(corruption_file)
        self.targets = np.load(labels_file)
        
        samples_per_severity = self.data.shape[0] // 5
        start_idx = (severity - 1) * samples_per_severity
        end_idx = severity * samples_per_severity
        
        self.data = self.data[start_idx:end_idx]
        self.targets = self.targets[start_idx:end_idx]
        
        self.tensor_data = torch.from_numpy(self.data).float() / 255.0
        if len(self.tensor_data.shape) == 4:  
            self.tensor_data = self.tensor_data.permute(0, 3, 1, 2).contiguous()
        
        self.targets = torch.from_numpy(self.targets).long()
    
    def __getitem__(self, index: int):
        img, target = self.tensor_data[index], self.targets[index].item()

        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            target = self.target_transform(target)
            
        return img, target
    
    def __len__(self) -> int:
        return len(self.data)



def get_test_loader(batch_size, corruption_type='gaussian_noise', severity=1, 
                   shuffle=False, num_workers=12, pin_memory=True, input_size=None):

    data_dir = ',/dataset/Cifar/CIFAR-10-C/'
    
    normalize = transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )

    transform_list = []
    if input_size and input_size != 32:
        transform_list.append(transforms.Resize((input_size, input_size)))
    transform_list.append(normalize)
    transform = transforms.Compose(transform_list)

    dataset = CIFAR10_C(
        root=data_dir,
        corruption_type=corruption_type,
        severity=severity,
        transform=transform
    )
    
    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0
    )
    
    return data_loader