from datasets import concatenate_datasets, load_dataset
from datasets import Dataset, DatasetDict
from typing import List, Union, Optional, Tuple, Dict
from openicl.icl_retriever import BaseRetriever
import numpy as np
from tqdm import trange
from queue import Queue


def collect_samples(
    idx_dict: Dict[int, List[List[int]]], retrievers: List[BaseRetriever]
) -> List[List[Dataset]]:
    """Collect selected samples from each client's local dataset using ``idx_dict``.

    Args:
        idx_dict (Dict[int, List[List[int]]]): Dict of lists, with each list contains list of indices for selected samples. ``{cid1: [[idx1, idx2, ...], [idx3, idx4, ...], ...], cid2: [[idx5, idx6, ...], [idx7, idx8, ...], ...], ...}``
        retrievers (List[BaseRetriever]): List of data retrievers. Each retriever contains local training dataset.

    Returns:
        List[List[Dataset]]:  .. code-block:: python

                                 [[q1_client1_ice_Dataset, q1_client2_ice_Dataset, q1_client3_ice_Dataset, ...],  # for query 1
                                 [q2_client1_ice_Dataset, q2_client2_ice_Dataset, q2_client3_ice_Dataset, ...],  # for query 2
                                 [q3_client1_ice_Dataset, q3_client2_ice_Dataset, q3_client3_ice_Dataset, ...],  # for query 3
                                ...]
    """
    num_clients = len(idx_dict)
    query_num = len(idx_dict[0])
    samples_for_all_query = []
    for query_idx in trange(
        query_num, disable=False, desc="Collect ICE samples from each test query: "
    ):
        samples = []
        for cid in range(num_clients):
            dataset = retrievers[cid].index_ds
            single_client_samples = dataset.select(idx_dict[cid][query_idx])
            samples.append(single_client_samples)
        samples_for_all_query.append(samples)
    return samples_for_all_query


def merge_concatenate(
    subsets: List[Dataset],
    chunk_sample_num: Optional[int] = None,
    seed: Optional[int] = None,
) -> Dataset:
    """Merge multiple subsets into one ``Dataset``. The merging strategy is similar to merge-sort algorithm.
    For example, if ``subsets`` is ``[Dataset(sample_1, sample_2), Dataset(sample_3, sample_4), Dataset(sample_5, sample_6, sample_7)]``,
    then we first concatenate the fist sample from each subset to form ``Dataset(sample_1, sample_3, sample_5)``, then the second sample from each subset to form
    ``Dataset(sample_2, sample_4, sample_6)``, then the last sample ```Dataset(sample_7)`. And finally merge them together to form
    ``Dataset(sample_1, sample_3, sample_5, sample_2, sample_4, sample_6, sample_7)``.

    Notice: this function is pretty SLOW...

    Args:
        subsets (List[Dataset]): List of ``Dataset``, with each represents the selected samples from single client.

    Returns:
        Dataset: _description_
    """
    sample_batchs = []
    subset_idxs = []
    subsets_num = len(subsets)
    total_sample_cnt = 0
    # prepare indices of each subsets for merging
    for subset in subsets:
        sample_num = len(subset)
        total_sample_cnt += sample_num
        q = Queue(maxsize=sample_num)
        list(map(q.put, range(sample_num)))
        subset_idxs.append(q)

    subset_not_empty_flag = list(range(subsets_num))

    # while total_sample_cnt > 0:
    while len(subset_not_empty_flag) > 0:
        cur_batch = []
        to_remove = []
        for cid in subset_not_empty_flag:
            if not subset_idxs[cid].empty():
                sample_idx = subset_idxs[
                    cid
                ].get()  # first sample idx in current subset
                cur_batch.append(
                    subsets[cid][sample_idx]
                )  # add sample to current batch
            else:
                to_remove.append(cid)
                # subset_not_empty_flag.remove(cid)  # remove empty subset's flag
                continue

        # remove empty subset's flag at the end of each single loop
        for cid in to_remove:
            subset_not_empty_flag.remove(cid)

        if len(cur_batch) > 0:
            sample_batchs.append(Dataset.from_list(cur_batch))

    dataset = concatenate_datasets(sample_batchs)
    if chunk_sample_num is not None and chunk_sample_num < len(dataset):
        dataset = dataset.select(list(range(chunk_sample_num)))
    # print('Server side ICE dataset construction done.')
    return dataset


def simple_concatenate(
    subsets: List[Dataset],
    chunk_sample_num: Optional[int] = None,
    seed: Optional[int] = None,
) -> Dataset:
    """Merge multiple subsets into one ``Dataset``. Just concatenate subsets using ``datasets.concatenate_datasets``

    Args:
        subsets (List[Dataset]): List of ``Dataset``, with each represents the selected samples from single client.

    Returns:
        Dataset: _description_
    """
    dataset = concatenate_datasets(subsets)
    # if chunk_sample_num is not None:
    #     dataset = dataset.select(list(range(chunk_sample_num)))
    return dataset


def random_concatenate(
    subsets: List[Dataset], chunk_sample_num: Optional[int] = None, seed: int = 1
) -> Dataset:
    """Merge multiple subsets into one ``Dataset`` randonly. Just concatenate subsets using ``datasets.concatenate_datasets``, then randomly shuffle, and chuncate.

    Args:
        subsets (List[Dataset]): _description_
        chunk_sample_num (Optional[int], optional): _description_. Defaults to None.
        seed (int, optional): _description_. Defaults to 1.

    Returns:
        Dataset: _description_
    """
    dataset = concatenate_datasets(subsets)
    sample_num = len(dataset)
    shuffled_dataset = dataset.shuffle()
    # rng = np.random.default_rng(seed=seed)
    if chunk_sample_num is not None and chunk_sample_num < len(dataset):
        subset_idxs = np.random.choice(sample_num, size=chunk_sample_num, replace=False)
        shuffled_dataset = shuffled_dataset.select(subset_idxs)
    return shuffled_dataset
