from typing import Tuple

import torch
from torch.utils.data import DataLoader, Subset, ConcatDataset, RandomSampler, Dataset
from torchvision.datasets import MNIST

from experiments.datasets import DataLoaders
from experiments.utils.active_learning_data import get_balanced_sample_indices, ActiveLearningData
from experiments.utils.train_validation_split import train_validation_split

device = torch.device("cuda")


class FastMNIST(MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Scale data to [0,1]
        self.data = self.data.unsqueeze(1).float().div(255)

        # Normalize it with the usual MNIST mean and std
        self.data = self.data.sub_(0.1307).div_(0.3081)

        # Put both data and targets on GPU in advance
        self.data, self.targets = self.data.to(device), self.targets.to(device)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        return img, target


global_train_dataset: Dataset = None
global_test_dataset: Dataset = None


def init_FastMNIST():
    global global_train_dataset, global_test_dataset
    if not global_train_dataset:
        global_train_dataset = FastMNIST("data/MNIST", train=True, download=True)

    if not global_test_dataset:
        global_test_dataset = FastMNIST("data/MNIST", train=False, download=True)


def dataloaders(
        train_batch_size, test_batch_size, *, swap_train_test=False, train_only=False, validation_size: int = 0
) -> DataLoaders:
    init_FastMNIST()

    test_dataset, train_dataset = global_test_dataset, global_train_dataset

    if swap_train_test:
        train_dataset, test_dataset = test_dataset, train_dataset

    if train_only:
        test_dataset = train_dataset

    if validation_size > 0:
        validation_indices = get_balanced_sample_indices(train_dataset.targets, 10, validation_size // 10)
        al = ActiveLearningData(train_dataset)
        validation_dataset = al.extract_dataset_from_pool_from_indices(validation_indices)
        train_dataset = al.pool_dataset
    else:
        validation_dataset = None

    train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, drop_last=True, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)
    train_eval_loader = DataLoader(train_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)

    if validation_dataset:
        validation_loader = DataLoader(validation_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)
    else:
        validation_loader = None

    return DataLoaders(train_loader, test_loader, train_eval_loader, validation_loader)


def combined_dataloaders(train_batch_size, test_batch_size) -> DataLoaders:
    init_FastMNIST()

    combined_dataset = ConcatDataset([global_train_dataset, global_test_dataset])

    train_loader = DataLoader(combined_dataset, batch_size=train_batch_size, shuffle=True, drop_last=True,
                              num_workers=0)
    test_loader = DataLoader(combined_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)
    train_eval_loader = DataLoader(combined_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0)

    return DataLoaders(train_loader, test_loader, train_eval_loader, None)
