import os
import gc
import json
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from multiprocessing import pool

from .dataset_leaf import TinyImageNetDataset, FEMNISTDataset, ShakespeareDataset, LabelNoiseDataset

# construct a client dataset (training & test set)
def get_dataset(args):
    """
    Retrieve requested datasets.

    Args:
        args (argparser): arguments

    Returns:
        metadata: {0: [indices_1], 1: [indices_2], ... , K: [indices_K]}
        global_testset: (optional) test set located at the central server,
        client datasets: [tuple(local_training_set[indices_1], local_test_set[indices_1]), tuple(local_training_set[indices_2], local_test_set[indices_2]), ...]
    """
    def construct_dataset(indices):
        subset = torch.utils.data.Subset(raw_train, indices)
        test_size = int(len(subset) * args.test_fraction)
        return (torch.utils.data.random_split(subset, [len(subset) - test_size, test_size]))

    if args.dataset in ['MNIST', 'CIFAR10', 'CIFAR100']:
        # call raw datasets
        raw_train = torchvision.datasets.__dict__[args.dataset](
            root=args.data_path,
            train=True,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.Resize(28),
                    torchvision.transforms.ToTensor()
                ]
            ) if 'MNIST' not in args.dataset else torchvision.transforms.ToTensor(),
            download=True
        )
        raw_test = torchvision.datasets.__dict__[args.dataset](
            root=args.data_path,
            train=False,
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.Resize(28),
                    torchvision.transforms.ToTensor()
                ]
            ) if 'MNIST' not in args.dataset else torchvision.transforms.ToTensor(),
            download=True
        )
        if args.label_noise:
            raw_train = LabelNoiseDataset(
                args,
                dataset=raw_train,
                transform=torchvision.transforms.Compose(
                    [
                        torchvision.transforms.Resize(28),
                        torchvision.transforms.ToTensor()
                    ]
                ) if 'MNIST' not in args.dataset else torchvision.transforms.ToTensor()
            )

        # get split indices
        split_map = split_data(args, raw_train)

        # construct client datasets
        with pool.ThreadPool(processes=args.n_jobs) as workhorse:
            client_datasets = workhorse.map(construct_dataset, tqdm(split_map.values(), desc=f'[INFO] ...create datasets [{args.dataset}]!'))
        return split_map, raw_test, client_datasets

    elif args.dataset == 'TinyImageNet':
        # call raw dataset
        raw_train = torchvision.datasets.ImageFolder(
            os.path.join(args.data_path, 'tiny-imagenet-200', 'train'),
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.RandomRotation(20),
                    torchvision.transforms.RandomHorizontalFlip(0.5),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]
            )
        )
        raw_test = TinyImageNetDataset(
            img_path=os.path.join(args.data_path, 'tiny-imagenet-200', 'val', 'images'),
            gt_path=os.path.join(args.data_path, 'tiny-imagenet-200', 'val', 'val_annotations.txt'),
            class_to_idx=raw_train.class_to_idx.copy(),
            transform=torchvision.transforms.Compose(
                [
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]
            )
        )

        # get split indices
        split_map = split_data(args, raw_train)

        # construct client datasets
        with pool.ThreadPool(processes=args.n_jobs) as workhorse:
            client_datasets = workhorse.map(construct_dataset, tqdm(split_map.values(), desc=f'[INFO] ...create datasets [{args.dataset}]!'))
        return split_map, raw_test, client_datasets

    elif args.dataset in ['FEMNIST', 'Shakespeare']:
        assert args.split_type == 'realistic', '[ERROR] LEAF benchmark dataset is only supported for `realistic` split scenario!'

        # parse dataset
        parser = LEAFParser(args)

        # construct client datasets
        split_map, client_datasets = parser.get_datasets()
        return split_map, None, client_datasets


