import torch
import torch.distributions
from torchvision import datasets, transforms
import os
from utils.datasets.paths import get_base_data_dir, get_LSUN_scenes_path
from utils.datasets.imagenet_augmentation import get_imageNet_augmentation
from torch.utils.data import DataLoader, SubsetRandomSampler, Sampler, Dataset
import numpy as np

def _generate_subset(dataset, num_classes, points_per_class):
    total_points = num_classes * points_per_class
    subset_idcs = torch.zeros(total_points, dtype=torch.long)

    end_idx = 0
    for i in range(num_classes):
        start_idx = end_idx
        end_idx = dataset.indices[i]

        class_i_idcs = torch.arange(start_idx, end_idx)
        rand_idcs = torch.randperm(len(class_i_idcs))[:points_per_class]

        subset_idcs[i*points_per_class:(i+1)*points_per_class] = class_i_idcs[rand_idcs]

    print(f'LSUN subset generated - samples per class {points_per_class}')
    return subset_idcs


class LSUN_subset(Dataset):
    def __init__(self, idcs, split='train', transform=None):
        super().__init__()
        self.idcs = idcs
        path = get_LSUN_scenes_path()
        self.dataset = datasets.LSUN(path, classes=split, transform=transform)

    def __getitem__(self, ii):
        index = self.idcs[ii]
        return self.dataset[index]

    def __len__(self):
        return len(self.idcs)


def get_LSUN_scenes_subset(split='train', samples_per_class=100_000, batch_size=128, shuffle=True, augm_type='none',
                           augm_class='imagenet', num_workers=8, size=224, config_dict=None):

    augm_config = {}
    if augm_class == 'imagenet':
        transform = get_imageNet_augmentation(type=augm_type, out_size=size, config_dict=augm_config)
    elif augm_class == 'cifar':
        raise NotImplementedError()
    else:
        raise NotImplementedError()

    idcs_filename = f'ssl_unlabeled_lsun_{split}_{samples_per_class}.pt'
    if os.path.isfile(idcs_filename):
        idcs = torch.load(idcs_filename)
    else:
        path = get_LSUN_scenes_path()
        dataset = datasets.LSUN(path, classes=split, transform=transform)
        num_classes = len(dataset.dbs)
        idcs = _generate_subset(dataset, num_classes, samples_per_class)
        torch.save(idcs, idcs_filename)
    idcs = idcs.view(-1).numpy()

    dataset = LSUN_subset(idcs, split, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                        num_workers=num_workers)

    if config_dict is not None:
        config_dict['Dataset'] = 'LSUN-Subset'
        config_dict['Batch size'] = batch_size
        config_dict['Samples per class'] = samples_per_class
        config_dict['Augmentation'] = augm_config

    return loader

if __name__ == "__main__":
    train_loader = get_LSUN_scenes_subset(batch_size=128, shuffle=True, augm_type='none', size=32)
    lsun_dataset = train_loader.dataset
    lsun_idcs = train_loader.sampler.indices
    lsun_data = np.zeros((len(lsun_idcs), 32, 32, 3), dtype=np.uint8)
    for i in range(len(lsun_idcs)):
        idx = lsun_idcs[i]
        img, _ = lsun_dataset[idx]
        img = 255. * img.permute(1, 2, 0).numpy()
        lsun_data[i] = img

        if i % 10_000 == 0 :
            print(i)

    np.save('lsun_1M.npy', lsun_data)
