import torch
import numpy as np
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
from sklearn.model_selection import LeaveOneOut
from typing import Tuple, List, Optional
import random


class FlattenTransform:
    """Transform to flatten tensor to 1D."""
    def __call__(self, x):
        return x.flatten()

def get_mnist_data(data_dir: str = "data/raw", *, flatten: bool = True) -> Tuple[datasets.MNIST, datasets.MNIST]:
    """Load MNIST dataset.
    Args:
        data_dir: directory to download/store
        flatten: if True -> (784,) for MLPs; if False -> (1,28,28) for CNNs
    """
    tfs = [transforms.ToTensor()]
    if flatten:
        tfs.append(FlattenTransform())
    transform = transforms.Compose(tfs)
    train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
    return train_dataset, test_dataset


def get_fashion_mnist_data(data_dir: str = "data/raw", *, flatten: bool = False) -> Tuple[datasets.FashionMNIST, datasets.FashionMNIST]:
    """Load FashionMNIST dataset.
    Args:
        data_dir: directory to download/store
        flatten: if True -> (784,) for MLPs; if False -> (1,28,28) for CNNs
    """
    tfs = [transforms.ToTensor()]
    if flatten:
        tfs.append(FlattenTransform())
    transform = transforms.Compose(tfs)
    train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
    return train_dataset, test_dataset


def create_leave_one_out_splits(dataset: datasets.MNIST, 
                               train_size: int,
                               leave_one_out_ratio: float,
                               random_seed: int = 42,
                               max_splits: int = 5) -> List[Tuple[Subset, Subset]]:
    """
    Create leave-one-out splits for validation.
    
    Args:
        dataset: Full dataset
        train_size: Number of training samples
        leave_one_out_ratio: Ratio of data to leave out for validation
        random_seed: Random seed for reproducibility
    
    Returns:
        List of (train_subset, val_subset) tuples
    """
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    
    # Sample train_size samples from the dataset
    indices = list(range(len(dataset)))
    random.shuffle(indices)
    selected_indices = indices[:train_size]
    
    # Calculate validation size
    val_size = int(train_size * leave_one_out_ratio)
    train_size_adjusted = train_size - val_size
    
    # Create k-fold splits instead of true leave-one-out for efficiency
    # Use k-fold where k = min(1/leave_one_out_ratio, max_splits)
    k = min(max(1, int(1 / leave_one_out_ratio)), max_splits)
    splits = []
    
    for i in range(k):
        # Create validation indices for this fold
        val_start = i * val_size
        val_end = min((i + 1) * val_size, train_size)
        
        if val_start >= train_size:
            break
            
        val_indices = selected_indices[val_start:val_end]
        train_indices = [idx for idx in selected_indices if idx not in val_indices]
        
        train_subset = Subset(dataset, train_indices)
        val_subset = Subset(dataset, val_indices)
        
        splits.append((train_subset, val_subset))
    
    return splits


def create_data_loaders(train_dataset: Subset,
                       val_dataset: Subset,
                       test_dataset: datasets.MNIST,
                       batch_size: int,
                       num_workers: int = 0) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create data loaders for training, validation, and test sets."""
    # Use persistent_workers=False to avoid multiprocessing issues
    train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                             shuffle=True, num_workers=num_workers,
                             persistent_workers=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=num_workers,
                           persistent_workers=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=num_workers,
                            persistent_workers=False)
    
    return train_loader, val_loader, test_loader


def save_data_splits(splits: List[Tuple[Subset, Subset]], 
                    save_dir: str = "data/splits") -> None:
    """Save data split indices for reproducibility."""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    for i, (train_subset, val_subset) in enumerate(splits):
        split_data = {
            'train_indices': train_subset.indices.tolist(),
            'val_indices': val_subset.indices.tolist()
        }
        
        with open(f"{save_dir}/split_{i}.json", 'w') as f:
            import json
            json.dump(split_data, f)


def save_experiment_data_splits(splits: List[Tuple[Subset, Subset]], 
                               experiment_dir: str,
                               config: dict) -> None:
    """Save data split indices for a specific experiment."""
    import os
    import json
    
    # Create data splits directory within experiment
    data_splits_dir = os.path.join(experiment_dir, 'data_splits')
    os.makedirs(data_splits_dir, exist_ok=True)
    
    # Save experiment metadata
    experiment_metadata = {
        'dataset': config['data']['dataset'],
        'train_size': config['data']['train_size'],
        'leave_one_out_ratio': config['data']['leave_one_out_ratio'],
        'random_seed': config['validation']['random_seed'],
        'num_splits': len(splits),
        'split_info': []
    }
    
    for i, (train_subset, val_subset) in enumerate(splits):
        split_data = {
            'split_idx': i,
            'train_indices': train_subset.indices,
            'val_indices': val_subset.indices,
            'train_size': len(train_subset),
            'val_size': len(val_subset)
        }
        
        # Save individual split
        split_file = os.path.join(data_splits_dir, f'split_{i}.json')
        with open(split_file, 'w') as f:
            json.dump(split_data, f, indent=2)
        
        experiment_metadata['split_info'].append({
            'split_idx': i,
            'train_size': len(train_subset),
            'val_size': len(val_subset)
        })
    
    # Save experiment metadata
    metadata_file = os.path.join(data_splits_dir, 'experiment_metadata.json')
    with open(metadata_file, 'w') as f:
        json.dump(experiment_metadata, f, indent=2)
    
    print(f"Data splits saved to: {data_splits_dir}")


def load_data_split(split_file: str) -> Tuple[List[int], List[int]]:
    """Load data split indices from file."""
    import json
    
    with open(split_file, 'r') as f:
        split_data = json.load(f)
    
    return split_data['train_indices'], split_data['val_indices'] 