import copy
import inspect
import logging
import os
import re
import ssl
import urllib.request

import numpy as np
import os.path as osp

from random import shuffle
from collections import defaultdict

logger = logging.getLogger(__name__)





def load_external_data(config=None):
    """
    Based on the configuration file, this function imports external \
    datasets and applies train/valid/test.

    Args:
        config: `CN` from `federatedscope/core/configs/config.py`

    Returns:
        (data, modified_config): tuple of ML split dataset, \
        and `CN` from `federatedscope/core/configs/config.py`, \
        which might be modified in the function.
    """

    import torch
    from importlib import import_module
    from torch.utils.data import DataLoader


    def load_huggingface_datasets_data(name, splits=None, config=None):
        import datasets
        from datasets import load_from_disk

        if config.data.args:
            raw_args = config.data.args[0]
        else:
            raw_args = {}
        assert 'max_len' in raw_args, "Miss key 'max_len' in " \
                                      "`config.data.args`."
        filtered_args = filter_dict(datasets.load_dataset, raw_args)
        logger.info("Begin to load huggingface dataset")
        if "hg_cache_dir" in raw_args:
            hugging_face_path = raw_args["hg_cache_dir"]
        else:
            hugging_face_path = os.getcwd()

        if "load_disk_dir" in raw_args:
            load_path = raw_args["load_disk_dir"]
            try:
                dataset = load_from_disk(load_path)
            except Exception as e:
                logging.error(f"When loading cached dataset form "
                              f"{load_path}, we faced the exception: \n "
                              f"{str(e)}")
        else:
            dataset = datasets.load_dataset(path=config.data.root,
                                            name=name,
                                            **filtered_args)
        if config.model.type.endswith('transformers'):
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
            from transformers import AutoTokenizer
            logger.info("To load huggingface tokenizer")
            tokenizer = AutoTokenizer.from_pretrained(
                config.model.type.split('@')[0],
                local_files_only=True,
                cache_dir=os.path.join(hugging_face_path, "transformers"))

        for split in dataset:
            x_all = [i['sentence'] for i in dataset[split]]
            targets = [i['label'] for i in dataset[split]]

            if split == "train" and "used_train_ratio" in raw_args and \
                    1 > raw_args['used_train_ratio'] > 0:
                selected_idx = [i for i in range(len(dataset[split]))]
                shuffle(selected_idx)
                selected_idx = selected_idx[:int(
                    len(selected_idx) * raw_args['used_train_ratio'])]
                x_all = [
                    element for i, element in enumerate(x_all)
                    if i in selected_idx
                ]
                targets = [
                    element for i, element in enumerate(targets)
                    if i in selected_idx
                ]

            x_all = tokenizer(x_all,
                              return_tensors='pt',
                              padding=True,
                              truncation=True,
                              max_length=raw_args['max_len'])
            data = [{key: value[i]
                     for key, value in x_all.items()}
                    for i in range(len(next(iter(x_all.values()))))]
            dataset[split] = (data, targets)
        data_split_dict = {
            'train': [(x, y)
                      for x, y in zip(dataset['train'][0], dataset['train'][1])
                      ],
            'val': [(x, y) for x, y in zip(dataset['validation'][0],
                                           dataset['validation'][1])],
            'test': [
                (x, y) for x, y in zip(dataset['test'][0], dataset['test'][1])
            ] if (set(dataset['test'][1]) - set([-1])) else None,
        }
        original_train_size = len(data_split_dict["train"])

        if "half_val_dummy_test" in raw_args and raw_args[
                "half_val_dummy_test"]:
            # since the "test" set from GLUE dataset may be masked, we need to
            # submit to get the ground-truth, for fast FL experiments,
            # we split the validation set into two parts with the same size as
            # new test/val data
            original_val = [(x, y) for x, y in zip(dataset['validation'][0],
                                                   dataset['validation'][1])]
            data_split_dict["val"], data_split_dict[
                "test"] = original_val[:len(original_val) //
                                       2], original_val[len(original_val) //
                                                        2:]
        if "val_as_dummy_test" in raw_args and raw_args["val_as_dummy_test"]:
            # use the validation set as tmp test set,
            # and partial training set as validation set
            data_split_dict["test"] = data_split_dict["val"]
            data_split_dict["val"] = []
        if "part_train_dummy_val" in raw_args and 1 > raw_args[
                "part_train_dummy_val"] > 0:
            new_val_part = int(original_train_size *
                               raw_args["part_train_dummy_val"])
            data_split_dict["val"].extend(
                data_split_dict["train"][:new_val_part])
            data_split_dict["train"] = data_split_dict["train"][new_val_part:]
        if "part_train_dummy_test" in raw_args and 1 > raw_args[
                "part_train_dummy_test"] > 0:
            new_test_part = int(original_train_size *
                                raw_args["part_train_dummy_test"])
            data_split_dict["test"] = data_split_dict["val"]
            if data_split_dict["test"] is not None:
                data_split_dict["test"].extend(
                    data_split_dict["train"][:new_test_part])
            else:
                data_split_dict["test"] = (
                    data_split_dict["train"][:new_test_part])
            data_split_dict["train"] = data_split_dict["train"][new_test_part:]

        return data_split_dict

    def load_openml_data(tid, splits=None, config=None):
        import openml
        from sklearn.model_selection import train_test_split

        task = openml.tasks.get_task(int(tid))
        did = task.dataset_id
        dataset = openml.datasets.get_dataset(did)
        data, targets, _, _ = dataset.get_data(
            dataset_format="array", target=dataset.default_target_attribute)

        train_data, test_data, train_targets, test_targets = train_test_split(
            data, targets, train_size=splits[0], random_state=config.seed)
        val_data, test_data, val_targets, test_targets = train_test_split(
            test_data,
            test_targets,
            train_size=splits[1] / (1. - splits[0]),
            random_state=config.seed)
        data_split_dict = {
            'train': [(x, y) for x, y in zip(train_data, train_targets)],
            'val': [(x, y) for x, y in zip(val_data, val_targets)],
            'test': [(x, y) for x, y in zip(test_data, test_targets)]
        }
        return data_split_dict

    DATA_LOAD_FUNCS = {
        'huggingface_datasets': load_huggingface_datasets_data,
        'openml': load_openml_data
    }

    modified_config = config.clone()

    # Load dataset
    splits = modified_config.data.splits
    name, package = modified_config.data.type.split('@')

    # Comply with the original train/val/test
    dataset = DATA_LOAD_FUNCS[package.lower()](name, splits, modified_config)
    data_split_tuple = (dataset.get('train'), dataset.get('val'),
                        dataset.get('test'))

    return data_split_tuple, modified_config


