"""
Data Loaders for HPC Evaluation

This module provides data loading utilities for CIFAR-10/100, ImageNet,
and CIFAR-10H human annotations used in the HPC paper.
"""

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import json
from typing import Tuple, Optional, Dict, List, Union
import pickle
from PIL import Image
import warnings


class CIFAR10HDataset(Dataset):
    """
    CIFAR-10H dataset with human annotations.
    
    Loads CIFAR-10 images with corresponding human label distributions
    from the CIFAR-10H dataset for human-centric evaluation.
    """
    
    def __init__(
        self, 
        root: str = "./data",
        split: str = "test",
        transform: Optional[transforms.Compose] = None,
        human_annotations_path: Optional[str] = None,
        download: bool = True
    ):
        """
        Initialize CIFAR-10H dataset.
        
        Args:
            root: Root directory for CIFAR-10 data
            split: 'train' or 'test' 
            transform: Optional torchvision transforms
            human_annotations_path: Path to human annotations file
            download: Whether to download CIFAR-10 if missing
        """
        self.root = root
        self.split = split
        self.transform = transform
        
        # Load base CIFAR-10 dataset
        train_mode = (split == "train")
        self.cifar10 = torchvision.datasets.CIFAR10(
            root=root, train=train_mode, download=download
        )
        
        # Load human annotations if available
        self.human_labels = None
        if human_annotations_path and os.path.exists(human_annotations_path):
            self.load_human_annotations(human_annotations_path)
        else:
            warnings.warn("Human annotations not found. Using synthetic distributions.")
            self._create_synthetic_human_labels()
    
    def load_human_annotations(self, annotations_path: str):
        """Load human annotations from file."""
        if annotations_path.endswith('.json'):
            with open(annotations_path, 'r') as f:
                annotations = json.load(f)
            # Convert to numpy array format
            self.human_labels = np.array(annotations['human_labels'])
        elif annotations_path.endswith('.npy'):
            self.human_labels = np.load(annotations_path)
        elif annotations_path.endswith('.pkl'):
            with open(annotations_path, 'rb') as f:
                data = pickle.load(f)
                self.human_labels = data['human_labels']
        
        print(f"Loaded human annotations: {self.human_labels.shape}")
    
    def _create_synthetic_human_labels(self):
        """Create synthetic human label distributions."""
        n_samples = len(self.cifar10)
        n_classes = 10
        
        # Create noisy human distributions around true labels
        self.human_labels = np.zeros((n_samples, n_classes))
        
        for i in range(n_samples):
            true_label = self.cifar10[i][1]
            
            # Create distribution with noise
            dist = np.random.dirichlet(np.ones(n_classes) * 0.1)
            # Boost true label probability
            dist[true_label] += 0.7
            # Renormalize
            dist = dist / dist.sum()
            
            self.human_labels[i] = dist
    
    def __len__(self) -> int:
        return len(self.cifar10)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Get item with image, true label, and human distribution.
        
        Returns:
            (image, true_label, human_distribution)
        """
        image, true_label = self.cifar10[idx]
        
        if self.transform:
            image = self.transform(image)
        
        # Get human distribution
        if self.human_labels is not None:
            human_dist = torch.from_numpy(self.human_labels[idx]).float()
        else:
            # Fallback: one-hot true label
            human_dist = F.one_hot(torch.tensor(true_label), 10).float()
        
        return image, torch.tensor(true_label), human_dist


class CIFAR100Dataset(Dataset):
    """CIFAR-100 dataset wrapper with optional human distributions."""
    
    def __init__(
        self,
        root: str = "./data", 
        split: str = "test",
        transform: Optional[transforms.Compose] = None,
        download: bool = True,
        use_synthetic_humans: bool = True
    ):
        self.root = root
        self.split = split  
        self.transform = transform
        
        # Load CIFAR-100
        train_mode = (split == "train")
        self.cifar100 = torchvision.datasets.CIFAR100(
            root=root, train=train_mode, download=download
        )
        
        # Create synthetic human distributions if requested
        if use_synthetic_humans:
            self._create_synthetic_human_labels()
        else:
            self.human_labels = None
    
    def _create_synthetic_human_labels(self):
        """Create synthetic human distributions for CIFAR-100."""
        n_samples = len(self.cifar100)
        n_classes = 100
        
        self.human_labels = np.zeros((n_samples, n_classes))
        
        for i in range(n_samples):
            true_label = self.cifar100[i][1]
            
            # Create more concentrated distributions for 100 classes
            dist = np.random.dirichlet(np.ones(n_classes) * 0.01)
            dist[true_label] += 0.8
            dist = dist / dist.sum()
            
            self.human_labels[i] = dist
    
    def __len__(self) -> int:
        return len(self.cifar100)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        image, true_label = self.cifar100[idx]
        
        if self.transform:
            image = self.transform(image)
        
        if self.human_labels is not None:
            human_dist = torch.from_numpy(self.human_labels[idx]).float()
        else:
            human_dist = F.one_hot(torch.tensor(true_label), 100).float()
        
        return image, torch.tensor(true_label), human_dist


class ImageNetSubset(Dataset):
    """
    ImageNet subset for HPC evaluation.
    
    Uses a smaller subset for computational efficiency while maintaining
    diversity for robust evaluation.
    """
    
    def __init__(
        self,
        root: str,
        split: str = "val",
        transform: Optional[transforms.Compose] = None,
        subset_size: Optional[int] = 10000,
        classes_subset: Optional[List[int]] = None
    ):
        self.root = root
        self.split = split
        self.transform = transform
        
        # Load ImageNet
        if split == "train":
            self.imagenet = torchvision.datasets.ImageNet(root, split='train')
        else:
            self.imagenet = torchvision.datasets.ImageNet(root, split='val')
        
        # Create subset indices
        self.subset_indices = self._create_subset_indices(subset_size, classes_subset)
        
    def _create_subset_indices(
        self, 
        subset_size: Optional[int],
        classes_subset: Optional[List[int]]
    ) -> List[int]:
        """Create balanced subset indices."""
        if subset_size is None and classes_subset is None:
            return list(range(len(self.imagenet)))
        
        indices = []
        class_counts = {}
        
        # If specific classes requested
        if classes_subset is not None:
            target_classes = set(classes_subset)
        else:
            target_classes = None
        
        # Calculate samples per class if size limit specified
        if subset_size is not None:
            n_target_classes = len(classes_subset) if classes_subset else 1000
            samples_per_class = subset_size // n_target_classes
        else:
            samples_per_class = float('inf')
        
        for idx in range(len(self.imagenet)):
            _, label = self.imagenet[idx]
            
            # Filter by class if specified
            if target_classes is not None and label not in target_classes:
                continue
            
            # Limit samples per class
            if class_counts.get(label, 0) >= samples_per_class:
                continue
            
            indices.append(idx)
            class_counts[label] = class_counts.get(label, 0) + 1
            
            if subset_size is not None and len(indices) >= subset_size:
                break
        
        return indices
    
    def __len__(self) -> int:
        return len(self.subset_indices)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        actual_idx = self.subset_indices[idx]
        image, true_label = self.imagenet[actual_idx]
        
        if self.transform:
            image = self.transform(image)
        
        # Create synthetic human distribution (no ImageNet human labels available)
        n_classes = 1000
        human_dist = F.one_hot(torch.tensor(true_label), n_classes).float()
        # Add small amount of noise
        noise = torch.rand(n_classes) * 0.05
        human_dist = human_dist + noise
        human_dist = human_dist / human_dist.sum()
        
        return image, torch.tensor(true_label), human_dist


def get_standard_transforms(dataset: str, split: str) -> transforms.Compose:
    """
    Get standard data transforms for different datasets.
    
    Args:
        dataset: 'cifar10', 'cifar100', or 'imagenet'
        split: 'train' or 'test'/'val'
    
    Returns:
        Torchvision transform composition
    """
    if dataset.lower() in ['cifar10', 'cifar100']:
        if split == 'train':
            return transforms.Compose([
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomCrop(32, padding=4),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
        else:
            return transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])
    
    elif dataset.lower() == 'imagenet':
        if split == 'train':
            return transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            return transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    else:
        raise ValueError(f"Unknown dataset: {dataset}")


def create_data_loader(
    dataset_name: str,
    split: str = "test",
    batch_size: int = 128,
    num_workers: int = 4,
    data_root: str = "./data",
    human_annotations_path: Optional[str] = None,
    **kwargs
) -> DataLoader:
    """
    Create data loader for specified dataset.
    
    Args:
        dataset_name: 'cifar10h', 'cifar10', 'cifar100', 'imagenet'
        split: 'train', 'test', or 'val'
        batch_size: Batch size for data loader
        num_workers: Number of worker processes
        data_root: Root directory for data
        human_annotations_path: Path to human annotations
        **kwargs: Additional arguments passed to dataset
    
    Returns:
        PyTorch DataLoader
    """
    # Get transforms
    if dataset_name == 'cifar10h':
        base_name = 'cifar10'
    else:
        base_name = dataset_name
    
    transform = get_standard_transforms(base_name, split)
    
    # Create dataset
    if dataset_name == 'cifar10h':
        dataset = CIFAR10HDataset(
            root=data_root,
            split=split,
            transform=transform,
            human_annotations_path=human_annotations_path,
            **kwargs
        )
    
    elif dataset_name == 'cifar10':
        # Standard CIFAR-10 without human annotations
        train_mode = (split == 'train')
        base_dataset = torchvision.datasets.CIFAR10(
            root=data_root, train=train_mode, download=True, transform=transform
        )
        
        # Wrap to provide consistent interface
        class CIFAR10Wrapper(Dataset):
            def __init__(self, base_dataset):
                self.base_dataset = base_dataset
            
            def __len__(self):
                return len(self.base_dataset)
            
            def __getitem__(self, idx):
                image, label = self.base_dataset[idx]
                human_dist = F.one_hot(torch.tensor(label), 10).float()
                return image, torch.tensor(label), human_dist
        
        dataset = CIFAR10Wrapper(base_dataset)
    
    elif dataset_name == 'cifar100':
        dataset = CIFAR100Dataset(
            root=data_root,
            split=split, 
            transform=transform,
            **kwargs
        )
    
    elif dataset_name == 'imagenet':
        dataset = ImageNetSubset(
            root=data_root,
            split=split,
            transform=transform,
            **kwargs
        )
    
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Create data loader
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(split == 'train'),
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=False
    )
    
    return loader


# Example usage
if __name__ == "__main__":
    print("Testing data loaders...")
    
    # Test CIFAR-10H loader
    try:
        loader = create_data_loader(
            'cifar10h', 
            split='test',
            batch_size=32,
            num_workers=2
        )
        
        # Get one batch
        images, true_labels, human_dists = next(iter(loader))
        print(f"CIFAR-10H batch: {images.shape}, {true_labels.shape}, {human_dists.shape}")
        print(f"Human dist example: {human_dists[0]}")
        print(f"Human dist sum: {human_dists[0].sum():.4f}")
        
    except Exception as e:
        print(f"CIFAR-10H loader failed: {e}")
    
    # Test CIFAR-100 loader  
    try:
        loader = create_data_loader(
            'cifar100',
            split='test', 
            batch_size=32,
            num_workers=2
        )
        
        images, true_labels, human_dists = next(iter(loader))
        print(f"CIFAR-100 batch: {images.shape}, {true_labels.shape}, {human_dists.shape}")
        
    except Exception as e:
        print(f"CIFAR-100 loader failed: {e}")
    
    print("Data loader tests completed.")
