import numpy as np
import sys
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from PIL import Image
from torch.utils.data import Dataset


class ReducedImageDataset(Dataset):
    # For reducing multi-class classification task to binary classification.

    def __init__(self, dataset, n_classes=2) -> None:
        super().__init__()
        if not hasattr(dataset, "data") or not hasattr(dataset, "targets"):
            sys.exit(
                "[ReducedImageDataset]: Passed in dataset is missing data and/or targets attributes."
            )
        if len(np.unique(dataset.targets)) < n_classes:
            sys.exit(
                "[ReducedImageDataset]: Requested number of classes that is greater than underlying dataset's number of classes."
            )

        class_labels = set(range(n_classes))
        self.transform = dataset.transform
        self.target_transform = dataset.target_transform
        self.data = []
        self.targets = []
        for i in range(len(dataset)):
            if dataset.targets[i] in class_labels:
                self.data.append(dataset.data[i])
                self.targets.append(dataset.targets[i])

    def __getitem__(self, index):
        # Taken straight from torchvision CIFAR10 image loading.
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

class PatchedImageDataset(Dataset):
    # For patching classes in a dataset at the image level.

    def __init__(self, dataset, patch_width=1) -> None:
        super().__init__()

        n_classes = len(np.unique(dataset.targets))
        if n_classes > 256:
            sys.exit("[PatchedImageDataset]: Current patching scheme only works for <= 256 classes.")
    
        self.patch_width = patch_width
        self.class_patches = np.random.randint(0, 256, (n_classes, patch_width, patch_width, 3))
        # We only differentiate the patches by a single color coordinate.
        for y in range(n_classes):
            self.class_patches[y, 0, 0, 0] = y

        self.transform = dataset.transform
        self.target_transform = dataset.target_transform
        self.data = np.copy(dataset.data)
        self.targets = np.copy(dataset.targets)
        for i in range(len(self.data)):
            self.data[i][:patch_width, :patch_width, :] = self.class_patches[self.targets[i]]

    def __getitem__(self, index):
        # Taken straight from torchvision CIFAR10 image loading.
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)
    

class BinaryModImageDataset(Dataset):
    # For patching classes in a dataset at the tensor level.
    # Only works for binary classification.

    def __init__(self, dataset, patch_scale=1e-3, low_var=0) -> None:
        super().__init__()

        self.data, self.targets = [], []
        for x, y in dataset:
            x_mod = x.clone()
            # This identifies the class in the first coordinate, but only works for binary classification.
            x_mod[0, 0, 0] = (2 * y - 1) * patch_scale
            if low_var > 0:
                x_mod[0, 0, 0] += (2 * y - 1) * low_var * np.random.rand()
            self.data.append(x_mod)
            self.targets.append(np.array([y], dtype=np.float32))
    
    def __getitem__(self, index):
        return self.data[index], self.targets[index]
    
    def __len__(self):
        return len(self.data)
    

class MultiModImageDataset(Dataset):
    # Multi-class version of BinaryModImageDataset.
    # Each class gets an (approximately) orthogonal feature associated with it.

    def __init__(self, dataset, patch_scale=1e-3, patch_height=32, patch_width=16, n_classes=10) -> None:
        super().__init__()

        features = torch.randn(n_classes, 3, patch_height, patch_width)  # Image data needs to have first dimension as 3.
        self.data, self.targets = [], []
        for x, y in dataset:
            x_mod = x.clone()
            x_mod[:, :patch_height, :patch_width] = patch_scale * features[y]
            self.data.append(x_mod)
            self.targets.append(y)
    
    def __getitem__(self, index):
        return self.data[index], self.targets[index]
    
    def __len__(self):
        return len(self.data)
    

class ConcatImageDataset(Dataset):
    # Multi-class noised image dataset where extra channels represent zero variance noise feature.

    def __init__(self, dataset, feature_scale=1e-3, n_classes=10, test_data=False, low_var=1e-2) -> None:
        super().__init__()

        for x, _ in dataset:
            data_shape = x.shape
            break
        if test_data:
            features = torch.zeros(n_classes, *data_shape)
        else:
            features = feature_scale * torch.randn(n_classes, *data_shape)

        self.data, self.targets = [], []
        for x, y in dataset:
            concat_feature = features[y]
            if low_var > 0 and not test_data:
                concat_feature += low_var * (torch.randn(concat_feature.shape))
            self.data.append(torch.cat([x, concat_feature], dim=0))
            self.targets.append(y)
    
    def __getitem__(self, index):
        return self.data[index], self.targets[index]
    
    def __len__(self):
        return len(self.data)


class ColoredMNIST(torch.utils.data.Dataset):

    def __init__(self, train=True, to_tensor=True):
        self.dataset = datasets.MNIST("data", train=train, download=True)
        self.train = train
        self.to_tensor = to_tensor
        self.transform = transforms.ToTensor()
        self.colors = [i * np.ones(3) for i in range(1, 11)] # Can also replace with random colors.

    def __getitem__(self, index):
        img, target = self.dataset[index]
        img = img.convert("RGB")

        if self.train:
            img = np.array(img)
            img[(img[:, :, 0] == 0), :] = self.colors[target]
            img = Image.fromarray(img, mode="RGB")
        else:
            img = np.array(img)
            img[(img[:, :, 0] == 0), :] = self.colors[9 - target]
            img = Image.fromarray(img, mode="RGB")
            
        if self.to_tensor:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.dataset)

    
