import os
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np

class VOC2012Dataset(Dataset):
    def __init__(self, 
                 root_dir='/home/gyf/iclr/recognize-anything/VOC2012',
                 image_size=384,
                 split='train',
                 transform=None):
        """
        VOC2012 dataset for multi-class classification
        
        Args:
            root_dir: path to VOC2012 directory
            image_size: input image size
            split: 'train', 'val', or 'trainval'
            transform: image transformations
        """
        self.root_dir = root_dir
        self.image_size = image_size
        self.split = split
        
        # VOC2012 has 20 object classes + 1 background = 21 classes total
        self.classes = [
            'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
            'bus', 'car', 'cat', 'chair', 'cow',
            'diningtable', 'dog', 'horse', 'motorbike', 'person',
            'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
        ]
        self.num_classes = len(self.classes)
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        # Paths
        self.images_dir = os.path.join(root_dir, 'JPEGImages')
        self.annotations_dir = os.path.join(root_dir, 'Annotations')
        #self.splits_dir = os.path.join(root_dir, 'ImageSets', 'Main')
        self.splits_dir = os.path.join(root_dir, "splits")
        
        # Load image list for the split
        self.image_ids = self._load_split()
        
        # Default transform if none provided
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform
    
    def _load_split(self):
        """Load image IDs for the specified split"""
        if self.split == 'train':
            split_file = os.path.join(self.splits_dir, 'train.txt')
        elif self.split == 'val':
            split_file = os.path.join(self.splits_dir, 'val.txt')
        elif self.split == 'trainval':
            split_file = os.path.join(self.splits_dir, 'trainval.txt')
        else:
            raise ValueError(f"Unknown split: {self.split}")
        print(f"Loading split file: {split_file}")
        # If split files don't exist, fall back to using all available annotations
        if not os.path.exists(split_file):
            print(f"Split file {split_file} not found, using all available annotations")
            image_ids = []
            for filename in os.listdir(self.annotations_dir):
                if filename.endswith('.xml'):
                    image_ids.append(filename.replace('.xml', ''))
            return image_ids
        
        with open(split_file, 'r') as f:
            lines = [line.strip() for line in f.readlines()]
        
        image_ids = []
        for line in lines:
            if ',' in line:
                # Format: JPEGImages/image_id.jpg,SegmentationClassAug/image_id.png
                image_path = line.split(',')[0]
                # Extract image_id from JPEGImages/image_id.jpg
                image_id = os.path.splitext(os.path.basename(image_path))[0]
            else:
                # Simple format: just image_id
                image_id = line
            image_ids.append(image_id)
        
        return image_ids
    
    def _parse_annotation(self, annotation_path):
        """Parse XML annotation file to extract class labels"""
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        
        # Initialize multi-label vector (all zeros)
        labels = np.zeros(self.num_classes, dtype=np.float32)
        
        # Extract all object classes in the image
        for obj in root.findall('object'):
            class_name = obj.find('name').text.lower()
            if class_name in self.class_to_idx:
                class_idx = self.class_to_idx[class_name]
                labels[class_idx] = 1.0
        
        return labels
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        # Get image ID
        image_id = self.image_ids[idx]
        
        # Load image
        image_path = os.path.join(self.images_dir, f"{image_id}.jpg")
        image = Image.open(image_path).convert('RGB')
        
        # Load annotation
        annotation_path = os.path.join(self.annotations_dir, f"{image_id}.xml")
        labels = self._parse_annotation(annotation_path)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return {
            'image': image,
            'labels': torch.from_numpy(labels),
            'image_id': image_id
        }

def create_voc2012_dataloaders(root_dir='/home/gyf/iclr/recognize-anything/VOC2012',
                               image_size=384,
                               batch_size=16,
                               num_workers=4):
    """
    Create train and validation dataloaders for VOC2012
    """
    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = VOC2012Dataset(
        root_dir=root_dir,
        image_size=image_size,
        split='trainval',  # Use trainval for training
        transform=train_transform
    )
    
    # For validation, we'll use a subset of training data or create a manual split
    # Since VOC2012 doesn't have a standard validation split for classification
    val_dataset = VOC2012Dataset(
        root_dir=root_dir,
        image_size=image_size,
        split='trainval',  # For now, use same data for validation
        transform=val_transform
    )
    
    # Create manual train/val split (80/20)
    dataset_size = len(train_dataset)
    indices = list(range(dataset_size))
    np.random.seed(42)
    np.random.shuffle(indices)
    
    split_idx = int(0.8 * dataset_size)
    train_indices = indices[:split_idx]
    val_indices = indices[split_idx:]
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)
    
    # Create dataloaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        sampler=val_sampler,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Test the dataset
if __name__ == "__main__":
    # Create dataset
    dataset = VOC2012Dataset()
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of classes: {dataset.num_classes}")
    print(f"Classes: {dataset.classes}")
    
    # Test loading a sample
    sample = dataset[0]
    print(f"Image shape: {sample['image'].shape}")
    print(f"Labels shape: {sample['labels'].shape}")
    print(f"Labels: {sample['labels']}")
    print(f"Active classes: {[dataset.classes[i] for i in range(len(sample['labels'])) if sample['labels'][i] == 1]}")