import abc
import os
import pathlib

import filelock
import numpy as np
import torch
import torchvision


class DatasetLoader(object, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def prepare_raw_data(self) -> None:
        pass

    @abc.abstractmethod
    def load_train_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        pass

    @abc.abstractmethod
    def load_val_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        pass

    @property
    @abc.abstractmethod
    def dataset_mean_std(self) -> tuple[torch.Tensor, torch.Tensor]:
        pass


class CIFAR10Loader(DatasetLoader):
    def prepare_raw_data(self) -> None:
        # TODO: Should do lock file in data root...
        with filelock.FileLock("cifar10_lock"):
            torchvision.datasets.CIFAR10(root=_get_data_root(), train=True, download=True)

    def load_train_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        return self._load_dataset(train=True)

    def load_val_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        return self._load_dataset(train=False)

    @property
    def dataset_mean_std(self) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.tensor([0.4914, 0.4822, 0.4465]), torch.tensor([0.2470, 0.2435, 0.2616])

    def _load_dataset(self, train: bool) -> tuple[torch.Tensor, torch.Tensor]:
        raw_dataset = torchvision.datasets.CIFAR10(root=_get_data_root(), train=train, download=False)
        images = torch.tensor(raw_dataset.data).permute(0, 3, 1, 2).to(torch.float32) / 255.0
        targets = torch.tensor(raw_dataset.targets).to(torch.int64)
        return images, targets


class MNISTLoader(DatasetLoader):
    def prepare_raw_data(self) -> None:
        # TODO: Should do lock file in data root...
        with filelock.FileLock("mnist_lock"):
            torchvision.datasets.MNIST(root=_get_data_root(), train=True, download=True)

    def load_train_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        return self._load_dataset(train=True)

    def load_val_data(self) -> tuple[torch.Tensor, torch.Tensor]:
        return self._load_dataset(train=False)

    @property
    def dataset_mean_std(self) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.tensor([0.1307]), torch.tensor([0.3081])

    def _load_dataset(self, train: bool) -> tuple[torch.Tensor, torch.Tensor]:
        raw_dataset = torchvision.datasets.MNIST(root=_get_data_root(), train=train, download=False)
        images = raw_dataset.data.unsqueeze(1).to(torch.float32) / 255.0
        targets = torch.tensor(raw_dataset.targets).to(torch.int64)
        return images, targets


def generate_full_canary_membership_masks(
    num_canaries: int,
    num_non_canaries: int,
    num_models_target: int,
    num_models_shadow: int,
    sample_non_canaries: bool,
    global_seed: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    membership_masks_targets, membership_masks_shadow = generate_full_membership_masks(
        num_canaries=num_canaries,
        num_non_canaries=num_non_canaries,
        num_models_target=num_models_target,
        num_models_shadow=num_models_shadow,
        sample_non_canaries=sample_non_canaries,
        global_seed=global_seed,
    )

    # Canaries are always the first rows of the full membership masks
    masks_target = membership_masks_targets[:, :num_canaries]
    masks_shadow = membership_masks_shadow[:, :num_canaries]

    return masks_target, masks_shadow


def generate_full_membership_masks(
    num_canaries: int,
    num_non_canaries: int,
    num_models_target: int,
    num_models_shadow: int,
    sample_non_canaries: bool,
    global_seed: int,
) -> tuple[torch.Tensor, torch.Tensor]:
    rng_splits = np.random.default_rng(global_seed)
    rng_targets, rng_shadow = rng_splits.spawn(2)
    del rng_splits
    membership_masks_targets = _generate_membership_masks(
        num_canaries=num_canaries,
        num_non_canaries=num_non_canaries,
        num_models=num_models_target,
        sample_non_canaries=sample_non_canaries,
        rng=rng_targets,
    )
    del rng_targets
    membership_masks_shadow = _generate_membership_masks(
        num_canaries=num_canaries,
        num_non_canaries=num_non_canaries,
        num_models=num_models_shadow,
        sample_non_canaries=sample_non_canaries,
        rng=rng_shadow,
    )
    del rng_shadow
    return membership_masks_targets, membership_masks_shadow


def _generate_membership_masks(
    num_canaries: int,
    num_non_canaries: int,
    num_models: int,
    sample_non_canaries: bool,
    rng: np.random.Generator,
) -> torch.Tensor:
    # Ensure that every canary is a member in exactly half of the models
    # (hence, not every model has exactly the same number of training samples)
    if num_models % 2 != 0:
        raise ValueError(f"Number of models must be even but is {num_models}")

    # Index 0..(num_canaries - 1) are canaries
    # Index num_canaries..(num_canaries + num_non_canaries - 1) are non-canaries
    membership_masks_t = np.zeros((num_canaries + num_non_canaries, num_models), dtype=bool)

    # Sample canary indices
    in_canary_indices_t = np.argsort(rng.uniform(size=(num_canaries, num_models)), axis=1)[:, : num_models // 2]
    np.put_along_axis(membership_masks_t[:num_canaries], in_canary_indices_t, values=True, axis=1)
    assert np.all(membership_masks_t[:num_canaries].sum(axis=1) == num_models // 2)
    assert np.all(~membership_masks_t[num_canaries:])

    if sample_non_canaries:
        # Sample non-canary indices
        in_non_canary_indices_t = np.argsort(rng.uniform(size=(num_non_canaries, num_models)), axis=1)[
            :, : num_models // 2
        ]
        np.put_along_axis(membership_masks_t[num_canaries:], in_non_canary_indices_t, values=True, axis=1)
        assert np.all(membership_masks_t[num_canaries:].sum(axis=1) == num_models // 2)
    else:
        # All non-canaries are in all models
        membership_masks_t[num_canaries:] = True
        assert np.all(membership_masks_t[num_canaries:].sum(axis=1) == num_models)

    return torch.from_numpy(membership_masks_t.T)


def build_train_data(
    train_images: torch.Tensor,
    train_targets: torch.Tensor,
    membership_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    
    if membership_mask is None:
        images = train_images.contiguous()
        targets = train_targets.contiguous()
    else:
        assert membership_mask.shape == (train_images.shape[0],)

        images = train_images[membership_mask].contiguous()
        targets = train_targets[membership_mask].contiguous()

    return images, targets


def select_canary_indices(
    num_canaries: int,
    num_samples: int,
    global_seed: int,
    manual_selection: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(global_seed)

    if manual_selection is not None:
        canary_indices = np.array(manual_selection).astype(np.int64)
    else:
        canary_indices = rng.choice(num_samples, size=num_canaries, replace=False)

    non_canary_indices = np.setdiff1d(np.arange(num_samples), canary_indices)

    return torch.from_numpy(canary_indices), torch.from_numpy(non_canary_indices)


def validate_dataset(
    images: torch.Tensor, targets: torch.Tensor, image_shape: tuple[int, int, int], num_samples: int, num_classes: int
) -> None:
    assert images.shape == (num_samples, *image_shape)
    assert images.dtype == torch.float32
    assert targets.shape == (num_samples,)
    assert targets.dtype == torch.int64
    assert images.min() >= 0.0 and images.max() <= 1.0
    assert targets.min() >= 0 and targets.max() < num_classes


def _get_data_root() -> pathlib.Path:
    if "DATA_ROOT" in os.environ:
        return pathlib.Path(os.environ["DATA_ROOT"])
    else:
        return pathlib.Path(__file__).parent / "data"
