import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class ImageNetDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # Get all image files and their corresponding class folders
        self.samples = []
        for class_idx, class_name in enumerate(sorted(os.listdir(self.root_dir))):
            class_path = os.path.join(self.root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in sorted(os.listdir(class_path)):
                    if img_name.endswith(('.JPEG', '.jpeg', '.jpg', '.png')):
                        self.samples.append((
                            os.path.join(class_path, img_name),
                            class_idx
                        ))

        # Default transform if none provided
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                   std=[0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load and convert image
        with Image.open(img_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            
            if self.transform:
                img = self.transform(img)
                
        return img, label


class ImageNetCDataset(Dataset):
    """
    Dataset class for loading ImageNet-C data
    
    Args:
        root_dir (str): Path to ImageNet-C root directory
        corruption_type (str): Type of corruption (e.g., 'gaussian_noise', 'shot_noise', etc.)
        severity (int): Severity level of corruption (1-5)
        transform (callable, optional): Optional transform to be applied on images
    """
    def __init__(self, root_dir, corruption_type, severity, transform=None):
        self.root_dir = root_dir
        self.corruption_type = corruption_type
        self.severity = severity
        self.transform = transform
        
        # Validate severity level
        if severity not in range(1, 6):
            raise ValueError("Severity must be between 1 and 5")
            
        # Path to corrupted images
        self.corruption_path = os.path.join(root_dir, corruption_type, str(severity))
        
        # Get all image files and their corresponding class folders
        self.samples = []
        for class_idx, class_name in enumerate(sorted(os.listdir(self.corruption_path))):
            class_path = os.path.join(self.corruption_path, class_name)
            if os.path.isdir(class_path):
                for img_name in sorted(os.listdir(class_path)):
                    if img_name.endswith(('.JPEG', '.jpeg', '.jpg', '.png')):
                        self.samples.append((
                            os.path.join(class_path, img_name),
                            class_idx
                        ))

        # Default transform if none provided
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                # transforms.Normalize(mean=[0.485, 0.456, 0.406],
                #                   std=[0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load and convert image
        with Image.open(img_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            
            if self.transform:
                img = self.transform(img)
                
        return img, label


cifar10_indices = {
    # Airplane
    'airplane': [404, 895],  # airliner, warplane
    
    # Automobile
    'automobile': [436, 511, 661, 817],  # beach wagon, convertible, Model T, sports car
    
    # Bird (expanded list)
    'bird': [10, 11, 12, 13, 92, 93, 94, 95, 96],  # includes more birds like jay, magpie, eagle, etc.
    
    # Cat
    'cat': [281, 282, 283, 284, 285],  # tabby, tiger cat, Persian cat, Siamese cat, Egyptian cat
    
    # Deer (added more deer-like animals)
    # 'deer': [351, 352, 353],  # hartebeest, impala, gazelle
    
    # Dog (common dog breeds)
    'dog': list(range(151, 269))[::20],  # All dog breeds from Chihuahua to Mexican hairless
    
    # Frog
    'frog': [30, 31],  # bullfrog, tree frog
    
    # Horse
    # 'horse': [603],  # horse cart (closest match, though not ideal)
    
    # Ship
    'ship': [510, 628, 724, 780, 833],  # container ship, liner, pirate ship, schooner, submarine
    
    # Truck
    'truck': [555, 569, 675, 717, 734, 864, 867],  # fire truck, garbage truck, moving van, pickup truck, police van, tow truck, trailer truck
}
