import logging
import numpy as np
import random

from .splitter_builder import get_splitter
from .base_data import ClientData, StandaloneDataDict

logger = logging.getLogger(__name__)


class BaseDataTranslator:
    """
    Translator is a tool to convert a centralized dataset to \
    ``StandaloneDataDict``, which is the input data of runner.

    Notes:
        The ``Translator`` is consist of several stages:

        Dataset -> ML split (``split_train_val_test()``) -> \
        FL split (``split_to_client()``) -> ``StandaloneDataDict``

    """
    def __init__(self, global_cfg, client_cfgs=None):
        """
        Convert data to `StandaloneDataDict`.

        Args:
            global_cfg: global CfgNode
            client_cfgs: client cfg `Dict`
        """
        self.global_cfg = global_cfg
        self.client_cfgs = client_cfgs
        self.splitter = get_splitter(global_cfg)

    def __call__(self, dataset):
        """
        Args:
            dataset: `torch.utils.data.Dataset`, `List` of (feature, label)
                or split dataset tuple of (train, val, test) or Tuple of
                split dataset with [train, val, test]

        Returns:
            datadict: instance of `StandaloneDataDict`, which is a subclass of
            `dict`.
        """
        datadict = self.split(dataset)
        # datadict = StandaloneDataDict(datadict, self.global_cfg)

        return datadict

    def split(self, dataset):
        """
        Perform ML split and FL split.

        Returns:
            dict of ``ClientData`` with client_idx as key to build \
            ``StandaloneDataDict``
        """
        train, val, test = self.split_train_val_test(dataset)
        datadict = self.split_to_client(train, val, test)
        # import pickle
        # with open('train.pkl', 'wb') as f:
        #     pickle.dump(train, f)
        # with open('val.pkl', 'wb') as f:
        #     pickle.dump(val, f)
        # with open('test.pkl', 'wb') as f:
        #     pickle.dump(test, f)
        return datadict

    def split_train_val_test(self, dataset, cfg=None):
        """
        Split dataset to train, val, test if not provided.

        Returns:
             List: List of split dataset, like ``[train, val, test]``
        """
        from torch.utils.data import Dataset, Subset

        if cfg is not None:
            splits = cfg.data.splits
        else:
            splits = self.global_cfg.data.splits
        if isinstance(dataset, tuple):
            # No need to split train/val/test for tuple dataset.
            error_msg = 'If dataset is tuple, it must contains ' \
                        'train, valid and test split.'
            assert len(dataset) == len(['train', 'val', 'test']), error_msg
            return [dataset[0], dataset[1], dataset[2]]

        index = np.random.permutation(np.arange(len(dataset)))
        train_size = int(splits[0] * len(dataset))
        val_size = int(splits[1] * len(dataset))

        if isinstance(dataset, Dataset):
            train_dataset = Subset(dataset, index[:train_size])
            val_dataset = Subset(dataset,
                                 index[train_size:train_size + val_size])
            test_dataset = Subset(dataset, index[train_size + val_size:])
        else:
            train_dataset = [dataset[x] for x in index[:train_size]]
            val_dataset = [
                dataset[x] for x in index[train_size:train_size + val_size]
            ]
            test_dataset = [dataset[x] for x in index[train_size + val_size:]]
        return train_dataset, val_dataset, test_dataset

    def split_to_client(self, train, val, test):
        """
        Split dataset to clients and build ``ClientData``.

        Returns:
            dict: dict of ``ClientData`` with ``client_idx`` as key.
        """

        # Initialization
        client_num = self.global_cfg.federated.num_clients
        train_client_size = self.global_cfg.federated.get('train_client_size', None)
        eval_client_size = self.global_cfg.federated.get('eval_client_size', None)
        random_seed = self.global_cfg.federated.get('random_seed', None)

        # Original splitting logic
        split_train, split_val, split_test = [[None] * client_num] * 3
        train_label_distribution = None

        if len(train) > 0:
            split_train = self.splitter(train)
            if self.global_cfg.data.consistent_label_distribution:
                try:
                    train_label_distribution = [[j[1] for j in x] for x in split_train]
                except:
                    logger.warning('Cannot access train label distribution for splitter, split dataset without considering train label.')
        if len(val) > 0:
            split_val = self.splitter(val, prior=train_label_distribution)
        if len(test) > 0:
            split_test = self.splitter(test, prior=train_label_distribution)

        # Reduce the number of datapoints in each client's dataset if train_client_size is set
        if train_client_size is not None:
            split_train = [self._reduce_dataset(x, train_client_size, f'train {i}', random_seed) for i, x in enumerate(split_train)]

            # Merge the reduced split_train to create a new train dataset for the server
            train = self._merge_datasets(split_train)
        
        if eval_client_size is not None:
            split_val = [self._reduce_dataset(x, eval_client_size, f'val {i}', random_seed) for i, x in enumerate(split_val)]
            split_test = [self._reduce_dataset(x, eval_client_size, f'test {i}', random_seed) for i, x in enumerate(split_test)]
            val = self._merge_datasets(split_val)
            test = self._merge_datasets(split_test)

        # Build data dict with `ClientData`, key `0` for server.
        data_dict = {
            0: ClientData(self.global_cfg, train=train, val=val, test=test)
        }
        for client_id in range(1, client_num + 1):
            if self.client_cfgs is not None:
                client_cfg = self.global_cfg.clone()
                client_cfg.merge_from_other_cfg(
                    self.client_cfgs.get(f'client_{client_id}'))
            else:
                client_cfg = self.global_cfg
            data_dict[client_id] = ClientData(client_cfg,
                                              train=split_train[client_id - 1],
                                              val=split_val[client_id - 1],
                                              test=split_test[client_id - 1])
        return data_dict

    def _reduce_dataset(self, dataset, target_size, dataset_name, random_seed=None):
        """
        Reduce the size of a dataset to the target size using a random seed for reproducibility.
        
        Args:
            dataset: The dataset to reduce.
            target_size: The desired size of the dataset.
            dataset_name: Name of the dataset for logging purposes.
            random_seed: Random seed for reproducibility (optional).
        
        Returns:
            A reduced dataset.
        """
        from torch.utils.data import Dataset, Subset
        
        if len(dataset) <= target_size:
            logger.warning(f"{dataset_name} size already smaller than the size required.")
            return dataset
        
        if random_seed is not None:
            random.seed(random_seed)
        
        if isinstance(dataset, Dataset):
            indices = random.sample(range(len(dataset)), target_size)
            return Subset(dataset, indices)
        else:
            return random.sample(dataset, target_size)

    def _merge_datasets(self, datasets):
        """
        Merge multiple datasets into a single dataset.
        
        Args:
            datasets: List of datasets to merge.
        
        Returns:
            A merged dataset.
        """
        from torch.utils.data import Dataset, ConcatDataset

        if not datasets:
            return []

        if isinstance(datasets[0], Dataset):
            return ConcatDataset(datasets)
        else:
            return [item for dataset in datasets for item in dataset]
