"""
ImageNet data loading utilities.

Provides ImageFolder-based loaders with proper transforms for ViT-B/16,
supporting distributed training and chunked loading for memory efficiency.

DDP-aware: Provides distributed samplers for sharding data across ranks.
Performance optimizations:
    - persistent_workers=True: Avoids worker respawn overhead
    - prefetch_factor=4: Preloads batches ahead for GPU overlap
    - pin_memory=True: Faster CPU→GPU transfer
    - multiprocessing_context='spawn': Better CUDA compatibility"""

import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset, Subset, DistributedSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import ViT_B_16_Weights
from typing import Optional, Tuple, List
import numpy as np

from .logging_utils import get_logger, log_dict, is_ddp, DDPState

logger = get_logger("ifc_vit.data")

# Default dataloader settings for optimal GPU utilization
DEFAULT_NUM_WORKERS = 20
DEFAULT_PREFETCH_FACTOR = 4 # Batches to prefetch per worker
DEFAULT_PERSISTENT_WORKERS = True  # Keep workers alive between epochs


# Standard ImageNet normalization
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


def get_train_transform() -> transforms.Compose:
    """Get training transforms for ViT-B/16."""
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),  # Use center crop for consistency
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


def get_val_transform() -> transforms.Compose:
    """Get validation transforms for ViT-B/16."""
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])


class ImageNetDataset:
    """
    ImageNet dataset wrapper with efficient loading.
    
    Supports:
    - Train/val splits
    - Distributed sampling
    - Subset loading by indices
    - Chunked iteration for memory efficiency
    """
    
    def __init__(
        self,
        root: str,
        split: str = "train",
        transform: Optional[transforms.Compose] = None,
        subset_n: Optional[int] = None,
    ):
        """
        Args:
            root: Path to ImageNet root directory (containing 'train' and 'val' folders)
            split: 'train' or 'val'
            transform: Optional transforms to apply
        """
        self.root = root
        self.split = split
        
        if transform is None:
            transform = get_train_transform() if split == "train" else get_val_transform()
        
        split_path = os.path.join(root, split)
        logger.info(f"Loading ImageNet {split} from {split_path}")
        if subset_n is not None:
            logger.info(f"Using subset of {subset_n} samples")
            all_dataset = ImageFolder(
                split_path,
                transform=transform
            )
            indices = np.arange(min(subset_n, len(all_dataset)))
            self.dataset = Subset(all_dataset, indices.tolist())
        else:
            self.dataset = ImageFolder(
                split_path,
                transform=transform
            )
            
        self.num_samples = len(self.dataset)
        self.num_classes = 1000
        
        # Log dataset statistics
        logger.info(f"Loaded ImageNet {split}: {self.num_samples:,} samples, {self.num_classes} classes")
        log_dict(logger, f"Dataset info ({split})", {
            'root': root,
            'split': split,
            'num_samples': self.num_samples,
            'num_classes': self.num_classes,
        })
    
    def get_loader(
        self,
        batch_size: int = 64,
        shuffle: bool = False,
        num_workers: int = DEFAULT_NUM_WORKERS,
        pin_memory: bool = True,
        distributed: bool = False,
        drop_last: bool = True,
        prefetch_factor: int = DEFAULT_PREFETCH_FACTOR,
        persistent_workers: bool = DEFAULT_PERSISTENT_WORKERS,
    ) -> DataLoader:
        """
        Get DataLoader for the dataset.
        
        Args:
            batch_size: Batch size
            shuffle: Whether to shuffle (ignored if distributed)
            num_workers: Number of data loading workers
            pin_memory: Pin memory for faster GPU transfer
            distributed: Use distributed sampler
            drop_last: Drop last incomplete batch
            prefetch_factor: Batches to prefetch per worker (default: 4)
            persistent_workers: Keep workers alive between epochs (default: True)
        """
        sampler = None
        if distributed:
            sampler = DistributedSampler(self.dataset, shuffle=shuffle)
            shuffle = False
        
        # Only use persistent_workers and prefetch_factor if num_workers > 0
        use_persistent = persistent_workers and num_workers > 0
        use_prefetch = prefetch_factor if num_workers > 0 else None
        
        return DataLoader(
            self.dataset,
            batch_size=batch_size,
            shuffle=shuffle if not distributed else False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            sampler=sampler,
            drop_last=drop_last,
            persistent_workers=use_persistent,
            prefetch_factor=use_prefetch,
        )
    
    def get_subset_loader(
        self,
        indices: np.ndarray,
        batch_size: int = 64,
        shuffle: bool = False,
        num_workers: int = DEFAULT_NUM_WORKERS,
        pin_memory: bool = True,
        distributed: bool = False,
        drop_last: bool = False,
        prefetch_factor: int = DEFAULT_PREFETCH_FACTOR,
        persistent_workers: bool = DEFAULT_PERSISTENT_WORKERS,
    ) -> DataLoader:
        """
        Get DataLoader for a subset of the dataset.
        
        Args:
            indices: Array of sample indices to include
            batch_size: Batch size
            shuffle: Whether to shuffle
            num_workers: Number of workers
            pin_memory: Pin memory
            distributed: Use DistributedSampler (shards across ranks)
            drop_last: Drop last incomplete batch
            prefetch_factor: Batches to prefetch per worker (default: 4)
            persistent_workers: Keep workers alive between epochs (default: True)
        """
        subset = Subset(self.dataset, indices.tolist())
        
        sampler = None
        if distributed:
            sampler = DistributedSampler(subset, shuffle=shuffle)
            shuffle = False
        
        # Only use persistent_workers and prefetch_factor if num_workers > 0
        use_persistent = persistent_workers and num_workers > 0
        use_prefetch = prefetch_factor if num_workers > 0 else None
        
        return DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=shuffle if not distributed else False,
            num_workers=num_workers,
            pin_memory=pin_memory,
            sampler=sampler,
            drop_last=drop_last,
            persistent_workers=use_persistent,
            prefetch_factor=use_prefetch,
        )
    
    def iter_chunks(
        self,
        chunk_size: int = 200_000,
        batch_size: int = 64,
        num_workers: int = 8,
    ):
        """
        Iterate over dataset in chunks for memory-efficient processing.
        
        Yields:
            Tuple of (chunk_indices, chunk_loader)
        """
        total = self.num_samples
        
        for start_idx in range(0, total, chunk_size):
            end_idx = min(start_idx + chunk_size, total)
            indices = np.arange(start_idx, end_idx)
            
            loader = self.get_subset_loader(
                indices,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
            )
            
            yield indices, loader
    
    def __len__(self) -> int:
        return self.num_samples


