from functools import reduce

import numpy as np

from data_utils import load_dset, split_dset_iid, split_dset_diri
from utils import SubsetWithTargets, count_cls


class Sampler:
    def __init__(self, cfg):
        """load dataset, partition idxs to cliens and server, give dsets correponding to the idxs

        Args:
            cfg (DictConfig): (local) config for this run. dset, fl, load_idxs, sample_method are used.

        Attributes:
            dset_train, dset_test:
                Examples:
                    type(dset_train): torchvision.datasets.cifar.CIFAR10
                    len(dset_train) = 50000
                    dset_train[0]: tuple(data: Tensor(3, 32, 32) float, label: np.int64)
                    dset_train.data: ndarray(50000, 32, 32, 3) np.uint8. Don't use this
                    dset_train.targets: ndarray(50000,) np.int64
            data_idxs_for_c (list[np.ndarray]):
                key: client idx, value: data idxs for the client
                e.g., len(data_idxs_for_c) = 4, data_idxs_for_c[0].shape = (2500,)
            data_idxs_s (np.ndarray): data idxs for server
                e.g., data_idxs_s.shape = (40000,)
            data_idxs_s2 (np.ndarray): data idxs for another server (used to discard data)
                e.g., data_idxs_s.shape = (40000,)
            n_data_for_cls_for_c (Dict[int, List[Tuple[int, int]]]):
                key: client idx, value: tuple(cls_idx: int, n_data: int)
        """
        self.cfg = cfg

        # split client dset and server dset
        dset_train, self.dset_test = load_dset(cfg.dset.path, cfg.dset.name)
        if cfg.dset.name != cfg.dset_s.name:
            self.data_idxs_c, _ = split_dset_iid(dset_train.targets, [cfg.fl.n_c_ratio, 1 - cfg.fl.n_c_ratio])
            dset_train_c = SubsetWithTargets(dset_train, self.data_idxs_c)
            dset_train_s, _ = load_dset(cfg.dset_s.path, cfg.dset_s.name)
            self.data_idxs_s, _ = split_dset_iid(dset_train_s.targets, [cfg.fl.n_s_ratio, 1 - cfg.fl.n_s_ratio])
            self.dset_chunk_s = SubsetWithTargets(dset_train_s, self.data_idxs_s)
        else:
            self.data_idxs_s, data_idxs_c, _ = split_dset_iid(
                dset_train.targets, [cfg.fl.n_s_ratio, cfg.fl.n_c_ratio, 1 - cfg.fl.n_s_ratio - cfg.fl.n_c_ratio]
            )
            dset_train_c = SubsetWithTargets(dset_train, data_idxs_c)
            self.dset_train_c = dset_train_c
            self.dset_chunk_s = SubsetWithTargets(dset_train, self.data_idxs_s)
            assert len(self.data_idxs_s) + len(data_idxs_c) == len(set(self.data_idxs_s) | set(data_idxs_c))

        self.data_idxs_for_c = split_dset_diri(dset_train_c.targets, cfg.fl.n_c, cfg.fl.diri_alpha)
        assert reduce(lambda acc, cur: acc + len(cur), self.data_idxs_for_c, 0) == len(dset_train_c)
        assert len(reduce(lambda acc, cur: acc | set(cur), self.data_idxs_for_c, set())) == len(dset_train_c)
        # create dset from the idxs
        self.dset_chunks = [SubsetWithTargets(dset_train_c, data_idx) for data_idx in self.data_idxs_for_c]
        self.n_data_for_cls_for_c = [count_cls(dset_chunk) for dset_chunk in self.dset_chunks]


if __name__ == "__main__":
    dset_train, dset_test = load_dset("/workspace/data/cifar10", "cifar10")
    idxs1, idxs2, idxs3 = split_dset_iid(dset_train.targets, [0.8, 0.1, 0.1])
    assert len(set(idxs1) | set(idxs2) | set(idxs3)) == len(dset_train)
    assert len(idxs1) + len(idxs2) + len(idxs3) == len(dset_train)
    assert all(
        np.unique(SubsetWithTargets(dset_train, idxs1).targets, return_counts=True)[1]
        == np.ones(10, dtype=np.int64) * 4000
    )
    assert all(
        np.unique(SubsetWithTargets(dset_train, idxs2).targets, return_counts=True)[1]
        == np.ones(10, dtype=np.int64) * 500
    )
    assert all(
        np.unique(SubsetWithTargets(dset_train, idxs3).targets, return_counts=True)[1]
        == np.ones(10, dtype=np.int64) * 500
    )
