import logging
import math
import os
import pickle
from random import random, shuffle

import numpy as np
from collections import defaultdict

from federatedscope.core.auxiliaries.utils import setup_seed

import federatedscope.register as register

logger = logging.getLogger(__name__)


def load_toy_data(config=None):

    generate = config.federate.mode.lower() == 'standalone'

    def _generate_data(client_num=5,
                       instance_num=1000,
                       feature_num=5,
                       save_data=False):
        """
        Generate data in FedRunner format
        Args:
            client_num:
            instance_num:
            feature_num:
            save_data:

        Returns:
            {
                '{client_id}': {
                    'train': {
                        'x': ...,
                        'y': ...
                    },
                    'test': {
                        'x': ...,
                        'y': ...
                    },
                    'val': {
                        'x': ...,
                        'y': ...
                    }
                }
            }

        """
        weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
        bias = np.random.normal(loc=0.0, scale=1.0)
        data = dict()
        for each_client in range(1, client_num + 1):
            data[each_client] = dict()
            client_x = np.random.normal(loc=0.0,
                                        scale=0.5 * each_client,
                                        size=(instance_num, feature_num))
            client_y = np.sum(client_x * weights, axis=-1) + bias
            client_y = np.expand_dims(client_y, -1)
            client_data = {'x': client_x, 'y': client_y}
            data[each_client]['train'] = client_data

        # test data
        test_x = np.random.normal(loc=0.0,
                                  scale=1.0,
                                  size=(instance_num, feature_num))
        test_y = np.sum(test_x * weights, axis=-1) + bias
        test_y = np.expand_dims(test_y, -1)
        test_data = {'x': test_x, 'y': test_y}
        for each_client in range(1, client_num + 1):
            data[each_client]['test'] = test_data

        # val data
        val_x = np.random.normal(loc=0.0,
                                 scale=1.0,
                                 size=(instance_num, feature_num))
        val_y = np.sum(val_x * weights, axis=-1) + bias
        val_y = np.expand_dims(val_y, -1)
        val_data = {'x': val_x, 'y': val_y}
        for each_client in range(1, client_num + 1):
            data[each_client]['val'] = val_data

        # server_data
        data[0] = dict()
        data[0]['train'] = None
        data[0]['val'] = val_data
        data[0]['test'] = test_data

        if save_data:
            # server_data = dict()
            save_client_data = dict()

            for client_idx in range(0, client_num + 1):
                if client_idx == 0:
                    filename = 'data/server_data'
                else:
                    filename = 'data/client_{:d}_data'.format(client_idx)
                with open(filename, 'wb') as f:
                    save_client_data['train'] = {
                        k: v.tolist()
                        for k, v in data[client_idx]['train'].items()
                    }
                    save_client_data['val'] = {
                        k: v.tolist()
                        for k, v in data[client_idx]['val'].items()
                    }
                    save_client_data['test'] = {
                        k: v.tolist()
                        for k, v in data[client_idx]['test'].items()
                    }
                    pickle.dump(save_client_data, f)

        return data

    if generate:
        data = _generate_data(client_num=config.federate.client_num,
                              save_data=config.eval.save_data)
    else:
        with open(config.distribute.data_file, 'rb') as f:
            data = pickle.load(f)
        for key in data.keys():
            data[key] = {k: np.asarray(v)
                         for k, v in data[key].items()
                         } if data[key] is not None else None

    return data, config