class IndexedDataset(Dataset):
    """
    Wrapper that returns (index, image, label) tuples.
    
    Useful for tracking which samples are processed.
    """
    
    def __init__(self, dataset: Dataset):
        self.dataset = dataset
    
    def __getitem__(self, idx: int) -> Tuple[int, torch.Tensor, int]:
        image, label = self.dataset[idx]
        return idx, image, label
    
    def __len__(self) -> int:
        return len(self.dataset)


def get_indexed_loader(
    root: str,
    split: str = "train",
    batch_size: int = 64,
    shuffle: bool = False,
    num_workers: int = DEFAULT_NUM_WORKERS,
    indices: Optional[np.ndarray] = None,
    prefetch_factor: int = DEFAULT_PREFETCH_FACTOR,
    persistent_workers: bool = DEFAULT_PERSISTENT_WORKERS,
) -> DataLoader:
    """
    Get a DataLoader that returns (idx, image, label) tuples.
    
    Args:
        root: ImageNet root directory
        split: 'train' or 'val'
        batch_size: Batch size
        shuffle: Whether to shuffle
        num_workers: Number of workers
        indices: Optional subset indices
        prefetch_factor: Batches to prefetch per worker (default: 4)
        persistent_workers: Keep workers alive between epochs (default: True)
    """
    transform = get_train_transform() if split == "train" else get_val_transform()
    
    base_dataset = ImageFolder(os.path.join(root, split), transform=transform)
    
    if indices is not None:
        base_dataset = Subset(base_dataset, indices.tolist())
    
    indexed_dataset = IndexedDataset(base_dataset)
    
    # Only use persistent_workers and prefetch_factor if num_workers > 0
    use_persistent = persistent_workers and num_workers > 0
    use_prefetch = prefetch_factor if num_workers > 0 else None
    
    return DataLoader(
        indexed_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=use_persistent,
        prefetch_factor=use_prefetch,
    )