def convert_data_mode(data, config):
    """
    Convert ``StandaloneDataDict`` to ``ClientData`` in ``distributed`` mode.

    Args:
        data: ``StandaloneDataDict``
        config: configuration of FL course, see `federate.core.configs`

    Returns:
        ``StandaloneDataDict`` in ``standalone`` mode, or ``ClientData`` in \
        ``distributed`` mode.
    """
    if config.federate.mode.lower() == 'standalone':
        return data
    else:
        # Invalid data_idx
        if config.distribute.data_idx == -1:
            return data
        elif config.distribute.data_idx not in data.keys():
            data_idx = np.random.choice(list(data.keys()))
            logger.warning(
                f"The provided data_idx={config.distribute.data_idx} is "
                f"invalid, so that we randomly sample a data_idx as {data_idx}"
            )
        else:
            data_idx = config.distribute.data_idx
        return data[data_idx]


def get_func_args(func):
    """
    Get the set of arguments that the function expects.

    Args:
        func: function to be analysis

    Returns:
        Arguments  that the function expects
    """
    sign = inspect.signature(func).parameters.values()
    sign = set([val.name for val in sign])
    return sign


def filter_dict(func, kwarg):
    """
    Filters out the common keys of kwarg that are not in kwarg.

    Args:
        func: function to be filtered
        kwarg: dict to filter

    Returns:
        Filtered dict of arguments of the function.
    """
    sign = get_func_args(func)
    common_args = sign.intersection(kwarg.keys())
    filtered_dict = {key: kwarg[key] for key in common_args}
    return filtered_dict


