import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import yaml

class ADE20KDataset(Dataset):
    def __init__(self, 
                 root_dir='/home/gyf/iclr/recognize-anything/ADE20K',
                 image_size=384,
                 split='val',
                 transform=None,
                 num_classes=150):
        """
        ADE20K dataset for multi-class classification (object detection from segmentation)
        
        Args:
            root_dir: path to ADE20K directory
            image_size: input image size
            split: 'train' or 'val'
            transform: image transformations
            num_classes: number of classes (150 for ADE20K)
        """
        self.root_dir = root_dir
        self.image_size = image_size
        self.split = split
        self.num_classes = num_classes
        
        # Load class information
        self.class_names = self._load_class_names()
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.class_names)}
        
        # Paths
        self.splits_dir = os.path.join(root_dir, 'list')
        
        # Load image list for the split
        self.image_pairs = self._load_split()
        
        # Default transform if none provided
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform
    
    def _load_class_names(self):
        """Load class names from ade_class.yaml"""
        class_file = os.path.join(self.root_dir, 'ade_class.yaml')
        
        if os.path.exists(class_file):
            with open(class_file, 'r') as f:
                class_dict = yaml.safe_load(f)
            # Extract class names, skip index 0 (void)
            class_names = []
            for i in range(1, self.num_classes + 1):
                if i in class_dict:
                    class_name = class_dict[i].strip()
                    # Handle special cases like 'bed ' (with trailing space)
                    class_name = class_name.strip()
                    class_names.append(class_name)
                else:
                    class_names.append(f"class_{i}")
            return class_names
        else:
            # Fallback to default class names
            return [f"class_{i}" for i in range(1, self.num_classes + 1)]
    
    def _load_split(self):
        """Load image and annotation pairs for the specified split"""
        if self.split == 'train':
            split_file = os.path.join(self.splits_dir, 'train.txt')
        elif self.split == 'val':
            split_file = os.path.join(self.splits_dir, 'val.txt')
        else:
            raise ValueError(f"Unknown split: {self.split}")
        
        print(f"Loading split file: {split_file}")
        
        if not os.path.exists(split_file):
            raise FileNotFoundError(f"Split file {split_file} not found")
        
        image_pairs = []
        with open(split_file, 'r') as f:
            for line in f:
                line = line.strip()
                if ',' in line:
                    # Format: images/validation/ADE_val_00000001.jpg,annotations/validation/ADE_val_00000001.png
                    image_path, annotation_path = line.split(',')
                    image_pairs.append((image_path, annotation_path))
        
        return image_pairs
    
    def _parse_segmentation(self, annotation_path):
        """Parse segmentation annotation to extract class labels"""
        full_annotation_path = os.path.join(self.root_dir, annotation_path)
        
        # Load segmentation map
        seg_map = np.array(Image.open(full_annotation_path))
        
        # Initialize multi-label vector (all zeros)
        labels = np.zeros(self.num_classes, dtype=np.float32)
        
        # Find unique class IDs in the segmentation map
        unique_classes = np.unique(seg_map)
        
        # Convert class IDs to labels (subtract 1 because class IDs start from 1, but we index from 0)
        for class_id in unique_classes:
            if 1 <= class_id <= self.num_classes:
                labels[class_id - 1] = 1.0
        
        return labels
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        # Get image and annotation paths
        image_path, annotation_path = self.image_pairs[idx]
        
        # Load image
        full_image_path = os.path.join(self.root_dir, image_path)
        image = Image.open(full_image_path).convert('RGB')
        
        # Load annotation labels
        labels = self._parse_segmentation(annotation_path)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Extract image ID from path
        image_id = os.path.splitext(os.path.basename(image_path))[0]
        
        return {
            'image': image,
            'labels': torch.from_numpy(labels),
            'image_id': image_id
        }

def create_ade20k_dataloaders(root_dir='/home/gyf/iclr/recognize-anything/ADE20K',
                              image_size=384,
                              batch_size=16,
                              num_workers=4):
    """
    Create train and validation dataloaders for ADE20K
    """
    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = ADE20KDataset(
        root_dir=root_dir,
        image_size=image_size,
        split='train',
        transform=train_transform
    )
    
    val_dataset = ADE20KDataset(
        root_dir=root_dir,
        image_size=image_size,
        split='val',
        transform=val_transform
    )
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Test the dataset
if __name__ == "__main__":
    # Create dataset
    dataset = ADE20KDataset()
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of classes: {dataset.num_classes}")
    print(f"First 10 classes: {dataset.class_names[:10]}")
    
    # Test loading a sample
    sample = dataset[0]
    print(f"Image shape: {sample['image'].shape}")
    print(f"Labels shape: {sample['labels'].shape}")
    print(f"Labels sum: {sample['labels'].sum()}")
    print(f"Active classes: {[dataset.class_names[i] for i in range(len(sample['labels'])) if sample['labels'][i] == 1]}")