"""
Data Loading Utilities for Multi-Task Learning

This module provides utilities for loading and preprocessing datasets for multi-task learning.
"""

import os
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from typing import Tuple, Optional


def load_face_dataset(
    image_size: int = 64, 
    normalize: bool = True,
    data_root: str = './data'
) -> Tuple[datasets.CelebA, datasets.CelebA, datasets.CelebA]:
    """
    Load face dataset for multi-task learning experiments.
    
    Args:
        image_size: Size to resize images to
        normalize: Whether to normalize images
        data_root: Root directory for data
        
    Returns:
        Tuple of (train_dataset, val_dataset, test_dataset)
    """
    # Define transforms
    if normalize:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])
    
    # Load training dataset
    train_dataset = datasets.CelebA(
        data_root,
        download=False,
        split='train',
        transform=transform
    )
    print(f"Training dataset size: {len(train_dataset)}")
    
    # Load validation dataset
    val_dataset = datasets.CelebA(
        data_root,
        download=False,
        split='valid',
        transform=transform
    )
    print(f"Validation dataset size: {len(val_dataset)}")
    
    # Load test dataset
    test_dataset = datasets.CelebA(
        data_root,
        download=False,
        split='test',
        transform=transform
    )
    print(f"Test dataset size: {len(test_dataset)}")
    
    return train_dataset, val_dataset, test_dataset


def create_face_dataloaders(
    train_dataset: datasets.CelebA,
    val_dataset: datasets.CelebA,
    test_dataset: datasets.CelebA,
    task_names: list,
    task_labels: list,
    training_size: int,
    val_size: int,
    test_size: int,
    batch_size: int = 512,
    seed: int = 42,
    mask: Optional[torch.Tensor] = None,
    adversarial: bool = False,
    adversarial_size: float = 0.1,
    no_stratify: bool = False
) -> Tuple[dict, dict, dict]:
    """
    Create data loaders for face dataset with multi-task setup.
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset
        test_dataset: Test dataset
        task_names: List of task names
        task_labels: List of task label indices
        training_size: Number of training samples
        val_size: Number of validation samples
        test_size: Number of test samples
        batch_size: Batch size for data loaders
        seed: Random seed for reproducibility
        mask: Optional mask for data selection
        adversarial: Whether to use adversarial data
        adversarial_size: Proportion of adversarial samples
        no_stratify: Whether to disable stratified sampling
        
    Returns:
        Tuple of (train_loaders, val_loaders, test_loaders)
    """
    import numpy as np
    from copy import deepcopy
    from torch.utils.data import Subset
    
    train_loaders = {}
    val_loaders = {}
    test_loaders = {}
    
    # Store original attributes
    original_train_attr = deepcopy(train_dataset.attr)
    original_val_attr = deepcopy(val_dataset.attr)
    original_test_attr = deepcopy(test_dataset.attr)
    
    corrupt_indices_total = []
    
    for task_idx, task_name in enumerate(task_names):
        # Set attributes for current task
        np.random.seed(seed)
        train_dataset.attr = original_train_attr[:, task_labels[task_idx]].reshape(-1, 1)
        val_dataset.attr = original_val_attr[:, task_labels[task_idx]].reshape(-1, 1)
        test_dataset.attr = original_test_attr[:, task_labels[task_idx]].reshape(-1, 1)
        
        # Apply adversarial corruption if requested
        if adversarial:
            num_to_corrupt = int(adversarial_size * len(train_dataset.attr))
            corrupt_indices = np.random.choice(len(train_dataset.attr), num_to_corrupt, replace=False)
            corrupt_indices_total.extend(corrupt_indices)
            train_dataset.attr[corrupt_indices] = 1 - train_dataset.attr[corrupt_indices]
        
        # Get labels for stratified sampling
        train_labels = train_dataset.attr.squeeze().numpy()
        val_labels = val_dataset.attr.squeeze().numpy()
        test_labels = test_dataset.attr.squeeze().numpy()
        
        # Compute class proportions
        all_labels = np.concatenate([train_labels, val_labels, test_labels])
        unique_labels, counts = np.unique(all_labels, return_counts=True)
        proportions = counts / counts.sum()
        
        # Compute samples per class for test and validation
        test_samples_per_class = (proportions * test_size).astype(int)
        test_samples_per_class[-1] += test_size - test_samples_per_class.sum()
        
        train_samples_per_class = (proportions * training_size).astype(int)
        train_samples_per_class[-1] += training_size - train_samples_per_class.sum()
        
        def stratified_sampling(labels, samples_per_class):
            """Perform stratified sampling."""
            np.random.seed(seed)
            indices_per_class = []
            for label, num_samples in zip(unique_labels, samples_per_class):
                indices = np.where(labels == label)[0]
                num_samples = min(num_samples, len(indices))
                selected_indices = np.random.choice(indices, num_samples, replace=False)
                indices_per_class.extend(selected_indices)
            np.random.shuffle(indices_per_class)
            return torch.Tensor(indices_per_class).int()
        
        # Create training data loader
        if no_stratify:
            if mask is not None:
                train_indices = torch.randperm(len(train_dataset.attr))[:training_size][mask[task_idx]]
            else:
                train_indices = torch.randperm(len(train_dataset.attr))[:training_size]
        else:
            if mask is None:
                train_indices = stratified_sampling(train_labels, train_samples_per_class)
            else:
                train_indices = stratified_sampling(train_labels, train_samples_per_class)[mask[task_idx]]
        
        train_loaders[task_name] = DataLoader(
            Subset(deepcopy(train_dataset), train_indices),
            batch_size=batch_size, 
            num_workers=4, 
            shuffle=False
        )
        
        # Create validation data loader
        if no_stratify:
            val_indices = torch.randperm(len(val_dataset.attr))[:val_size]
        else:
            val_indices = stratified_sampling(val_labels, test_samples_per_class)
        
        val_loaders[task_name] = DataLoader(
            Subset(deepcopy(val_dataset), val_indices),
            batch_size=batch_size, 
            num_workers=4, 
            shuffle=False
        )
        
        # Create test data loader
        if no_stratify:
            test_indices = torch.randperm(len(test_dataset.attr))[:test_size]
        else:
            test_indices = stratified_sampling(test_labels, test_samples_per_class)
        
        test_loaders[task_name] = DataLoader(
            Subset(deepcopy(test_dataset), test_indices),
            batch_size=batch_size, 
            num_workers=4, 
            shuffle=False
        )
    
    return train_loaders, val_loaders, test_loaders

