from typing import Any, NamedTuple, TypeAlias

import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10, MNIST

from research.wsl_ece.metric.ddi2013 import BinarizedDDI2013
from research.wsl_ece.metric.sized_dataset import SizedDataset


class BinarizedMNIST(SizedDataset):
    """
    A binarized MNIST dataset where digits less than 5 are labelled ``0`` and
    digits greater than or equal to 5 are labelled ``1``.
    """

    def __init__(self, root: str, train: bool = True, download=True):
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.target_transform = transforms.Lambda(lambda x: int(x >= 5))
        self.dataset = MNIST(
            root, train=train, transform=self.transform, target_transform=self.target_transform, download=download
        )

    def __getitem__(self, index) -> tuple[Any, Any]:
        return self.dataset[index]

    def __len__(self) -> int:
        return len(self.dataset)


class BinarizedCIFAR10(SizedDataset):
    """
    A binarized CIFAR-10 dataset that contains only the classes in `positive_classes`.
    """

    def __init__(
        self,
        root: str,
        train: bool = True,
        download=True,
        positive_classes: set[str] = {"airplane", "automobile", "ship", "truck"},  # noqa: B006
    ):
        if train:
            self.transform = transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(size=32, padding=int(32 * 0.125), padding_mode="reflect"),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
                ]
            )
        # Dummy dataset to get the class_to_idx mapping
        self.dataset = CIFAR10(root, train=train, transform=self.transform, download=download)
        positive_idx = [self.dataset.class_to_idx[c] for c in positive_classes]
        self.target_transform = transforms.Lambda(lambda x: int(x in positive_idx))
        self.dataset = CIFAR10(
            root, train=train, transform=self.transform, target_transform=self.target_transform, download=download
        )

    def __getitem__(self, index) -> tuple[Any, Any]:
        return self.dataset[index]

    def __len__(self) -> int:
        return len(self.dataset)


# Define a type alias for ``SizedDataset`` or ``Subset`` to make mypy happy.
# ``Subset`` is imported from ``torch.utils.data`` at runtime, so this alias
# must be marked explicitly to avoid being treated as a simple variable.
SizedDatasetType: TypeAlias = SizedDataset | torch.utils.data.Subset


class PositiveUnlabeledDatasets(NamedTuple):
    """
    A named tuple that contains an unlabeled dataset, a positive dataset, and a prior probability.
    """

    unlabeled: SizedDatasetType
    positive: SizedDatasetType
    prior: float

    @classmethod
    def from_dataset(cls, dataset: SizedDataset, num_positive: int) -> "PositiveUnlabeledDatasets":
        positive_indices = [i for i in range(len(dataset)) if dataset[i][1] == 1]
        prior = len(positive_indices) / len(dataset)
        if len(positive_indices) < num_positive:
            raise ValueError(
                f"Requested {num_positive} positive instances, "
                f"but only {len(positive_indices)} are available in the dataset."
            )
        if len(positive_indices) > num_positive:
            positive_indices = np.random.choice(positive_indices, num_positive, replace=False).tolist()  # type: ignore
        positive = torch.utils.data.Subset(dataset, positive_indices)  # type: ignore
        return cls(dataset, positive, prior)


def load_dataset(
    dataset_name: str, root: str, train: bool = True, download: bool = True, tokenizer=None, **kwargs
) -> SizedDataset:
    if dataset_name == "mnist":
        return BinarizedMNIST(root, train=train, download=download, **kwargs)
    if dataset_name == "cifar10":
        return BinarizedCIFAR10(root, train=train, download=download, **kwargs)
    elif dataset_name == "ddi2013":
        return BinarizedDDI2013(root, train=train, download=download, tokenizer=tokenizer, **kwargs)
    raise ValueError(f"Unknown dataset: {dataset_name}")
