import os
import numpy as np
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader
import random

class CUBICCDataset(torch.utils.data.Dataset):
    train_data_sup, train_labels_sup = None, None
    train_data_unsup, train_labels_unsup = None, None
    train_data, train_labels = None, None

    def __init__(self, datadir, mode, sup_frac=1.0, missing='image', **kwargs):
        self.images = torch.load(os.path.join(datadir,'images.pt'))
        self.captions = torch.load(os.path.join(datadir,'captions.pt')) 
        self.labels = torch.load(os.path.join(datadir,'labels.pt'))
        self.train_split = np.load(os.path.join(datadir, 'train_split.npy'))
        self.validation_split = np.load(os.path.join(datadir, 'validation_split.npy'))
        self.test_split = np.load(os.path.join(datadir, 'test_split.npy'))

        assert mode in ["sup", "unsup", "test"], "invalid train/test option values"

        if mode in ["sup", "unsup"]:
            if CUBICCDataset.train_data is None:
                print("Splitting dataset")
                image, sentense, CUBICCDataset.train_labels = [self.images[idx] for idx in self.train_split], [self.captions[idx] for idx in self.train_split], [self.labels[idx] for idx in self.train_split]
                CUBICCDataset.train_data = (image, sentense)

                num_sup_samples = int(sup_frac * len(CUBICCDataset.train_labels))
                CUBICCDataset.train_data_sup = (CUBICCDataset.train_data[0][:num_sup_samples], CUBICCDataset.train_data[1][:num_sup_samples])
                CUBICCDataset.train_labels_sup = CUBICCDataset.train_labels[:num_sup_samples]
                CUBICCDataset.train_data_unsup = (CUBICCDataset.train_data[0][num_sup_samples:], CUBICCDataset.train_data[1][num_sup_samples:])
                CUBICCDataset.train_labels_unsup = CUBICCDataset.train_labels[num_sup_samples:]
                
                if sup_frac != 1.0:
                    shuf = np.linspace(0, len(CUBICCDataset.train_labels_unsup)-1, len(CUBICCDataset.train_labels_unsup)).astype(np.int)
                    np.random.shuffle(shuf)
                    if missing == 'image':
                        CUBICCDataset.train_data_unsup = (CUBICCDataset.train_data_unsup[0], [CUBICCDataset.train_data_unsup[1][idx] for idx in shuf])
                        CUBICCDataset.train_labels_unsup = [CUBICCDataset.train_labels_unsup[idx] for idx in shuf]
                    elif missing == 'sentence':
                        CUBICCDataset.train_data_unsup = ([CUBICCDataset.train_data_unsup[0][idx] for idx in shuf], CUBICCDataset.train_data_unsup[1])
                        CUBICCDataset.train_labels_unsup = [CUBICCDataset.train_labels_unsup[idx] for idx in shuf]

            if mode == "sup":
                self.data, self.labels = CUBICCDataset.train_data_sup, CUBICCDataset.train_labels_sup
                print("Num sup samples %i" % len(self.labels))
            else:
                self.data, self.labels = CUBICCDataset.train_data_unsup, CUBICCDataset.train_labels_unsup
                print("Num unsup samples %i" % len(self.labels))
        else:
            index = np.concatenate((self.validation_split, self.test_split))
            test_image, test_sentense, test_labels = [self.images[idx] for idx in index], [self.captions[idx] for idx in index], [self.labels[idx] for idx in index]
            self.data = (test_image, test_sentense)
            self.labels = test_labels

    def __getitem__(self, idx):
        return self.data[0][idx], self.data[1][idx], self.labels[idx]

    def __len__(self):
        return len(self.labels)
    

def setup_CUBICC_loaders(batch_size, sup_frac=1.0, missing='image', root=None, **kwargs):
    if 'num_workers' not in kwargs:
        kwargs = {'num_workers': 2, 'pin_memory': True}

    cached_data = {}
    loaders = {}

    if sup_frac == 0.0:
        modes = ["unsup"]
    elif sup_frac == 1.0:
        modes = ["sup"]
    else:
        modes = ["unsup", "sup"]

    for mode in modes:
        cached_data[mode] = CUBICCDataset(datadir=root, mode=mode, sup_frac=sup_frac, missing=missing)
        loaders[mode] = DataLoader(cached_data[mode], batch_size=batch_size, shuffle=True, drop_last=False, **kwargs)

    cached_data['test'] = CUBICCDataset(datadir=root, mode='test', sup_frac=sup_frac, missing=missing)
    loaders['test'] = DataLoader(cached_data['test'], batch_size=batch_size, shuffle=False, drop_last=False, **kwargs)
    return loaders 


def get_8_CUBICC_samples(CUBICC, num_testing_images, device):
    samples = []
    for i in range(8):
        for r in range(num_testing_images):
            image, sentence, target = CUBICC.__getitem__(r)
            if target == i:
                image = image.to(device)
                sentence = sentence.to(device)
                samples.append((image, sentence, target))
                break
    outputs = []
    for mod in range(2):
        outputs_mod = [samples[digit][mod] for digit in range(8)]
        outputs.append(torch.stack(outputs_mod))
    return outputs


def get_some_CUBICC_samples(CUBICC, num_testing_images, num, device):
    samples = []
    for _ in range(num):
        image, sentence, target = CUBICC.__getitem__(random.randint(0, num_testing_images - 1))
        image = image.to(device)
        sentence = sentence.to(device)
        samples.append((image, sentence, target))
    outputs = []
    for mod in range(2):
        outputs_mod = [samples[digit][mod] for digit in range(num)]
        outputs.append(torch.stack(outputs_mod))
    return outputs
