from datasets import load_dataset, DatasetDict, Dataset
import argparse
import os
import json
from collections import Counter
import itertools
import numpy as np
from nltk.tokenize import word_tokenize
from typing import Dict, Optional, Union, Hashable
import math
import random

from util.verification import data_verification

data2numcls = {
    "sst2": 2,
    "gen-sst2": 2,
    "sst5": 5,
    "gen-sst5": 5,
    "amazon": 5,
    "yelp": 5,
    "gen-yelp": 5,
    "mr": 2,
    "cr": 2,
    "agnews": 4,
    "trec": 6,
    "dbpedia": 14,
    "yahoo": 10,
    "mnli": 3,
    "snli": 3,
    "rte": 2,
    "subj": 2,
    "gen-subj": 2,
    "cola": 2,
}

data2usename = {
    "sst2": "KaiLv/UDR_SST-2",
    "sst5": "KaiLv/UDR_SST-5",
    "yelp": "KaiLv/UDR_Yelp",
    "amazon": "KaiLv/UDR_Amazon",
    "mr": "KaiLv/UDR_MR",
    "cr": "KaiLv/UDR_CR",
    "agnews": "KaiLv/UDR_AGNews",
    "trec": "KaiLv/UDR_TREC",
    "dbpedia": "KaiLv/UDR_DBPedia",
    "yahoo": "KaiLv/UDR_Yahoo",
    "mnli": "KaiLv/UDR_MNLI",
    "snli": "KaiLv/UDR_SNLI",
    "rte": "KaiLv/UDR_RTE",
    "subj": "KaiLv/UDR_Subj",
    "cola": "KaiLv/UDR_COLA",
}


def _check_label_space(label_cnt, label_space):
    # for label in label_cnt:
    #     assert label in label_space
    for label in label_space:
        assert label in label_cnt


def _check_dataset_type(dataset, split):
    if isinstance(dataset, Dataset):
        data_to_split = dataset
    elif isinstance(dataset, DatasetDict):
        data_to_split = dataset[split]
    else:
        raise ValueError(
            f"dataset should be either datasets.Dataset or datasets.DatasetDict, rather than {type(dataset)}."
        )
    return data_to_split


def _check_remain_sample(sample_num_cumsum_per_class, label_cnt, label_space):
    resi_label_cnt = {}
    for label in label_space:
        resi_label_cnt[label] = (
            label_cnt[label] - sample_num_cumsum_per_class[label][-1]
        )
    print(f"Remained sample number for each class: {resi_label_cnt}")


def assign_label_to_client(num_classes, num_clients, major_classes_num):
    total_shards_num = int(major_classes_num * num_clients)
    tmp_label_space = list(range(num_classes))
    random.shuffle(tmp_label_space)
    candidate_label_seq = tmp_label_space * int(
        math.floor(total_shards_num / num_classes)
    )
    candidate_label_seq.extend(tmp_label_space[: (total_shards_num % num_classes)])
    label_space_cnt = Counter(candidate_label_seq)
    # print(len(candidate_label_seq), candidate_label_seq)

    per_client_label_space = []
    label_idx_sumsum = np.cumsum([major_classes_num for _ in range(num_clients)])
    label_assign = np.split(candidate_label_seq, label_idx_sumsum)
    per_client_label_space = [label_assign[cid].tolist() for cid in range(num_clients)]
    # print(per_client_label_space)

    return per_client_label_space, label_space_cnt


def get_per_class_samples(dataset, label_space):
    samples_per_class = {c: [] for c in label_space}
    for label in label_space:
        samples_per_class[label] = dataset.filter(
            lambda sample: sample["label"] == label
        )
    return samples_per_class


def cumsum_for_each(sample_num_per_class, label_space=None):
    if label_space is None:
        label_space = list(sample_num_per_class.keys())

    sample_num_cumsum_per_class = {}
    for label in label_space:
        sample_num_cumsum_per_class[label] = np.cumsum(
            sample_num_per_class[label]
        ).astype(int)

    return sample_num_cumsum_per_class


def split_indices(
    samples_per_class,
    sample_num_cumsum_per_class,
    num_clients=None,
    label_space=None,
    index_key="idx",
):
    if label_space is None:
        label_space = list(samples_per_class.keys())
    if num_clients is None:
        num_clients = len(list(sample_num_cumsum_per_class.values())[0])

    indices_client_dict = {cid: [] for cid in range(num_clients)}
    for label in label_space:
        indices = samples_per_class[label][index_key]
        random.shuffle(indices)
        indices_partition = np.split(indices, sample_num_cumsum_per_class[label])
        for cid in range(num_clients):
            indices_client_dict[cid].extend(indices_partition[cid].tolist())

    return indices_client_dict