def load_external_data(config=None):
    r""" Based on the configuration file, this function imports external datasets and applies train/valid/test splits
    and split by some specific `splitter` into the standard FederatedScope input data format.

    Args:
        config: `CN` from `federatedscope/core/configs/config.py`

    Returns:
        data_local_dict: dict of split dataloader.
                        Format:
                            {
                                'client_id': {
                                    'train': DataLoader(),
                                    'test': DataLoader(),
                                    'val': DataLoader()
                                }
                            }
        modified_config: `CN` from `federatedscope/core/configs/config.py`, which might be modified in the function.

    """

    import torch
    import inspect
    from importlib import import_module
    from torch.utils.data import DataLoader
    from federatedscope.core.auxiliaries.splitter_builder import get_splitter
    from federatedscope.core.auxiliaries.transform_builder import get_transform

    def get_func_args(func):
        sign = inspect.signature(func).parameters.values()
        sign = set([val.name for val in sign])
        return sign

    def filter_dict(func, kwarg):
        sign = get_func_args(func)
        common_args = sign.intersection(kwarg.keys())
        filtered_dict = {key: kwarg[key] for key in common_args}
        return filtered_dict

    def load_torchvision_data(name, splits=None, config=None):
        dataset_func = getattr(import_module('torchvision.datasets'), name)
        transform_funcs = get_transform(config, 'torchvision')
        if config.data.args:
            raw_args = config.data.args[0]
        else:
            raw_args = {}
        if 'download' not in raw_args.keys():
            raw_args.update({'download': True})
        filtered_args = filter_dict(dataset_func.__init__, raw_args)
        func_args = get_func_args(dataset_func.__init__)

        # Perform split on different dataset
        if 'train' in func_args:
            # Split train to (train, val)
            dataset_train = dataset_func(root=config.data.root,
                                         train=True,
                                         **filtered_args,
                                         **transform_funcs)
            dataset_val = None
            dataset_test = dataset_func(root=config.data.root,
                                        train=False,
                                        **filtered_args,
                                        **transform_funcs)
            if splits:
                train_size = int(splits[0] * len(dataset_train))
                val_size = len(dataset_train) - train_size
                lengths = [train_size, val_size]
                dataset_train, dataset_val = torch.utils.data.dataset.random_split(
                    dataset_train, lengths)

        elif 'split' in func_args:
            # Use raw split
            dataset_train = dataset_func(root=config.data.root,
                                         split='train',
                                         **filtered_args,
                                         **transform_funcs)
            dataset_val = dataset_func(root=config.data.root,
                                       split='valid',
                                       **filtered_args,
                                       **transform_funcs)
            dataset_test = dataset_func(root=config.data.root,
                                        split='test',
                                        **filtered_args,
                                        **transform_funcs)
        elif 'classes' in func_args:
            # Use raw split
            dataset_train = dataset_func(root=config.data.root,
                                         classes='train',
                                         **filtered_args,
                                         **transform_funcs)
            dataset_val = dataset_func(root=config.data.root,
                                       classes='valid',
                                       **filtered_args,
                                       **transform_funcs)
            dataset_test = dataset_func(root=config.data.root,
                                        classes='test',
                                        **filtered_args,
                                        **transform_funcs)
        else:
            # Use config.data.splits
            dataset = dataset_func(root=config.data.root,
                                   **filtered_args,
                                   **transform_funcs)
            train_size = int(splits[0] * len(dataset))
            val_size = int(splits[1] * len(dataset))
            test_size = len(dataset) - train_size - val_size
            lengths = [train_size, val_size, test_size]
            dataset_train, dataset_val, dataset_test = torch.utils.data.dataset.random_split(
                dataset, lengths)

        data_dict = {
            'train': dataset_train,
            'val': dataset_val,
            'test': dataset_test
        }

        return data_dict

    def load_torchtext_data(name, splits=None, config=None):
        from torch.nn.utils.rnn import pad_sequence
        from federatedscope.nlp.dataset.utils import label_to_index

        dataset_func = getattr(import_module('torchtext.datasets'), name)
        if config.data.args:
            raw_args = config.data.args[0]
        else:
            raw_args = {}
        assert 'max_len' in raw_args, "Miss key 'max_len' in `config.data.args`."
        filtered_args = filter_dict(dataset_func.__init__, raw_args)
        dataset = dataset_func(root=config.data.root, **filtered_args)

        # torchtext.transforms requires >= 0.12.0 and torch = 1.11.0,
        # so we do not use `get_transform` in torchtext.

        # Merge all data and tokenize
        x_list = []
        y_list = []
        for data_iter in dataset:
            data, targets = [], []
            for i, item in enumerate(data_iter):
                data.append(item[1])
                targets.append(item[0])
            x_list.append(data)
            y_list.append(targets)

        x_all, y_all = [], []
        for i in range(len(x_list)):
            x_all += x_list[i]
            y_all += y_list[i]

        if config.model.type.endswith('transformers'):
            from transformers import AutoTokenizer

            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    config.model.type.split('@')[0],
                    local_files_only=False,
                    cache_dir=os.path.join(os.getcwd(), "huggingface"))
            except:
                logging.error("")

            x_all = tokenizer(x_all,
                              return_tensors='pt',
                              padding=True,
                              truncation=True,
                              max_length=raw_args['max_len'])
            data = [{key: value[i]
                     for key, value in x_all.items()}
                    for i in range(len(next(iter(x_all.values()))))]
            if 'classification' in config.model.task.lower():
                targets = label_to_index(y_all)
            else:
                y_all = tokenizer(y_all,
                                  return_tensors='pt',
                                  padding=True,
                                  truncation=True,
                                  max_length=raw_args['max_len'])
                targets = [{key: value[i]
                            for key, value in y_all.items()}
                           for i in range(len(next(iter(y_all.values()))))]
        else:
            from torchtext.data import get_tokenizer
            tokenizer = get_tokenizer("basic_english")
            if len(config.data.transform) == 0:
                raise ValueError(
                    "`transform` must be one pretrained Word Embeddings from \
                    ['GloVe', 'FastText', 'CharNGram']")
            if len(config.data.transform) == 1:
                config.data.transform.append({})
            vocab = getattr(import_module('torchtext.vocab'),
                            config.data.transform[0])(
                                dim=config.model.in_channels,
                                **config.data.transform[1])

            if 'classification' in config.model.task.lower():
                data = [
                    vocab.get_vecs_by_tokens(tokenizer(x),
                                             lower_case_backup=True)
                    for x in x_all
                ]
                targets = label_to_index(y_all)
            else:
                data = [
                    vocab.get_vecs_by_tokens(tokenizer(x),
                                             lower_case_backup=True)
                    for x in x_all
                ]
                targets = [
                    vocab.get_vecs_by_tokens(tokenizer(y),
                                             lower_case_backup=True)
                    for y in y_all
                ]
                targets = pad_sequence(targets).transpose(
                    0, 1)[:, :raw_args['max_len'], :]
            data = pad_sequence(data).transpose(0,
                                                1)[:, :raw_args['max_len'], :]
        # Split data to raw
        num_items = [len(ds) for ds in x_list]
        data_list, cnt = [], 0
        for num in num_items:
            data_list.append([
                (x, y)
                for x, y in zip(data[cnt:cnt + num], targets[cnt:cnt + num])
            ])
            cnt += num

        if len(data_list) == 3:
            # Use raw splits
            data_dict = {
                'train': data_list[0],
                'val': data_list[1],
                'test': data_list[2]
            }
        elif len(data_list) == 2:
            # Split train to (train, val)
            data_dict = {
                'train': data_list[0],
                'val': None,
                'test': data_list[1]
            }
            if splits:
                train_size = int(splits[0] * len(data_dict['train']))
                val_size = len(data_dict['train']) - train_size
                lengths = [train_size, val_size]
                data_dict['train'], data_dict[
                    'val'] = torch.utils.data.dataset.random_split(
                        data_dict['train'], lengths)
        else:
            # Use config.data.splits
            data_dict = {}
            train_size = int(splits[0] * len(data_list[0]))
            val_size = int(splits[1] * len(data_list[0]))
            test_size = len(data_list[0]) - train_size - val_size
            lengths = [train_size, val_size, test_size]
            data_dict['train'], data_dict['val'], data_dict[
                'test'] = torch.utils.data.dataset.random_split(
                    data_list[0], lengths)

        return data_dict

    def load_torchaudio_data(name, splits=None, config=None):
        import torchaudio

        dataset_func = getattr(import_module('torchaudio.datasets'), name)
        raise NotImplementedError

    def load_torch_geometric_data(name, splits=None, config=None):
        import torch_geometric

        dataset_func = getattr(import_module('torch_geometric.datasets'), name)
        raise NotImplementedError

    def load_huggingface_datasets_data(name, splits=None, config=None):
        from datasets import load_dataset
        from datasets import load_from_disk

        if config.data.args:
            raw_args = config.data.args[0]
        else:
            raw_args = {}
        assert 'max_len' in raw_args, "Miss key 'max_len' in `config.data.args`."
        filtered_args = filter_dict(load_dataset, raw_args)
        logger.info("To load huggingface dataset")
        if "hg_cache_dir" in raw_args:
            hugging_face_path = raw_args["hg_cache_dir"]
        else:
            hugging_face_path = os.getcwd()

        if "load_disk_dir" in raw_args:
            dataset = load_from_disk(raw_args["load_disk_dir"])
        else:
            dataset = load_dataset(path=config.data.root,
                                   name=name,
                                   **filtered_args)
        if config.model.type.endswith('transformers'):
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
            from transformers import AutoTokenizer
            logger.info("To load huggingface tokenizer")
            tokenizer = AutoTokenizer.from_pretrained(
                config.model.type.split('@')[0],
                local_files_only=False,
                cache_dir=os.path.join(hugging_face_path, "transformers"))

        for split in dataset:
            x_all = [i['sentence'] for i in dataset[split]]
            targets = [i['label'] for i in dataset[split]]

            if split == "train" and "used_train_ratio" in raw_args and 1 > raw_args[
                    'used_train_ratio'] > 0:
                selected_idx = [i for i in range(len(dataset[split]))]
                shuffle(selected_idx)
                selected_idx = selected_idx[:int(
                    len(selected_idx) * raw_args['used_train_ratio'])]
                x_all = [
                    element for i, element in enumerate(x_all)
                    if i in selected_idx
                ]
                targets = [
                    element for i, element in enumerate(targets)
                    if i in selected_idx
                ]

            x_all = tokenizer(x_all,
                              return_tensors='pt',
                              padding=True,
                              truncation=True,
                              max_length=raw_args['max_len'])
            data = [{key: value[i]
                     for key, value in x_all.items()}
                    for i in range(len(next(iter(x_all.values()))))]
            dataset[split] = (data, targets)
        data_dict = {
            'train': [(x, y)
                      for x, y in zip(dataset['train'][0], dataset['train'][1])
                      ],
            'val': [(x, y) for x, y in zip(dataset['validation'][0],
                                           dataset['validation'][1])],
            'test': [
                (x, y) for x, y in zip(dataset['test'][0], dataset['test'][1])
            ] if (set(dataset['test'][1]) - set([-1])) else None,
        }
        original_train_size = len(data_dict["train"])

        if "half_val_dummy_test" in raw_args and raw_args[
                "half_val_dummy_test"]:
            # since the "test" set from GLUE dataset may be masked, we need to submit to get the ground-truth,
            # for fast FL experiments, we split the validation set into two parts with the same size as new test/val
            original_val = [(x, y) for x, y in zip(dataset['validation'][0],
                                                   dataset['validation'][1])]
            data_dict["val"], data_dict[
                "test"] = original_val[:len(original_val) //
                                       2], original_val[len(original_val) //
                                                        2:]
        if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]:
            # use the validation set as tmp test set, and partial training set as validation set
            data_dict["test"] = data_dict["val"]
            data_dict["val"] = []
        if "part_train_dummy_val" in raw_args and 1 > raw_args[
                "part_train_dummy_val"] > 0:
            new_val_part = int(original_train_size *
                               raw_args["part_train_dummy_val"])
            data_dict["val"].extend(data_dict["train"][:new_val_part])
            data_dict["train"] = data_dict["train"][new_val_part:]
        if "part_train_dummy_test" in raw_args and 1 > raw_args[
                "part_train_dummy_test"] > 0:
            new_test_part = int(original_train_size *
                                raw_args["part_train_dummy_test"])
            data_dict["test"] = data_dict["val"]
            if data_dict["test"] is not None:
                data_dict["test"].extend(data_dict["train"][:new_test_part])
            else:
                data_dict["test"] = (data_dict["train"][:new_test_part])
            data_dict["train"] = data_dict["train"][new_test_part:]

        return data_dict

    def load_openml_data(tid, splits=None, config=None):
        import openml
        from sklearn.model_selection import train_test_split

        task = openml.tasks.get_task(int(tid))
        did = task.dataset_id
        dataset = openml.datasets.get_dataset(did)
        data, targets, _, _ = dataset.get_data(
            dataset_format="array", target=dataset.default_target_attribute)

        train_data, test_data, train_targets, test_targets = train_test_split(
            data, targets, train_size=splits[0], random_state=config.seed)
        val_data, test_data, val_targets, test_targets = train_test_split(
            test_data,
            test_targets,
            train_size=splits[1] / (1. - splits[0]),
            random_state=config.seed)
        data_dict = {
            'train': [(x, y) for x, y in zip(train_data, train_targets)],
            'val': [(x, y) for x, y in zip(val_data, val_targets)],
            'test': [(x, y) for x, y in zip(test_data, test_targets)]
        }
        return data_dict

    DATA_LOAD_FUNCS = {
        'torchvision': load_torchvision_data,
        'torchtext': load_torchtext_data,
        'torchaudio': load_torchaudio_data,
        'torch_geometric': load_torch_geometric_data,
        'huggingface_datasets': load_huggingface_datasets_data,
        'openml': load_openml_data
    }

    modified_config = config.clone()

    # Load dataset
    splits = modified_config.data.splits
    name, package = modified_config.data.type.split('@')

    dataset = DATA_LOAD_FUNCS[package.lower()](name, splits, modified_config)
    splitter = get_splitter(modified_config)

    data_local_dict = {
        x: {}
        for x in range(1, modified_config.federate.client_num + 1)
    }

    # Build dict of Dataloader
    for split in dataset:
        if dataset[split] is None:
            continue
        all_ds = splitter(dataset[split])
        for i, ds in enumerate(all_ds):
            if split == 'train':
                data_local_dict[i + 1][split] = DataLoader(
                    ds,
                    batch_size=modified_config.data.batch_size,
                    shuffle=True,
                    num_workers=modified_config.data.num_workers)
            else:
                data_local_dict[i + 1][split] = DataLoader(
                    ds,
                    batch_size=modified_config.data.batch_size,
                    shuffle=False,
                    num_workers=modified_config.data.num_workers)

    return data_local_dict, modified_config


