import numpy as np
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST, SVHN
import random


def join_datasets(data_1, labels_1, data_2, labels_2):
    data_arr_1 = []
    data_arr_2 = []
    label_arr = []
    for i in range(10):
        inds_1 = np.array(np.where(labels_1 == i)[0])
        inds_2 = np.array(np.where(labels_2 == i)[0])
        num = min(len(inds_1), len(inds_2))
        data_arr_1.append(data_1[inds_1[:num]] / 255.0)
        data_arr_2.append(data_2[inds_2[:num]] / 255.0)
        label_arr.append(torch.ones(num) * i)
    return torch.cat(data_arr_1, dim=0), torch.cat(data_arr_2, dim=0), torch.cat(label_arr, dim=0)

class MNIST_SVHN_Dataset(torch.utils.data.Dataset):
    train_data_sup, train_labels_sup = None, None
    train_data_unsup, train_labels_unsup = None, None
    train_data, train_labels = None, None

    def __init__(self, root, mode, download=True, sup_frac=1.0, missing='mnist', **kwargs):
        super().__init__()
        mnist = MNIST(root, train=True if mode is not 'test' else False, download=download)
        svhn = SVHN(root, split='train' if mode is not 'test' else 'test', download=download)
       
        self.mnist_data, self.mnist_labels = mnist.data.view(-1, 1, 28, 28).float(), mnist.targets
        self.svhn_data, self.svhn_labels = torch.tensor(svhn.data).float(), svhn.labels
        
        assert mode in ["sup", "unsup", "test"], "invalid train/test option values"

        if mode in ["sup", "unsup"]:
            if MNIST_SVHN_Dataset.train_data is None:
                print("Splitting dataset")
                mnist, svhn, MNIST_SVHN_Dataset.train_labels = join_datasets(self.mnist_data, self.mnist_labels, self.svhn_data, self.svhn_labels)
                MNIST_SVHN_Dataset.train_data = (mnist, svhn)

                num_sup_samples = int(sup_frac * MNIST_SVHN_Dataset.train_labels.shape[0])
                MNIST_SVHN_Dataset.train_data_sup = (MNIST_SVHN_Dataset.train_data[0][:num_sup_samples], MNIST_SVHN_Dataset.train_data[1][:num_sup_samples])
                MNIST_SVHN_Dataset.train_labels_sup = MNIST_SVHN_Dataset.train_labels[:num_sup_samples]
                MNIST_SVHN_Dataset.train_data_unsup = (MNIST_SVHN_Dataset.train_data[0][num_sup_samples:], MNIST_SVHN_Dataset.train_data[1][num_sup_samples:])
                MNIST_SVHN_Dataset.train_labels_unsup = MNIST_SVHN_Dataset.train_labels[num_sup_samples:]
                
                if sup_frac != 1.0:
                    shuf = np.linspace(0, MNIST_SVHN_Dataset.train_labels_unsup.shape[0]-1, MNIST_SVHN_Dataset.train_labels_unsup.shape[0])
                    np.random.shuffle(shuf)
                    if missing == 'mnist':
                        MNIST_SVHN_Dataset.train_data_unsup = (MNIST_SVHN_Dataset.train_data_unsup[0], MNIST_SVHN_Dataset.train_data_unsup[1][shuf])
                        MNIST_SVHN_Dataset.train_labels_unsup = MNIST_SVHN_Dataset.train_labels_unsup[shuf]
                    elif missing == 'svhn':
                        MNIST_SVHN_Dataset.train_data_unsup = (MNIST_SVHN_Dataset.train_data_unsup[0][shuf], MNIST_SVHN_Dataset.train_data_unsup[1])
                        MNIST_SVHN_Dataset.train_labels_unsup = MNIST_SVHN_Dataset.train_labels_unsup[shuf]

            if mode == "sup":
                self.data, self.labels = MNIST_SVHN_Dataset.train_data_sup, MNIST_SVHN_Dataset.train_labels_sup
                print("Num sup samples %i" % self.labels.shape[0])
            else:
                self.data, self.labels = MNIST_SVHN_Dataset.train_data_unsup, MNIST_SVHN_Dataset.train_labels_unsup
                print("Num unsup samples %i" % self.labels.shape[0])
        else:
            mnist, svhn, labels = join_datasets(self.mnist_data,
                                                   self.mnist_labels,
                                                   self.svhn_data,
                                                   self.svhn_labels)
            self.data = (mnist, svhn)
            self.labels = labels

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        return self.data[1][index], self.data[0][index], self.labels[index]


def setup_MNIST_SVHN_loaders(batch_size, sup_frac=1.0, missing='mnist', root=None, **kwargs):
    if 'num_workers' not in kwargs:
        kwargs = {'num_workers': 4, 'pin_memory': True}

    cached_data = {}
    loaders = {}

    if sup_frac == 0.0:
        modes = ["unsup"]
    elif sup_frac == 1.0:
        modes = ["sup"]
    else:
        modes = ["unsup", "sup"]

    for mode in modes:
        cached_data[mode] = MNIST_SVHN_Dataset(root=root, mode=mode, sup_frac=sup_frac, missing=missing)
        loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, drop_last=False, **kwargs)

    cached_data['test'] = MNIST_SVHN_Dataset(root=root, mode='test', sup_frac=sup_frac, missing=missing)
    loaders['test'] = DataLoader(cached_data['test'], batch_size=batch_size, shuffle=False, drop_last=False, **kwargs)
    return loaders 


def get_10_mnist_svhn_samples(mnist_svhn, num_testing_images, device):
    samples = []
    for i in range(10):
        for r in range(num_testing_images):
            svhn, mnist, target = mnist_svhn.__getitem__(r)
            if target == i:
                svhn = svhn.to(device)
                mnist = mnist.to(device)
                samples.append((svhn, mnist, target))
                break
    outputs = []
    for mod in range(2):
        outputs_mod = [samples[digit][mod] for digit in range(10)]
        outputs.append(torch.stack(outputs_mod))
    return outputs

def get_some_mnist_svhn_samples(mnist_svhn, num_testing_images, num, device):
    samples = []
    for _ in range(num):
        svhn, mnist, target = mnist_svhn.__getitem__(random.randint(0, num_testing_images - 1))
        svhn = svhn.to(device)
        mnist = mnist.to(device)
        samples.append((svhn, mnist, target))
    outputs = []
    for mod in range(2):
        outputs_mod = [samples[digit][mod] for digit in range(num)]
        outputs.append(torch.stack(outputs_mod))
    return outputs
