"""
Vision dataset loading and preprocessing tools
Supports multiple vision datasets, providing data for CLIP/ViT experiments
"""

import os
from typing import Tuple, Dict, Any

from torch.utils.data import Dataset, DataLoader, random_split

try:
    import torchvision
    from torchvision import transforms
except ImportError:
    torchvision = None
    transforms = None

try:
    from datasets import load_dataset
except ImportError:
    load_dataset = None


class VisionDataset(Dataset):
    """
    Generic vision dataset wrapper that supports different formats
    """
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item[0] if isinstance(item, (list, tuple)) else item["image"]
        label = item[1] if isinstance(item, (list, tuple)) else item["label"]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label


def get_vision_dataset(name: str, root: str = "./data", train: bool = True) -> VisionDataset:
    """
    Load vision dataset
    
    Supported datasets:
    - CIFAR10/100
    - ImageNet-1K (local path)
    - EuroSAT
    - ISIC (skin cancer)
    """
    if torchvision is None:
        raise ImportError("torchvision is required for vision datasets")
    
    # Common transforms
    transform = transforms.Compose([
        transforms.Resize(224),  # Input size for CLIP/ViT
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    if name.lower() == "cifar10":
        ds_raw = torchvision.datasets.CIFAR10(
            root=root, train=train, download=True, transform=None
        )
        return VisionDataset(ds_raw, transform)
    
    elif name.lower() == "cifar100":
        ds_raw = torchvision.datasets.CIFAR100(
            root=root, train=train, download=True, transform=None
        )
        return VisionDataset(ds_raw, transform)
    
    elif name.lower() == "eurosat":
        # EuroSAT can be obtained through HuggingFace Datasets or torchvision
        if load_dataset is not None:
            hf_ds = load_dataset("timm/eurosat-rgb", split="train" if train else "test")
            
            class EuroSATHF(Dataset):
                def __init__(self, hf_dataset, transform=None):
                    self.dataset = hf_dataset
                    self.transform = transform
                    
                def __len__(self):
                    return len(self.dataset)
                    
                def __getitem__(self, idx):
                    item = self.dataset[idx]
                    image = item["image"].convert("RGB")
                    label = item["label"]
                    
                    if self.transform:
                        image = self.transform(image)
                    
                    return image, label
            
            return EuroSATHF(hf_ds, transform)
        else:
            raise ImportError("datasets library is required for EuroSAT")
    
    elif name.lower() == "imagenet":
        # ImageNet requires local path
        if not os.path.exists(root):
            raise ValueError(f"ImageNet path {root} does not exist. Please provide a valid path.")
        
        ds_raw = torchvision.datasets.ImageNet(
            root=root, split="train" if train else "val", transform=None
        )
        return VisionDataset(ds_raw, transform)
    
    elif name.lower() == "isic":
        if load_dataset is not None:
            try:
                hf_ds = load_dataset("marmal88/skin_cancer", split="train" if train else "test")
                
                class ISICHF(Dataset):
                    def __init__(self, hf_dataset, transform=None):
                        self.dataset = hf_dataset
                        self.transform = transform
                        
                    def __len__(self):
                        return len(self.dataset)
                        
                    def __getitem__(self, idx):
                        item = self.dataset[idx]
                        image = item["image"].convert("RGB")
                        label = item["label"]
                        
                        if self.transform:
                            image = self.transform(image)
                        
                        return image, label
                
                return ISICHF(hf_ds, transform)
            except Exception as e:
                print(f"Failed to load ISIC via HuggingFace: {e}")
                raise ValueError("ISIC dataset loading failed. Please check the dataset path.")
        else:
            raise ImportError("datasets library is required for ISIC dataset")
    
    else:
        raise ValueError(f"Unknown dataset: {name}")


def make_vision_loaders(
    name: str,
    root: str,
    batch_size: int,
    val_fraction: float = 0.1,
    num_workers: int = 4,
    pin_memory: bool = True,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Create training, validation, and test data loaders
    
    Args:
        name: Dataset name
        root: Data storage path
        batch_size: Batch size
        val_fraction: Validation set ratio
        num_workers: Data loading processes
        pin_memory: Whether to use pinned memory
    
    Returns:
        train_loader, val_loader, test_loader
    """
    # Get training dataset
    ds_train_full = get_vision_dataset(name, root=root, train=True)
    
    # Split training and validation sets
    n_full = len(ds_train_full)
    n_val = int(val_fraction * n_full)
    n_train = n_full - n_val
    ds_train, ds_val = random_split(ds_train_full, [n_train, n_val])
    
    # Create training and validation loaders
    train_loader = DataLoader(
        ds_train, batch_size=batch_size, shuffle=True, 
        num_workers=num_workers, pin_memory=pin_memory
    )
    
    val_loader = DataLoader(
        ds_val, batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=pin_memory
    )
    
    # Create test loader
    ds_test = get_vision_dataset(name, root=root, train=False)
    test_loader = DataLoader(
        ds_test, batch_size=batch_size, shuffle=False, 
        num_workers=num_workers, pin_memory=pin_memory
    )
    
    return train_loader, val_loader, test_loader


def get_dataset_info(name: str) -> Dict[str, Any]:
    """
    Get basic dataset information
    """
    if name.lower() == "cifar10":
        return {
            "num_classes": 10,
            "input_size": (3, 224, 224),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "class_names": ["airplane", "automobile", "bird", "cat", "deer", 
                           "dog", "frog", "horse", "ship", "truck"]
        }
    
    elif name.lower() == "cifar100":
        return {
            "num_classes": 100,
            "input_size": (3, 224, 224),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "class_names": [
                'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 
                'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 
                'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 
                'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 
                'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 
                'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 
                'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 
                'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 
                'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 
                'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 
                'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 
                'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 
                'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 
                'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
            ]
        }
    
    elif name.lower() == "eurosat":
        return {
            "num_classes": 10,
            "input_size": (3, 224, 224),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "class_names": ["AnnualCrop", "Forest", "HerbaceousVegetation", "Highway", 
                           "Industrial", "Pasture", "PermanentCrop", "Residential", 
                           "River", "SeaLake"]
        }
    
    elif name.lower() == "isic":
        return {
            "num_classes": 8,
            "input_size": (3, 224, 224),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
        }
    
    elif name.lower() == "imagenet":
        return {
            "num_classes": 1000,
            "input_size": (3, 224, 224),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
        }
    
    else:
        raise ValueError(f"Unknown dataset: {name}")