def get_data(config):
    # fix the seed for data generation, will restore the user-specified on after the generation
    setup_seed(12345)
    for func in register.data_dict.values():
        data_and_config = func(config)
        if data_and_config is not None:
            return data_and_config
    if config.data.type.lower() == 'toy':
        data, modified_config = load_toy_data(config)
    elif config.data.type.lower() in ['femnist', 'celeba']:
        from federatedscope.cv.dataloader import load_cv_dataset
        data, modified_config = load_cv_dataset(config)
    elif config.data.type.lower() in [
            'shakespeare', 'twitter', 'subreddit', 'synthetic'
    ]:
        from federatedscope.nlp.dataloader import load_nlp_dataset
        data, modified_config = load_nlp_dataset(config)
    elif config.data.type.lower() in [
            'cora',
            'citeseer',
            'pubmed',
            'dblp_conf',
            'dblp_org',
    ] or config.data.type.lower().startswith('csbm'):
        from federatedscope.gfl.dataloader import load_nodelevel_dataset
        data, modified_config = load_nodelevel_dataset(config)
    elif config.data.type.lower() in ['ciao', 'epinions', 'fb15k-237', 'wn18']:
        from federatedscope.gfl.dataloader import load_linklevel_dataset
        data, modified_config = load_linklevel_dataset(config)
    elif config.data.type.lower() in [
            'hiv', 'proteins', 'imdb-binary'
    ] or config.data.type.startswith('graph_multi_domain'):
        from federatedscope.gfl.dataloader import load_graphlevel_dataset
        data, modified_config = load_graphlevel_dataset(config)
    elif config.data.type.lower() == 'vertical_fl_data':
        from federatedscope.vertical_fl.dataloader import load_vertical_data
        data, modified_config = load_vertical_data(config, generate=True)
    elif 'movielens' in config.data.type.lower():
        from federatedscope.mf.dataloader import load_mf_dataset
        data, modified_config = load_mf_dataset(config)
    elif '@' in config.data.type.lower():
        data, modified_config = load_external_data(config)
    else:
        raise ValueError('Data {} not found.'.format(config.data.type))

    setup_seed(config.seed)

    # get the statistics about the used data
    data_num_all_client = defaultdict(list)
    logger.info(
        f"For data={config.data.type} with subsample={config.data.subsample}, the client_num is {len(data)}"
    )
    for client_id, ds_ci in data.items():
        if client_id == 0:
            # skip the data holds on server
            continue
        if isinstance(ds_ci, dict):
            for split_name, ds in ds_ci.items():
                try:
                    import torch
                    from federatedscope.mf.dataloader import MFDataLoader
                    if isinstance(
                            ds,
                        (torch.utils.data.Dataset, list)) or issubclass(
                            type(ds), torch.utils.data.Dataset):
                        data_num_all_client[split_name].append(len(ds))

                    elif isinstance(
                            ds,
                        (torch.utils.data.DataLoader, list)) or issubclass(
                            type(ds), torch.utils.data.DataLoader):
                        data_num_all_client[split_name].append(len(ds.dataset))
                        if config.data.labelwise_boxplot and "cifar" in config.data.type.lower():
                            from collections import Counter
                            all_labels = [ds.dataset[i][1] for i in range(len(ds.dataset))]
                            label_wise_cnt = Counter(all_labels)
                            for label, cnt in label_wise_cnt.items():
                                data_num_all_client[label].append(cnt)

                    elif issubclass(type(ds), MFDataLoader):
                        data_num_all_client[split_name].append(ds.n_rating)
                except:
                    if isinstance(ds, list):
                        data_num_all_client[split_name].append(len(ds))
        if config.data.type in ["cora", "citeseer", "pubmed"]:
            # node-wise classification
            from torch_geometric.data.data import Data
            import torch
            if isinstance(ds_ci, Data):
                for split_name in ["train_mask", "val_mask", "test_mask"]:
                    num_nodes = sum(ds_ci[split_name]).item()
                    data_num_all_client[split_name.split("_")[0]].append(
                        num_nodes)

    if config.data.plot_boxplot:
        index = []
        data_num_list = []
        for key, val in data_num_all_client.items():
            if config.data.labelwise_boxplot and key in ["train", "test", "val"]:
                continue
            index.append(key)
            data_num_list.append(val)
        if index[1] == "test" and index[2] == "val":
            index[1], index[2] = index[2], index[1]
            data_num_list[1], data_num_list[2] = data_num_list[
                2], data_num_list[1]
        import matplotlib.pyplot as plt
        import matplotlib.pylab as pylab
        plt.clf()
        label_size = 18.5
        ticks_size = 17
        title_size = 22.5
        legend_size = 17
        params = {
            'legend.fontsize': legend_size,
            'axes.labelsize': label_size,
            'axes.titlesize': title_size,
            'xtick.labelsize': ticks_size,
            'ytick.labelsize': ticks_size
        }
        if config.data.labelwise_boxplot:
            index_order = np.argsort(np.array(index))
            index = [index[i] for i in index_order]
            data_num_list = [data_num_list[i] for i in index_order]

        pylab.rcParams.update(params)
        ax = plt.subplot()
        ax.violinplot(data_num_list)
        ax.set_xticks(range(1, len(index) + 1))
        ax.set_xticklabels(index)
        ax.set_ylabel("#Samples Per Client")
        fig_name = f"{config.outdir}/visual_{config.data.type}.pdf"
        if config.data.labelwise_boxplot:
            fig_name = f"{config.outdir}/visual_{config.data.type}_label.pdf"
        plt.savefig(fig_name,
                    bbox_inches='tight',
                    pad_inches=0)
        plt.show()

    from scipy import stats
    all_split_merged_num = []
    for k, v in data_num_all_client.items():
        if all_split_merged_num == []:
            all_split_merged_num.extend(v)
        else:
            all_split_merged_num = [
                all_split_merged_num[i] + v[i] for i in range(len(v))
            ]
    data_num_all_client["all"] = all_split_merged_num
    for k, v in data_num_all_client.items():
        if len(v) == 0:
            logger.warning(
                "The data distribution statistics info are nor correctly logged, maybe you used a data type we haven't support"
            )
        else:
            stats_res = stats.describe(v)
            if stats_res.minmax[1] == 0:
                logger.warning(
                    f"For data split {k}, the max sample num in the client is 0. Please check w    hether this is as you would like it to be"
                )
            logger.info(
                f"For data split {k}, the stats_res over all client is {stats_res}, the meadian is {sorted(v)[len(v) // 2]}, std is {math.sqrt(stats_res.variance)}"
            )

    return data, modified_config


