from __future__ import annotations

import warnings
from typing import Any, Sequence

import lightning as L
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torchvision.transforms import v2 as transforms


class InterleavedDataset(torch.utils.data.Dataset):
    """Interleaved dataset.

    Useful to ensure that each batch contains examples from all datasets (assuming the batch size
    is larger than the number of datasets). This assumes all datasets are of the same size.

    :param datasets: Sequence of (map-style) datasets to interleave.
    """

    def __init__(self, datasets: Sequence[torch.utils.data.Dataset]) -> None:
        if not all(len(datasets[0]) == len(dataset) for dataset in datasets):
            warnings.warn(
                f"{self.__class__.__name__} assumes all datasets are of the same size. "
                "This may lead to data being omitted in the interleaved dataset.",
            )
        super().__init__()
        self.datasets = datasets

    def __getitem__(self, index: int) -> tuple[Any, Any]:
        dataset_index = index % len(self.datasets)
        return self.datasets[dataset_index][index // len(self.datasets)]

    def __len__(self):
        len_smallest_dataset = min(len(dataset) for dataset in self.datasets)
        return len_smallest_dataset * len(self.datasets)


class OODDataset(L.LightningDataModule):
    """Dataset to benchmark out-of-distribution detection.

    :param id_dataset: In-distribution dataset.
    :param ood_dataset: Out-of-distribution dataset.
    """

    def __init__(
        self,
        id_dataset: L.LightningDataModule,
        ood_dataset: L.LightningDataModule,
    ) -> None:
        super().__init__()
        self.id_dataset = id_dataset
        self.ood_dataset = ood_dataset

    def prepare_data(self) -> None:
        self.id_dataset.prepare_data()
        self.ood_dataset.prepare_data()

    def setup(self, stage: str | None = None) -> None:

        self.id_dataset.setup(stage)
        self.ood_dataset.setup(stage)

        if stage == "fit" or stage is None:

            # Modify the targets to indicate in-distribution and out-of-distribution data
            self.id_dataset.data_train.target_transform = transforms.Lambda(
                lambda x: (x, 0)
            )
            self.id_dataset.data_val.target_transform = transforms.Lambda(
                lambda x: (x, 0)
            )
            self.ood_dataset.data_train.target_transform = transforms.Lambda(
                lambda x: (x, 1)
            )
            self.ood_dataset.data_val.target_transform = transforms.Lambda(
                lambda x: (x, 1)
            )

            self.data_train = ConcatDataset(
                [self.id_dataset.data_train, self.ood_dataset.data_train]
            )
            self.data_val = ConcatDataset(
                [self.id_dataset.data_val, self.ood_dataset.data_val]
            )

        if stage == "test" or stage is None:
            # Modify the targets to indicate in-distribution and out-of-distribution data
            self.id_dataset.data_test.target_transform = transforms.Lambda(
                lambda x: (x, 0)
            )
            self.ood_dataset.data_test.target_transform = transforms.Lambda(
                lambda x: (x, 1)
            )

            # Equal number of in- and out-of-distribution examples
            ood_id_ratio = len(self.ood_dataset.data_test) / len(
                self.id_dataset.data_test
            )

            # Interleave in- and out-of-distribution examples
            if ood_id_ratio > 1:
                self.data_test = InterleavedDataset(
                    [
                        ConcatDataset([self.id_dataset.data_test] * int(ood_id_ratio)),
                        self.ood_dataset.data_test,
                    ]
                )
            else:
                self.data_test = InterleavedDataset(
                    [
                        self.id_dataset.data_test,
                        ConcatDataset(
                            [self.ood_dataset.data_test] * int(1 / ood_id_ratio)
                        ),
                    ]
                )

    def train_dataloader(self):
        return DataLoader(
            self.data_train,
            batch_size=self.id_dataset.batch_size,
            shuffle=True,
            pin_memory=self.id_dataset.pin_memory,
            num_workers=self.id_dataset.num_workers,
            persistent_workers=self.id_dataset.persistent_workers,
            generator=self.id_dataset.generator,
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_val,
            batch_size=self.id_dataset.batch_size_test,
            shuffle=False,
            pin_memory=self.id_dataset.pin_memory,
            num_workers=self.id_dataset.num_workers,
            persistent_workers=self.id_dataset.persistent_workers,
            generator=self.id_dataset.generator,
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_test,
            batch_size=self.ood_dataset.batch_size_test,
            shuffle=False,
            pin_memory=self.ood_dataset.pin_memory,
            num_workers=self.ood_dataset.num_workers,
            persistent_workers=self.ood_dataset.persistent_workers,
            generator=self.ood_dataset.generator,
        )