def cls_iid_partition(
    dataset: Union[Dataset, DatasetDict],
    split: Optional[str] = "train",
    data_name: Optional[str] = "sst2",
    num_clients: Optional[int] = 5,
    test_split: Optional[str] = "test",
    subset_num: Optional[int] = None,
    seed: Optional[int] = 0,
    verbose: Optional[bool] = True,
):
    np.random.seed(seed)
    random.seed(seed)

    data_to_split = _check_dataset_type(dataset, split)
    num_samples = len(data_to_split)
    order_idx = list(range(num_samples))
    data_to_split = data_to_split.add_column("order_idx", order_idx)

    label_cnt = Counter(data_to_split["label"])
    print(f"Overall {split} sample number per class: {label_cnt}")

    num_classes = data2numcls[data_name]
    label_space = list(range(num_classes))
    _check_label_space(label_cnt, label_space)
    print(f"label_space: {label_space}")

    if num_clients == 1:
        data_dict = {
            f"{split}-client0": data_to_split.select(
                list(range(num_samples))
            ).remove_columns("order_idx")
        }
    else:
        # calculate the sample number for each class on each client (IID equally partition)
        sample_num_per_class = {}
        sample_num_cumsum_per_class = {}
        for label in label_space:
            cnt = label_cnt[label]
            sample_num = int(math.floor(cnt / num_clients))
            sample_num_per_class[label] = [sample_num for _ in range(num_clients)]

        sample_num_cumsum_per_class = cumsum_for_each(
            sample_num_per_class, label_space
        )  # {label_1: [...], label_2: [...], ...}

        # calculate remained sample number for each class
        _check_remain_sample(sample_num_cumsum_per_class, label_cnt, label_space)

        # filter out sample indices for each class
        samples_per_class = get_per_class_samples(data_to_split, label_space)

        # partition sample indices to client based on per-class sample number on each client (indices_per_class)
        indices_client_dict = split_indices(
            samples_per_class,
            sample_num_cumsum_per_class,
            num_clients,
            label_space,
            index_key="order_idx",
        )

        # partition the dataset based on assigned sample indices
        data_dict = {
            f"{split}-client{cid}": data_to_split.select(
                indices_client_dict[cid]
            ).remove_columns("order_idx")
            for cid in range(num_clients)
        }

    if verbose:
        print(f"Parition report: ")
        for cid in range(num_clients):
            per_client_cnt = Counter(data_dict[f"{split}-client{cid}"]["label"])
            total_per_client = len(data_dict[f"{split}-client{cid}"]["label"])
            print(f"Client {cid}: {per_client_cnt} \t Overall: {total_per_client}")

    if isinstance(dataset, DatasetDict):
        for other_split in dataset.keys():
            if other_split != split:
                # data_dict[other_split] = dataset[other_split]
                data_dict[other_split] = datasplit_verification(
                    dataset[other_split], num_classes
                )

                if other_split == test_split:
                    data_dict[other_split] = datasplit_subset(
                        data_dict[other_split],
                        subset_num=subset_num,
                        split=test_split,
                        verbose=verbose,
                    )

    fed_dataset = DatasetDict(data_dict)
    return fed_dataset


def cls_noniid_partition(
    dataset: Union[Dataset, DatasetDict],
    split: Optional[str] = "train",
    data_name: Optional[str] = "sst2",
    num_clients: Optional[int] = 5,
    major_classes_num: Optional[int] = 1,
    test_split: Optional[str] = "test",
    subset_num: Optional[int] = None,
    seed=0,
    verbose=True,
):
    np.random.seed(seed)
    random.seed(seed)

    num_classes = data2numcls[data_name]
    assert num_classes >= major_classes_num
    assert major_classes_num * num_clients >= num_classes

    data_to_split = _check_dataset_type(dataset, split)
    num_samples = len(data_to_split)
    order_idx = list(range(num_samples))
    data_to_split = data_to_split.add_column("order_idx", order_idx)
    label_cnt = Counter(data_to_split["label"])
    print(f"Overall {split} sample number per class: {label_cnt}")

    label_space = list(range(num_classes))
    print(f"label_cnt: {label_cnt};\t label_space: {label_space}")
    _check_label_space(label_cnt, label_space)
    print(f"label_space: {label_space}")

    if num_clients == 1:
        data_dict = {
            f"{split}-client0": data_to_split.select(
                list(range(num_samples))
            ).remove_columns("order_idx")
        }
    else:
        # assign class for each client
        per_client_label_space, label_space_cnt = assign_label_to_client(
            num_classes, num_clients, major_classes_num
        )

        sample_num_per_class = {}
        for label in label_space:
            sample_num = label_cnt[label]
            shard_num = label_space_cnt[label]
            per_shard_sample_num = int(math.floor(sample_num / shard_num))
            cur_class_sample_num = []
            for cid in range(num_clients):
                if label in per_client_label_space[cid]:
                    cur_class_sample_num.append(per_shard_sample_num)
                else:
                    cur_class_sample_num.append(0)

            sample_num_per_class[label] = cur_class_sample_num

        # print("sample_num_per_class: ", sample_num_per_class)
        sample_num_cumsum_per_class = cumsum_for_each(
            sample_num_per_class, label_space
        )  # {label_1: [...], label_2: [...], ...}

        # calculate remained sample number for each class
        _check_remain_sample(sample_num_cumsum_per_class, label_cnt, label_space)

        # filter out sample indices for each class
        samples_per_class = get_per_class_samples(data_to_split, label_space)

        # partition sample indices to client based on per-class sample number on each client (indices_per_class)
        indices_client_dict = split_indices(
            samples_per_class,
            sample_num_cumsum_per_class,
            num_clients,
            label_space,
            index_key="order_idx",
        )

        # partition the dataset based on assigned sample indices
        data_dict = {
            f"{split}-client{cid}": data_to_split.select(
                indices_client_dict[cid]
            ).remove_columns("order_idx")
            for cid in range(num_clients)
        }

    if verbose:
        print(f"Parition report: ")
        for cid in range(num_clients):
            per_client_cnt = Counter(data_dict[f"{split}-client{cid}"]["label"])
            total_per_client = len(data_dict[f"{split}-client{cid}"]["label"])
            print(f"Client {cid}: {per_client_cnt} \t Overall: {total_per_client}")

    if isinstance(dataset, DatasetDict):
        for other_split in dataset.keys():
            if other_split != split:
                # data_dict[other_split] = dataset[other_split]
                data_dict[other_split] = datasplit_verification(
                    dataset[other_split], num_classes
                )

                if other_split == test_split:
                    data_dict[other_split] = datasplit_subset(
                        data_dict[other_split],
                        subset_num=subset_num,
                        split=test_split,
                        verbose=verbose,
                    )

    fed_dataset = DatasetDict(data_dict)

    return fed_dataset


