import torch
import os
import random
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from collections import Counter
from src.utils.ntk_util import get_full_data
from src.core.datasets import Cifar10, Cifar100, \
    ImbalancedCifar10, MNIST, ImbalancedMNIST, FashionMNIST, SVHN, TinyImageNet


DATASETS = {
    'cifar100': Cifar100,
    'cifar10': Cifar10,
    'imb_cifar10': ImbalancedCifar10,
    'mnist': MNIST,
    'imb_mnist': ImbalancedMNIST,
    'fashionmnist': FashionMNIST,
    'svhn': SVHN,
    'tiny_imagenet': TinyImageNet,
}

MODES = ['train', 'unlabeled', 'val']


class SubsetSequentialSampler(torch.utils.data.Sampler):
    r"""Samples elements sequentially from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self):
        return len(self.indices)


def build(data_config, logger):
    data_name = data_config['name']
    root = data_config['root']
    batch_size = data_config['batch_size']
    test_batch_size = data_config.get('test_batch_size', batch_size)
    num_workers = data_config['num_workers']
    transform_config = data_config['transform']
    al_params = data_config['al_params']

    dataloaders = {}
    if data_name == 'cifar10' or data_name == 'cifar100':
        #begin_index = al_params.get('init_num', al_params['add_num'])
        #num_subset = al_params['num_subset']
        #init_label_path = os.path.join(
        #    root, '{}_init_label_set_{}.npy'.format(data_name, begin_index))
        #init_unlabel_path = os.path.join(
        #    root, '{}_init_unlabel_set_{}.npy'.format(data_name, begin_index))
        #init_subset_path = os.path.join(
        #    root, '{}_init_subset_{}.npy'.format(data_name, num_subset))

        #path_exists = os.path.exists(init_label_path) and\
        #    os.path.exists(init_unlabel_path) and os.path.exists(init_subset_path)
        #if path_exists:
        #    labeled_set = np.load(init_label_path)
        #    unlabeled_set = np.load(init_unlabel_path)
        #    subset = np.load(init_subset_path)
        #else:
        #    indices = list(range(50000))
        #    random.shuffle(indices)
        #    labeled_set = indices[:begin_index]
        #    unlabeled_set = indices[begin_index:]
        #    random.shuffle(unlabeled_set)
        #    subset = unlabeled_set[:al_params['num_subset']]

        #    np.save(init_label_path, labeled_set)
        #    np.save(init_unlabel_path, unlabeled_set)
        #    np.save(init_subset_path, subset)

        indices = list(range(50000))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size)

            dataloaders[mode] = dataloader

    elif data_name == 'imb_cifar10':
        indices = list(range(27500))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size)

            dataloaders[mode] = dataloader

    elif data_name == 'tiny_imagenet':
        indices = list(range(100000))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, 'train', download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size)

            dataloaders[mode] = dataloader

    elif data_name == 'mnist':
        indices = list(range(60000))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set
        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                val_subset = list(range(len(dataset)))
                random.shuffle(val_subset)
                dataloader = DataLoader(dataset, batch_size=test_batch_size,
                                        num_workers=num_workers)

            dataloaders[mode] = dataloader

    elif data_name == 'imb_mnist':
        indices = list(range(32462))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size)

            dataloaders[mode] = dataloader

    elif data_name == 'fashionmnist':
        indices = list(range(60000))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size)

            dataloaders[mode] = dataloader

    elif data_name == 'svhn':
        indices = list(range(73257))
        random.shuffle(indices)
        begin_index = al_params.get('init_num', al_params['add_num'])
        labeled_set = indices[:begin_index]
        unlabeled_set = indices[begin_index:]
        random.shuffle(unlabeled_set)
        subset = unlabeled_set[:al_params['num_subset']]

        dataloaders['subset'] = subset
        dataloaders['labeled_set'] = labeled_set
        dataloaders['unlabeled_set'] = unlabeled_set

        for mode in MODES:
            transform = compose_transforms(data_name, transform_config, mode, logger)
            if mode == 'train':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=batch_size,
                                        sampler=SubsetRandomSampler(labeled_set),
                                        num_workers=num_workers)
            elif mode == 'unlabeled':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=1,
                                        sampler=SubsetSequentialSampler(subset),
                                        num_workers=num_workers)
            elif mode == 'val':
                dataset = DATASETS[data_name](
                    root, mode, download=True, logger=logger, transform=transform)
                dataloader = DataLoader(dataset, batch_size=test_batch_size,
                                        num_workers=num_workers)

            dataloaders[mode] = dataloader
    else:
        logger.error(
            'Specify valid data name'.format(DATASETS.keys()))

    assert len(set(labeled_set) & set(subset)) == 0
    assert len(set(subset) & set(unlabeled_set[len(subset):])) == 0
    X_train, y_train = get_full_data(dataloaders['train'].dataset, labeled_set)

    count_train = Counter(y_train.numpy())
    ordered_count = sorted(count_train.items(), key=lambda i: i[0])
    print('y train: {}'.format(str(ordered_count)), flush=True)

    # import IPython; IPython.embed()
    #
    # data = []
    # labels = []
    # for item in dataloaders['train']:
    #     data.append(item['inputs'].clone())
    #     labels.append(item['labels'].clone())
    #
    # labels = torch.cat(labels).tolist()
    # data = list(torch.cat(data).tensor_split(len(labels)))
    #
    # for _ in range(10):
    #     for i, item in enumerate(dataloaders['train']):
    #         if i == 0:
    #             print('Mean of first mini-batch:', item['inputs'].mean())
    #
    # label_to_indices = {}
    # label_to_points = {}
    # for label in range(10):
    #     label_to_indices[label] = []
    #     for i in range(len(labels)):
    #         if labels[i] == label:
    #             label_to_indices[label].append(i)
    #     points = torch.cat([data[i] for i in label_to_indices[label]])
    #     label_to_points[label] = points
    #     print('(mu, sigma) for class {}: ({:.3f}, {:.3f})'.format(label, points.mean(), points.std()))
    #     print('weird std for class {}: {:.3f}'.format(label, points.mean(axis=2).std()))


    return dataloaders


def update(cycle, dataloaders, arg, data_config, model_config, writer, save_dir):
    num_classes = model_config['model_arch']['num_classes']
    al_params = data_config['al_params']
    add_num, num_subset = al_params['add_num'], al_params['num_subset']
    batch_size, num_workers = data_config['batch_size'], data_config['num_workers']
    labeled_set, unlabeled_set, subset =\
        dataloaders['labeled_set'], dataloaders['unlabeled_set'], dataloaders['subset']
    new_dict = {}

    # Update the labeled dataset, the unlabeled dataset, and subset
    queried_labeled_set = list(torch.tensor(subset)[arg][-add_num:].numpy())
    worst_queries = list(torch.tensor(subset)[arg][:add_num].numpy())
    new_labeled_set = list(labeled_set) + queried_labeled_set
    new_unlabeled_set = list(torch.tensor(subset)[arg][:-add_num].numpy()) + list(unlabeled_set[num_subset:])
    random.shuffle(new_unlabeled_set)
    new_subset = new_unlabeled_set[:num_subset]

    save_dir = Path(save_dir).joinpath('query_info')
    save_dir.mkdir(parents=True, exist_ok=True)
    if cycle == 0:
        path = save_dir.joinpath('train_cycle_0')
        np.save(str(path.absolute()), labeled_set)
    path = save_dir.joinpath('train_cycle_{}'.format(cycle+1))
    np.save(str(path.absolute()), new_labeled_set)
    path = save_dir.joinpath(('worst_cycle_{}'.format(cycle+1)))
    np.save(str(path.absolute()), worst_queries)

    dataloaders['labeled_set'] = new_labeled_set
    dataloaders['unlabeled_set'] = new_unlabeled_set
    dataloaders['subset'] = new_subset

    # check if label_set and subset are overlapping
    if len(set(new_labeled_set) & set(new_subset)) > 0:
        print('Problematic new_labeled_set and new_subset, they have overlapping points!')
        import IPython; IPython.embed()
    assert len(set(new_labeled_set) & set(new_subset)) == 0

    # Update train and unlabeled dataloaders
    train_dataset = dataloaders['train'].dataset
    unlabeled_dataset = dataloaders['unlabeled'].dataset
    dataloaders['train'] = DataLoader(
        train_dataset, batch_size=batch_size,
        sampler=SubsetRandomSampler(new_labeled_set), num_workers=num_workers)
    dataloaders['unlabeled'] = DataLoader(
        unlabeled_dataset, batch_size=1,
        sampler=SubsetSequentialSampler(new_subset), num_workers=num_workers)
    # add summary
    #queried_labels = []
    #queried_inputs = []
    #for idx in queried_labeled_set:
    #    inputs = train_dataset[idx]['inputs']
    #    label = float(train_dataset[idx]['labels'])

    #    queried_inputs.append(inputs.unsqueeze(0))
    #    queried_labels.append(label)

    #queried_labels = torch.tensor(queried_labels)
    #label_hist = torch.histc(queried_labels, bins=num_classes, min=0, max=num_classes-1)
    #if len(queried_inputs) > 50:
    #    try:
    #        indices = np.random.choice(len(queried_inputs), 50)
    #    except:
    #        import IPython; IPython.embed()
    #else:
    #    indices = np.arange(len(queried_inputs))
    #queried_inputs = torch.cat(queried_inputs)[indices]

    #writer.add_histogram(
    #    values=queried_labels, global_step=cycle, tag='queried_label_hist', bins=10)
    #writer.add_images(
    #    img_tensor=queried_inputs, global_step=cycle, tag='queried_images')

    return dataloaders


def override(dataloaders, arg, data_config, add_num):
    al_params = data_config['al_params']
    num_subset = al_params['num_subset']
    batch_size, num_workers = data_config['batch_size'], data_config['num_workers']
    labeled_set, unlabeled_set, subset =\
        dataloaders['labeled_set'], dataloaders['unlabeled_set'], dataloaders['subset']
    new_dict = {}

    # Update the labeled dataset, the unlabeled dataset, and subset
    new_labeled_set = list(torch.tensor(subset)[arg][-add_num:].numpy())
    new_unlabeled_set = list(torch.tensor(subset)[arg][:-add_num].numpy()) + list(unlabeled_set[num_subset:])
    random.shuffle(new_unlabeled_set)
    new_subset = new_unlabeled_set[:num_subset]
    dataloaders['labeled_set'] = new_labeled_set
    dataloaders['unlabeled_set'] = new_unlabeled_set
    dataloaders['subset'] = new_subset

    # Update train and unlabeled dataloaders
    dataloaders['train'] = DataLoader(
        dataloaders['train'].dataset, batch_size=batch_size,
        sampler=SubsetRandomSampler(new_labeled_set), num_workers=num_workers)
    dataloaders['unlabeled'] = DataLoader(
        dataloaders['unlabeled'].dataset, batch_size=1,
        sampler=SubsetSequentialSampler(new_subset), num_workers=num_workers)

    return dataloaders


def normalization_params(data_name, logger):
    if 'cifar10' in data_name or 'svhn' in data_name:
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2023, 0.1994, 0.2010]
    elif 'mnist' in data_name:
        mean = (0.0,)
        std = (1.0,)
    elif data_name == 'fashionmnist':
        mean = (0.0,)
        std = (1.0,)
    elif data_name == 'tiny_imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        logger.error(
            'Specify valid data name'.format(DATASETS.keys()))
    return (mean, std)


def compose_transforms(data_name, transform_config, mode, logger):
    mean, std = normalization_params(data_name, logger)
    image_size = transform_config['image_size']
    crop_size = transform_config['crop_size']
    augment = transform_config.get('augment', False)
    flatten = transform_config.get('flatten', False)

    if 'cifar10' in data_name  or 'mnist' in data_name or \
            data_name == 'fashionmnist' or data_name == 'svhn' or data_name == 'tiny_imagenet':
        if mode == 'train' or mode == 'unlabeled':
            if augment:
                t_list = [
                    transforms.ToPILImage(),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(size=crop_size, padding=4),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std)]
            else:
                t_list = [
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std)]
        else:
            t_list = [
                transforms.ToTensor(),
                transforms.Normalize(mean, std)]
    else:
        logger.error(
            'Specify valid data name'.format(DATASETS.keys()))
    if flatten:
        t_list.append(transforms.Lambda(lambda x: torch.flatten(x)))
    return transforms.Compose(t_list)


