import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split
from typing import Optional, List, Union, Tuple, Dict
import random
import matplotlib.pyplot as plt


class LabelConditionalInversion(torch.utils.data.Dataset):
    """
    Applies label-specific inversion to the MNIST dataset.

    This class extends the torch.utils.data.Dataset and creates a dataset from
    MNIST while inverting the image pixel values for the instances belonging to
    a specified label. Useful for testing the effects of label-specific
    image transformations.

    Attributes:
        label_to_invert (int): The label for which the image pixel values will
            be inverted.
        base_dataset (torchvision.datasets.MNIST): The base MNIST dataset
            loaded with specified parameters.
    """
    def __init__(self, root, label_to_invert, train, download, transform):
        self.label_to_invert = label_to_invert
        self.base_dataset = torchvision.datasets.MNIST(root=root, train=train, download=download,
                                                       transform=transform)

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        image, label = self.base_dataset[idx]

        if label == self.label_to_invert:
            image = 1.0 - image  # invert only for the target label

        return image, label

    @property
    def targets(self):
        return self.base_dataset.targets


# --- 1. Data Loading Function ---
def load_mnist_data(
    root_dir: str = './data',
    selected_labels: Optional[List[int]] = None,
    num_samples_per_label: Optional[Union[int, List[int]]] = None,
    train: bool = True,
    dataset_name: str = 'mnist'
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Loads the MNIST or Fashion MNIST dataset with options for label filtering and sample limiting.

    Args:
        root_dir (str): Directory to download/load dataset.
        selected_labels (Optional[List[int]]): A list of digit labels (0-9) to load.
                                              If None or empty, loads all labels.
        num_samples_per_label (Optional[Union[int, List[int]]]):
            - If None: Loads all available samples for the selected labels.
            - If int (N): Loads N samples for EACH selected label.
            - If List[int]: Loads the specified number of samples for each
                            corresponding label in selected_labels. Must have the
                            same length as selected_labels.
        train (bool): If True, loads the training set. Otherwise, loads the test set.
        dataset_name (str): Name of dataset to load. Either 'mnist' or 'fashion_mnist'.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - images (torch.Tensor): Tensor of loaded images (N, 1, 28, 28).
            - labels (torch.Tensor): Tensor of corresponding labels (N,).

    Raises:
        ValueError: If num_samples_per_label is a list but its length doesn't
                    match selected_labels.
        ValueError: If selected_labels contains invalid labels (not 0-9).
        TypeError: If num_samples_per_label has an invalid type.
        ValueError: If dataset_name is not 'mnist' or 'fashion_mnist'.
    """
    dataset_name = dataset_name.lower()
    if dataset_name not in ['mnist', 'fashion_mnist']:
        raise ValueError("dataset_name must be either 'mnist' or 'fashion_mnist'")
        
    print(f"--- Loading {dataset_name.upper()} {'Train' if train else 'Test'} Data ---")
    # Define the transformation to apply to the images (convert to tensor and normalize)
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # Normalization often helps training, but for visualization and noise addition
        # keeping it in [0, 1] might be clearer initially.
        # transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean/std
    ])

    # Load the full dataset
    if dataset_name == 'mnist':
        # full_dataset = LabelConditionalInversion(root=root_dir, label_to_invert=2, train=train, download=True, transform=transform)
        full_dataset = torchvision.datasets.MNIST(root=root_dir, train=train, download=True, transform=transform)
    else:  # 'fashion_mnist'
        full_dataset = torchvision.datasets.FashionMNIST(root=root_dir, train=train, download=True, transform=transform)

    # --- Label Filtering ---
    indices_to_keep = list(range(len(full_dataset))) # Start with all indices

    if selected_labels is not None and len(selected_labels) > 0:
        print(f"Filtering for labels: {selected_labels}")
        # Validate selected labels
        if not all(0 <= lbl <= 9 for lbl in selected_labels):
            raise ValueError("selected_labels contains invalid MNIST labels (must be 0-9).")

        # Use faster filtering if targets are tensors/lists directly
        try:
            targets = full_dataset.targets # This is faster if available (newer torchvision)
            if isinstance(targets, torch.Tensor):
                targets = targets.tolist()
        except AttributeError: # Fallback for older torchvision or other dataset types
             print("Attribute 'targets' not found, iterating through dataset (slower)...")
             targets = [label for _, label in full_dataset] # Slower

        indices_to_keep = [
            i for i, label in enumerate(targets) if label in selected_labels
        ]
        if not indices_to_keep:
            print("Warning: No samples found for the selected labels.")
            return torch.empty((0, 1, 28, 28)), torch.empty((0,), dtype=torch.long)
        print(f"Found {len(indices_to_keep)} samples matching selected labels.")
    else:
        print("No specific labels selected, loading all available data.")
        selected_labels = list(range(10)) # Consider all labels for sampling step


    # Create a subset based on selected labels (if any)
    label_filtered_dataset = Subset(full_dataset, indices_to_keep)

    # --- Sample Number Filtering ---
    final_indices = list(range(len(label_filtered_dataset))) # Indices within the subset

    if num_samples_per_label is not None:
        print(f"Sampling based on num_samples_per_label: {num_samples_per_label}")
        # Need labels corresponding to the label_filtered_dataset
        subset_labels = [label_filtered_dataset[i][1] for i in range(len(label_filtered_dataset))]
        label_indices_map: Dict[int, List[int]] = {lbl: [] for lbl in selected_labels}
        for idx, label in enumerate(subset_labels):
            if label in label_indices_map: # Should always be true if selected_labels used
                 label_indices_map[label].append(idx)

        final_indices = [] # Reset and populate based on sampling

        if isinstance(num_samples_per_label, int):
            N = num_samples_per_label
            print(f"Selecting up to {N} samples per selected label...")
            for label in selected_labels:
                available_indices = label_indices_map.get(label, [])
                num_to_select = min(N, len(available_indices))
                if num_to_select < N:
                    print(f"Warning: Only found {num_to_select} samples for label {label} (requested {N}).")
                if num_to_select > 0:
                    final_indices.extend(random.sample(available_indices, num_to_select))

        elif isinstance(num_samples_per_label, list):
            if len(num_samples_per_label) != len(selected_labels):
                raise ValueError(
                    "Length of num_samples_per_label list must match length of selected_labels."
                )
            print(f"Selecting specific counts per label: {dict(zip(selected_labels, num_samples_per_label))}")
            for i, label in enumerate(selected_labels):
                N = num_samples_per_label[i]
                available_indices = label_indices_map.get(label, [])
                num_to_select = min(N, len(available_indices))
                if num_to_select < N:
                     print(f"Warning: Only found {num_to_select} samples for label {label} (requested {N}).")
                if num_to_select > 0:
                    final_indices.extend(random.sample(available_indices, num_to_select))

        else:
             raise TypeError("num_samples_per_label must be None, an integer, or a list of integers.")

        if not final_indices:
            print("Warning: No samples selected after applying num_samples_per_label.")
            return torch.empty((0, 1, 28, 28)), torch.empty((0,), dtype=torch.long)

        # Create the final subset based on sampled indices
        sampled_dataset = Subset(label_filtered_dataset, final_indices)
        print(f"Selected a total of {len(sampled_dataset)} samples.")

    else:
        # No sampling per label needed, use the label-filtered dataset
        print(f"Loading all {len(label_filtered_dataset)} samples for the selected labels.")
        sampled_dataset = label_filtered_dataset


    # --- Extract Tensors ---
    # It's often more efficient to load into tensors directly if memory allows
    images = torch.stack([img for img, lbl in sampled_dataset], dim=0)
    labels = torch.tensor([lbl for img, lbl in sampled_dataset], dtype=torch.long)

    print(f"Final loaded data shape: Images - {images.shape}, Labels - {labels.shape}")
    print("-" * 30)

    return images, labels

# --- 2. Stratified Splitting Function ---
def split_data_stratified(
    images: torch.Tensor,
    labels: torch.Tensor,
    random_state: int,
    target_split_ratio: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Splits the data into source and target sets using stratified sampling.

    Args:
        images (torch.Tensor): Input image tensor.
        labels (torch.Tensor): Input label tensor.
        random_state (int): Seed for reproducibility.
        target_split_ratio (float): Proportion of data to put in the target set (0.0 to 1.0).

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            - source_images, source_labels, target_images, target_labels
    """
    print(f"--- Splitting Data (Target Ratio: {target_split_ratio:.2f}) ---")
    
    # Check for empty dataset
    if images.shape[0] == 0:
        print("Warning: Cannot split empty dataset.")
        empty_tensor = torch.empty_like(images)
        empty_labels = torch.empty_like(labels)
        return empty_tensor, empty_labels, empty_tensor, empty_labels

    # train_test_split needs indices or numpy arrays
    indices = np.arange(images.shape[0])
    labels_np = labels.cpu().numpy() # Stratify works best on numpy array

    # Split: separate source and target data
    try:
        source_indices, target_indices = train_test_split(
            indices,
            test_size=target_split_ratio,
            stratify=labels_np,
            random_state=random_state
        )
    except ValueError as e:
         print(f"Warning: Stratified split failed ({e}). This might happen if a class "
               f"has fewer samples than required for splitting. Falling back to non-stratified split.")
         # Fallback to non-stratified split if stratification isn't possible
         # (e.g., only 1 sample of a certain class)
         source_indices, target_indices = train_test_split(
            indices,
            test_size=target_split_ratio,
            random_state=random_state
         )

    # Create the final tensors
    source_images = images[source_indices]
    source_labels = labels[source_indices]
    target_images = images[target_indices]
    target_labels = labels[target_indices]

    # Print summary
    print(f"Source data shape: Images - {source_images.shape}, Labels - {source_labels.shape}")
    print(f"Target data shape: Images - {target_images.shape}, Labels - {target_labels.shape}")
    print("-" * 30)

    return source_images, source_labels, target_images, target_labels

# --- 4. Batching Function ---
def create_dataloaders(
    images: torch.Tensor,
    labels: Optional[torch.Tensor] = None,
    batch_size: int = 64,
    shuffle: bool = True,
    drop_last: bool = False,
    num_workers: int = 0
) -> DataLoader:
    """
    Creates a DataLoader instance for the provided image tensor and optional label tensor.

    Args:
        images (torch.Tensor): Input image tensor.
        labels (Optional[torch.Tensor]): Optional input label tensor (also images of the same shape).
        batch_size (int): Number of samples per batch.
        shuffle (bool): Whether to shuffle data each epoch.
        drop_last (bool): Whether to drop the last incomplete batch if the dataset size
                         is not divisible by the batch size.
        num_workers (int): How many subprocesses to use for data loading.
                         0 means that the data will be loaded in the main process.

    Returns:
        DataLoader: The created data loader for the images (and labels if provided)
    """
    print(f"--- Creating DataLoader (Batch Size: {batch_size}, Workers: {num_workers}) ---")

    if images.shape[0] > 0:
        # Create dataset with or without labels
        if labels is not None:
            dataset = TensorDataset(images, labels)
            print(f"DataLoader created with images and labels.")
        else:
            dataset = TensorDataset(images)
            print(f"DataLoader created with images only.")

        loader = DataLoader(
            dataset, 
            batch_size=batch_size, 
            shuffle=shuffle, 
            drop_last=drop_last,
            num_workers=num_workers
        )
        print(f"DataLoader created with {len(loader)} batches.")
    else:
        print("Dataset is empty, creating an empty DataLoader.")
        if labels is not None:
            # Empty loader with labels placeholder
            loader = DataLoader(TensorDataset(images, labels), num_workers=num_workers)
        else:
            # Original empty loader behavior
            loader = DataLoader(TensorDataset(images), num_workers=num_workers)

    print("-" * 30)
    return loader