def split_data(args, raw_train):
    """Split data indices using labels.

    Args:
        args (argparser): arguments
        raw_train (dataset): raw dataset object to parse

    Returns:
        split_map (dict): dictionary with key is a client index (~args.K) and a corresponding value is a list of indice array
    """
    # IID split (i.e., statistical homogeneity)
    if args.split_type == 'iid':
        # randomly shuffle label indices
        shuffled_indices = np.random.permutation(len(raw_train))

        # split shuffled indices by the number of clients
        split_indices = np.array_split(shuffled_indices, args.K)

        # construct a hashmap
        split_map = {i: split_indices[i] for i in range(args.K)}
        return split_map

    # Non-IID split proposed in McMahan et al., 2016 (i.e., each client has samples from at least two different classes)
    elif args.split_type == 'pathological':
        assert args.dataset in ['MNIST',
                                'CIFAR10'], '[ERROR] `pathological non-IID setting` is supported only for `MNIST` or `CIFAR10` dataset!'
        assert len(
            raw_train) / args.shard_size / args.K == 2, '[ERROR] each client should have samples from class at least 2 different classes!'

        # sort data by labels
        sorted_indices = np.argsort(np.array(raw_train.targets))
        shard_indices = np.array_split(sorted_indices, len(raw_train) // args.shard_size)

        # sort the list to conveniently assign samples to each clients from at least two~ classes
        split_indices = [[] for _ in range(args.K)]

        # retrieve each shard in order to each client
        for idx, shard in enumerate(shard_indices):
            split_indices[idx % args.K].extend(shard)

        # construct a hashmap
        split_map = {i: split_indices[i] for i in range(args.K)}
        return split_map

    # Non-IID split proposed in Hsu et al., 2019 (i.e., using Dirichlet distribution to simulate non-IID split)
    # https://github.com/QinbinLi/FedKT/blob/0bb9a89ea266c057990a4a326b586ed3d2fb2df8/experiments.py
    elif args.split_type == 'dirichlet':
        split_map = dict()

        # container
        client_indices_list = [[] for _ in range(args.K)]

        # iterate through all classes
        for c in range(args.num_classes):
            # get corresponding class indices
            target_class_indices = np.where(np.array(raw_train.targets) == c)[0]

            # shuffle class indices
            np.random.shuffle(target_class_indices)

            # get label retrieval probability per each client based on a Dirichlet distribution
            proportions = np.random.dirichlet(np.repeat(args.alpha, args.K))
            proportions = np.array(
                [p * (len(idx) < len(raw_train) / args.K) for p, idx in zip(proportions, client_indices_list)])

            # normalize
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(target_class_indices)).astype(int)[:-1]

            # split class indices by proportions
            idx_split = np.array_split(target_class_indices, proportions)
            client_indices_list = [j + idx.tolist() for j, idx in zip(client_indices_list, idx_split)]

        # shuffle finally and create a hashmap
        for j in range(args.K):
            np.random.seed(args.global_seed);
            np.random.shuffle(client_indices_list[j])
            if len(client_indices_list[j]) > 10:
                split_map[j] = client_indices_list[j]
        return split_map

    # LEAF benchmark dataset
    elif args.split_type == 'realistic':
        return print('[INFO] No need to split... use LEAF parser directly!')


class LEAFParser:
    def __init__(self, args):
        self.root = args.data_path
        self.n_jobs = args.n_jobs
        self.dataset_name = args.dataset.lower()

        # declare appropriate dataset class
        if 'femnist' in self.dataset_name:
            self.dataset_class = FEMNISTDataset
        elif 'shakespeare' in self.dataset_name:
            self.dataset_class = ShakespeareDataset
        else:
            raise NotImplementedError(f'[ERROR] {self.dataset_name} is not supported yet!')

        # set path
        self.train_root = f'{self.root}/{self.dataset_name.lower()}/train'
        self.test_root = f'{self.root}/{self.dataset_name}/test'

        # get raw data
        self.raw_train = self._parse_data(self.train_root, 'train')
        self.raw_test = self._parse_data(self.test_root, 'test')

        # merge raw data
        self.merged_train = self._merge_raw_data(self.raw_train, 'train')
        self.merged_test = self._merge_raw_data(self.raw_test, 'test')
        del self.raw_train, self.raw_test
        gc.collect()

        # make dataset for each client
        self.split_map, self.datasets = self._convert_to_dataset(self.merged_train, self.merged_test)
        del self.merged_train, self.merged_test
        gc.collect()

    def _parse_data(self, root, mode):
        raw_all = []
        for file in tqdm(os.listdir(root), desc=f'[INFO] ...parsing {mode} data (LEAF - {self.dataset_name.upper()})'):
            with open(f'{root}/{file}') as raw_files:
                for raw_file in raw_files:
                    raw_all.append(json.loads(raw_file))
        return raw_all

    def _merge_raw_data(self, data, mode):
        merged_raw_data = {'users': list(), 'num_samples': list(), 'user_data': dict()}
        for raw_data in tqdm(data, desc=f'[INFO] ...merging raw {mode} data (LEAF - {self.dataset_name.upper()})'):
            merged_raw_data['users'].extend(raw_data['users'])
            merged_raw_data['num_samples'].extend(raw_data['num_samples'])
            merged_raw_data['user_data'] = {**merged_raw_data['user_data'], **raw_data['user_data']}
        return merged_raw_data

    def _convert_to_dataset(self, merged_train, merged_test):
        """
        Returns:
            [tuple(local_training_set[indices_1], local_test_set[indices_1]), tuple(local_training_set[indices_2], local_test_set[indices_2]), ...]
        """

        def construct_leaf(idx, user):
            # copy dataset class prototype for each training set and test set
            tr_dset, te_dset = self.dataset_class(), self.dataset_class()
            setattr(tr_dset, 'train', True);
            setattr(te_dset, 'train', False)

            # set essential attributes
            tr_dset.identifier = user;
            te_dset.identifier = user
            tr_dset.data = merged_train['user_data'][user];
            te_dset.data = merged_test['user_data'][user]
            tr_dset.num_samples = merged_train['num_samples'][idx];
            te_dset.num_samples = merged_test['num_samples'][idx]
            tr_dset._make_dataset();
            te_dset._make_dataset()
            return (tr_dset, te_dset)

        with pool.ThreadPool(processes=self.n_jobs) as workhorse:
            datasets = workhorse.starmap(construct_leaf, [(idx, user) for idx, user in
                                                          tqdm(enumerate(merged_train['users']),
                                                               desc=f'[INFO] ...create datasets [LEAF - {self.dataset_name.upper()}]!')])
        split_map = dict(zip([i for i in range(len(merged_train['user_data']))],
                             list(map(sum, zip(merged_train['num_samples'], merged_test['num_samples'])))))
        return split_map, datasets

    def get_datasets(self):
        assert self.datasets is not None, '[ERROR] dataset is not constructed internally!'
        return self.split_map, self.datasets