def merge_data(all_data, merged_max_data_id=None, specified_dataset_name=None):
    """
    Merge data from client 1 to ``merged_max_data_id`` contained in given \
    ``all_data``.

    Args:
        all_data: ``StandaloneDataDict``
        merged_max_data_id: max merged data index
        specified_dataset_name: split name to be merged

    Returns:
        Merged data.
    """
    import torch.utils.data
    from .wrap_dataset import WrapDataset

    # Assert
    if merged_max_data_id is None:
        merged_max_data_id = len(all_data) - 1
    assert merged_max_data_id >= 1
    if specified_dataset_name is None:
        dataset_names = list(all_data[1].keys())  # e.g., train, test, val
    else:
        if not isinstance(specified_dataset_name, list):
            specified_dataset_name = [specified_dataset_name]
        dataset_names = specified_dataset_name
    assert len(dataset_names) >= 1, \
        "At least one sub-dataset is required in client 1"

    data_name = "test" if "test" in dataset_names else dataset_names[0]
    id_contain_all_dataset_key = -1
    # check the existence of the data to be merged
    for client_id in range(1, merged_max_data_id + 1):
        contain_all_dataset_key = True
        for dataset_name in dataset_names:
            if dataset_name not in all_data[client_id]:
                contain_all_dataset_key = False
                logger.warning(f'Client {client_id} does not contain '
                               f'dataset key {dataset_name}.')
        if id_contain_all_dataset_key == -1 and contain_all_dataset_key:
            id_contain_all_dataset_key = client_id
    assert id_contain_all_dataset_key != -1, \
        "At least one client within [1, merged_max_data_id] should contain " \
        "all the key for expected dataset names."

    if issubclass(type(all_data[id_contain_all_dataset_key][data_name]),
                  torch.utils.data.DataLoader):
        if isinstance(all_data[id_contain_all_dataset_key][data_name].dataset,
                      WrapDataset):
            # e.g., x, y
            data_elem_names = list(all_data[id_contain_all_dataset_key]
                                   [data_name].dataset.dataset.keys())
            merged_data = {name: defaultdict(list) for name in dataset_names}
            for data_id in range(1, merged_max_data_id + 1):
                for d_name in dataset_names:
                    if d_name not in all_data[data_id]:
                        continue
                    for elem_name in data_elem_names:
                        merged_data[d_name][elem_name].append(
                            all_data[data_id]
                            [d_name].dataset.dataset[elem_name])
            for d_name in dataset_names:
                for elem_name in data_elem_names:
                    merged_data[d_name][elem_name] = np.concatenate(
                        merged_data[d_name][elem_name])
                merged_data[d_name] = WrapDataset(merged_data[d_name])
        else:
            client_data = {
                key: []
                for key in all_data[id_contain_all_dataset_key].keys()
            }
            for data_id in range(1, merged_max_data_id + 1):
                for d_name in dataset_names:
                    if d_name not in all_data[data_id]:
                        continue
                    else:
                        client_data[d_name].append(
                            all_data[data_id][d_name].dataset)
            merged_data = {
                key: torch.utils.data.ConcatDataset(client_data[key])
                for key in dataset_names
            }
    else:
        raise NotImplementedError(
            "Un-supported type when merging data across different clients."
            f"Your data type is "
            f"{type(all_data[id_contain_all_dataset_key][data_name])}. "
            f"Currently we only support the following forms: "
            " 1): {data_id: {train: {x:ndarray, y:ndarray}} }"
            " 2): {data_id: {train: DataLoader }")
    return merged_data


def save_local_data(dir_path,
                    train_data=None,
                    train_targets=None,
                    test_data=None,
                    test_targets=None,
                    val_data=None,
                    val_targets=None):
    r"""
    Save data to disk. Source: \
    https://github.com/omarfoq/FedEM/blob/main/data/femnist/generate_data.py

    Args:
        train_data: x of train data
        train_targets: y of train data
        test_data: x of test data
        test_targets: y of test data
        val_data: x of validation data
        val_targets:y of validation data

    Note:
        save ``(`train_data`, `train_targets`)`` in ``{dir_path}/train.pt``, \
        ``(`val_data`, `val_targets`)`` in ``{dir_path}/val.pt`` \
        and ``(`test_data`, `test_targets`)`` in ``{dir_path}/test.pt``
    """
    import torch
    if (train_data is not None) and (train_targets is not None):
        torch.save((train_data, train_targets), osp.join(dir_path, "train.pt"))

    if (test_data is not None) and (test_targets is not None):
        torch.save((test_data, test_targets), osp.join(dir_path, "test.pt"))

    if (val_data is not None) and (val_targets is not None):
        torch.save((val_data, val_targets), osp.join(dir_path, "val.pt"))


def download_url(url: str, folder='folder'):
    """
    Downloads the content of an url to a folder. Modified from \
    https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric

    Args:
        url (string): The url of target file.
        folder (string): The target folder.

    Returns:
        string: File path of downloaded files.
    """

    file = url.rpartition('/')[2]
    file = file if file[0] == '?' else file.split('?')[0]
    path = osp.join(folder, file)
    if osp.exists(path):
        logger.info(f'File {file} exists, use existing file.')
        return path

    logger.info(f'Downloading {url}')
    os.makedirs(folder, exist_ok=True)
    ctx = ssl._create_unverified_context()
    data = urllib.request.urlopen(url, context=ctx)
    with open(path, 'wb') as f:
        f.write(data.read())

    return path