def merge_data(all_data, merged_max_data_id):
    dataset_names = list(all_data[1].keys())  # e.g., train, test, val
    import torch.utils.data
    from federatedscope.mf.dataloader import MFDataLoader
    assert isinstance(all_data[1]["test"], (dict, torch.utils.data.DataLoader, MFDataLoader)), \
            "the data should be organized as the format similar to the following format" \
            "1): {data_id: {train: {x:ndarray, y:ndarray}} }" \
            "2): {data_id: {train: DataLoader }"
    if isinstance(all_data[1]["test"], dict):
        data_elem_names = list(all_data[1]["test"].keys())  # e.g., x, y
        merged_data = {name: defaultdict(list) for name in dataset_names}
        for data_id in range(1, merged_max_data_id):
            for d_name in dataset_names:
                for elem_name in data_elem_names:
                    merged_data[d_name][elem_name].append(
                        all_data[data_id][d_name][elem_name])

        for d_name in dataset_names:
            for elem_name in data_elem_names:
                merged_data[d_name][elem_name] = np.concatenate(
                    merged_data[d_name][elem_name])
    elif issubclass(type(all_data[1]["test"]), torch.utils.data.DataLoader):
        merged_data = {name: all_data[1][name] for name in dataset_names}
        for data_id in range(2, merged_max_data_id):
            for d_name in dataset_names:
                merged_data[d_name].dataset.extend(
                    all_data[data_id][d_name].dataset)
    else:
        merged_data = None
    return merged_data
