import random
import os
from typing import *

from constants import *

import torch
import torchvision

cifar_img_transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )
])

cifar_img_transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )
])


class CIFAR10Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str,
        digits: List[int],
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ):
        # Contains a MNIST dataset
        # Use different transforms for train vs test (like Dolphin)
        cifar_transform = cifar_img_transform_train if train else cifar_img_transform_test
        actual_transform = transform if transform is not None else cifar_transform
        
        self.cifar10_dataset = torchvision.datasets.CIFAR10(
            root,
            train=train,
            transform=actual_transform,
            target_transform=target_transform,
            download=download,
        )
        self.relevant_digits = list(
            filter(lambda d: d[1] in digits, self.cifar10_dataset))
        self.relevant_digits = [(tensor.to(DEVICE), digit) for tensor, digit in self.relevant_digits]
        self.index_map = list(range(len(self.relevant_digits)))
        random.shuffle(self.index_map)
        self.shuffled_digits = [self.relevant_digits[idx]
                                for idx in self.index_map]
        #to GPU
        self.targets = torch.tensor(
            [d[1] for d in self.shuffled_digits], device=DEVICE)

    def __len__(self):
        return len(self.shuffled_digits)

    def __getitem__(self, idx):
        return self.relevant_digits[self.index_map[idx]]

    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch)


def get_data(
    train: bool,
    digits: List[int] = [i for i in range(10)],
):
    data_dir = os.path.abspath(os.path.join(
        os.path.abspath(__file__), "../../data"))
    data = CIFAR10Dataset(
        data_dir,
        digits=digits,
        train=train,
        download=True,
    )
    sorted = torch.sort(data.targets)
    idxs = sorted.indices
    values = sorted.values
    ids_of_digit = {}
    for digit in digits:
        t = (values == digit).nonzero(as_tuple=True)[0]
        ids_of_digit[digit] = idxs[t[0]:t[-1]]
    return (data, ids_of_digit)