def load_cifar10():
    """Loads CIFAR-10 dataset."""
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    return (
        datasets.CIFAR10("data",
                         train=True,
                         download=True,
                         transform=transform_train),
        datasets.CIFAR10("data",
                         train=False,
                         download=True,
                         transform=transform_test),
    )


def load_cifar100():
    """Loads CIFAR-100 dataset."""
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5070746, 0.4865490, 0.4409179),
                             (0.2673342, 0.2564385, 0.2761506)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5070746, 0.4865490, 0.4409179),
                             (0.2673342, 0.2564385, 0.2761506)),
    ])

    return (
        datasets.CIFAR100("data",
                          train=True,
                          download=True,
                          transform=transform_train),
        datasets.CIFAR100("data",
                          train=False,
                          download=True,
                          transform=transform_test),
    )


def generate_synthetic(n_sample=5000, feature_scale=1e-3, core_d=1, spur_d=9, low_var=False):
    """Loads synthetic binary classification dataset."""
    v1 = feature_scale * torch.ones((n_sample // 2, core_d))
    v2 = -v1

    if low_var:
        v1 = v1 + (feature_scale / 10) * torch.rand((n_sample // 2, core_d))
        v2 = v2 - (feature_scale / 10) * torch.rand((n_sample // 2, core_d))

    # Ensures spurious features are more separable.
    spur1 = 99 * torch.rand((n_sample // 2, spur_d)) + 1
    spur2 = -99 * torch.rand((n_sample // 2, spur_d)) - 1
    # spur1 = torch.randn((n_sample // 2, spur_d))
    # spur1[:, -1] = -1 * spur1[:, :-1].sum(dim=1) + 99 * torch.rand(n_sample // 2) + 1
    # spur2 = torch.randn((n_sample // 2, spur_d))
    # spur2[:, -1] = -1 * spur2[:, :-1].sum(dim=1) - 99 * torch.rand(n_sample // 2) - 1
    
    class1 = torch.concatenate([v1, spur1], dim=1)
    class0 = torch.concatenate([v2, spur2], dim=1)

    X = torch.concatenate([class1, class0], dim=0)
    y = torch.concatenate([torch.ones(n_sample // 2), torch.zeros(n_sample // 2)], dim=0).unsqueeze(dim=1)

    return torch.utils.data.TensorDataset(X, y)


def load_dataset(
    dataset: str,
    subsample: int = 0,
    binary: bool = False,
    add_patches: bool = False,
    low_var: float = 0,
):
    """Loads dataset specified by provided string.

    Args:
        dataset (str): Dataset name.
        subsample (int, optional): How much to subsample data by. Defaults to 0 (no subsampling).
    """
    out_dim = 10
    if dataset == "CIFAR10":
        train_data, test_data = load_cifar10()
    elif dataset == "CIFAR100":
        out_dim = 100
        train_data, test_data = load_cifar100()
    elif dataset == "MNIST":
        train_data = ColoredMNIST(train=True)
        test_data = ColoredMNIST(train=False)
        return train_data, test_data, out_dim
    else:
        sys.exit(f"Dataset {dataset} is an invalid dataset.")

    # Subsample as necessary.
    if subsample > 0:
        train_data = torch.utils.data.Subset(
            train_data,
            np.random.choice(list(range(len(train_data))),
                             size=subsample,
                             replace=False),
        )
        test_data = torch.utils.data.Subset(
            test_data,
            np.random.choice(list(range(len(test_data))),
                             size=int(0.2 * subsample),
                             replace=False),
        )

    if binary:
        train_data = ReducedImageDataset(train_data, n_classes=2)
        test_data = ReducedImageDataset(test_data, n_classes=2)
        out_dim = 2
        if add_patches:
            train_data = BinaryModImageDataset(train_data, patch_scale=1e-1, low_var=low_var)
    elif add_patches:
        # train_data = PatchedImageDataset(train_data, patch_width=1)
        # train_data = MultiModImageDataset(train_data, patch_scale=1e-3, patch_height=16, patch_width=16, n_classes=out_dim)
        train_data = ConcatImageDataset(train_data, feature_scale=5e-3, n_classes=out_dim, test_data=False, low_var=low_var)
        test_data = ConcatImageDataset(test_data, feature_scale=5e-3, n_classes=out_dim, test_data=True)

    return train_data, test_data, out_dim


def split_train_into_val(train_data, val_prop: float = 0.1):
    """Splits training dataset into train and val.

    Args:
        train_data: Training dataset.
        val_prop: Proportion of data to use for validation.
    """
    val_len = int(val_prop * len(train_data))
    train_subset, val_subset = torch.utils.data.random_split(
        train_data, [len(train_data) - val_len, val_len])
    return train_subset, val_subset