def load_imagenet(
    root: str,
    batch_size: int = 64,
    num_workers: int = 8,
    subset_n: Optional[int] = None,
) -> Tuple[DataLoader, DataLoader]:
    """
    Load ImageNet train and validation loaders.
    
    Args:
        root: Path to ImageNet root
        batch_size: Batch size for both loaders
        num_workers: Number of data loading workers
        
    Returns:
        train_loader, val_loader
    """
    train_dataset = ImageNetDataset(root, split="train", subset_n=subset_n)
    val_dataset = ImageNetDataset(root, split="train", subset_n=subset_n)
    
    train_loader = train_dataset.get_loader(
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )
    
    val_loader = val_dataset.get_loader(
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
    )
    
    return train_loader, val_loader


def make_distributed_subset_loader(
    root: str,
    split: str,
    indices: np.ndarray,
    batch_size: int = 64,
    num_workers: int = DEFAULT_NUM_WORKERS,
    shuffle: bool = False,
    drop_last: bool = False,
    prefetch_factor: int = DEFAULT_PREFETCH_FACTOR,
    persistent_workers: bool = DEFAULT_PERSISTENT_WORKERS,
) -> Tuple[DataLoader, int, int]:
    """
    Create a DDP-aware subset loader that shards data across ranks.
    
    Each rank gets a disjoint subset of the indices. Use this for 
    curvature subset loading in DDP mode.
    
    Args:
        root: ImageNet root directory
        split: 'train' or 'val'
        indices: Global indices to include (same on all ranks)
        batch_size: Per-rank batch size
        num_workers: Data loading workers
        shuffle: Whether to shuffle
        drop_last: Drop last incomplete batch (required for CUDA Graph)
        prefetch_factor: Batches to prefetch per worker (default: 4)
        persistent_workers: Keep workers alive between epochs (default: True)
        
    Returns:
        loader: DataLoader for this rank's shard
        local_num_samples: Number of samples on this rank
        global_num_samples: Total samples across all ranks
    """
    transform = get_train_transform() if split == "train" else get_val_transform()
    base_dataset = ImageFolder(os.path.join(root, split), transform=transform)
    
    # Create subset from indices
    subset = Subset(base_dataset, indices.tolist())
    global_num_samples = len(subset)
    
    # Only use persistent_workers and prefetch_factor if num_workers > 0
    use_persistent = persistent_workers and num_workers > 0
    use_prefetch = prefetch_factor if num_workers > 0 else None
    
    # Create distributed sampler
    if is_ddp():
        sampler = DistributedSampler(
            subset,
            shuffle=shuffle,
            drop_last=drop_last,
        )
        local_num_samples = len(sampler)
        
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=False,  # Sampler handles shuffling
            num_workers=num_workers,
            pin_memory=True,
            sampler=sampler,
            drop_last=drop_last,
            persistent_workers=use_persistent,
            prefetch_factor=use_prefetch,
        )
    else:
        local_num_samples = global_num_samples
        loader = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=drop_last,
            persistent_workers=use_persistent,
            prefetch_factor=use_prefetch,
        )
    
    logger.info(f"Created {'distributed ' if is_ddp() else ''}subset loader: "
                f"global={global_num_samples}, local={local_num_samples}, "
                f"batch_size={batch_size}, prefetch={use_prefetch}, persistent={use_persistent}")
    
    return loader, local_num_samples, global_num_samples


if __name__ == "__main__":
    # Test with a sample path
    import sys
    
    if len(sys.argv) > 1:
        imagenet_root = sys.argv[1]
        
        print("Testing ImageNet loader...")
        train_dataset = ImageNetDataset(imagenet_root, split="train")
        
        loader = train_dataset.get_loader(batch_size=32, shuffle=False)
        
        for images, labels in loader:
            print(f"Batch shape: {images.shape}, Labels: {labels[:5]}")
            break
        
        print("\nTesting indexed loader...")
        indexed_loader = get_indexed_loader(imagenet_root, split="train", batch_size=32)
        
        for idx, images, labels in indexed_loader:
            print(f"Indices: {idx[:5]}, Batch shape: {images.shape}")
            break
    else:
        print("Usage: python imagenet_loader.py /path/to/imagenet")