def datasplit_verification(datasplit: Dataset, num_classes=-1):
    total_sample_num = len(datasplit)
    exclude = []
    for idx in range(total_sample_num):
        if int(datasplit[idx]["label"]) not in list(range(num_classes)):
            exclude.append(idx)

    if len(exclude) > 0:
        reduced_datasplit = datasplit.select(
            i for i in range(total_sample_num) if i not in set(exclude)
        )
        print("Reduce the dataset size... (Remove wrong samples...)")
    else:
        reduced_datasplit = datasplit

    return reduced_datasplit


def datasplit_subset(
    datasplit: Dataset,
    subset_num: Optional[int] = None,
    seed: Optional[int] = 1,
    verbose: Optional[bool] = True,
    split: Optional[str] = "test",
    return_remain: Optional[bool] = False,
):
    orig_sample_num = len(datasplit)
    if subset_num is None or orig_sample_num == subset_num:
        subset_data = datasplit
    else:
        assert orig_sample_num > subset_num

        rng = np.random.default_rng(seed=seed)
        subset_idxs = sorted(
            rng.choice(orig_sample_num, size=subset_num, replace=False, shuffle=True)
        )
        subset_data = datasplit.select(subset_idxs)

        if return_remain:
            remain_idxs = [i for i in range(orig_sample_num) if i not in subset_idxs]
            remain_data = datasplit.select(remain_idxs)

    if verbose:
        label_cnt = Counter(datasplit["label"])
        subset_label_cnt = Counter(subset_data["label"])
        print(
            f"Original {split} set sample number: {orig_sample_num}; label distibution: {label_cnt}"
        )
        print(
            f"Subset {split} set sample number: {subset_num}; label distribution: {subset_label_cnt}"
        )

    if return_remain:
        return subset_data, remain_data
    else:
        return subset_data


def _check_sanity_noniid():
    split = "train"
    num_clients = 7
    major_classes_num = 2
    for data_name in data2usename.keys():
        print(f"=========== {data_name}")
        dataset = load_dataset(data2usename[data_name])
        fed_dataset = cls_noniid_partition(
            dataset,
            split=split,
            data_name=data_name,
            num_clients=num_clients,
            major_classes_num=major_classes_num,
            seed=0,
        )
        print(fed_dataset)

        total_indices = []
        for cid in range(num_clients):
            total_indices.extend(fed_dataset[f"{split}-client{cid}"]["idx"])

        print(f"Total sample num after partition: {len(total_indices)}")
        print(f"Total unique sample num after partition: {len(set(total_indices))}")
        print(f"Original total sample num: {len(dataset[split]['label'])}")
        print("\n\n")


def _check_sanity_iid():
    num_clients = 4
    for data_name in data2usename.keys():
        print(f"=========== {data_name}")
        dataset = load_dataset(data2usename[data_name])
        fed_dataset = cls_iid_partition(
            dataset, split="train", data_name=data_name, num_clients=num_clients, seed=0
        )
        print(fed_dataset)
        print("\n\n")


if __name__ == "__main__":
    _check_sanity_iid()
    _check_sanity_noniid()
