#! -*- coding: utf-8
import typing

import numpy as np
import pandas as pd


def dirichlet(targets: np.ndarray, ndistribute: int,
              alpha: float = 1.0, seed=0, num_auxiliary_workers=10,
              min_size_rate: float = 0.5, label_sort: bool = False) -> typing.List[typing.List[int]]:
    # we refer https://github.com/epfml/relaysgd/tree/89719198ba227ebbff9a6bf5b61cb9baada167fd
    """Code adapted from Tao Lin (partition_data.py)"""
    random_state = np.random.RandomState(seed=seed)

    num_indices = len(targets)
    num_classes = len(np.unique(targets))

    indices2targets = np.array(list(enumerate(targets)))
    random_state.shuffle(indices2targets)

    # partition indices.
    from_index = 0
    splitted_targets = []
    num_splits = int(np.ceil(ndistribute / num_auxiliary_workers))
    split_n_workers = [
        num_auxiliary_workers
        if idx < num_splits - 1
        else ndistribute - num_auxiliary_workers * (num_splits - 1)
        for idx in range(num_splits)
    ]
    split_ratios = [_n_workers / ndistribute for _n_workers in split_n_workers]
    for idx, ratio in enumerate(split_ratios):
        to_index = from_index + \
            int(num_auxiliary_workers / ndistribute * num_indices)
        splitted_targets.append(
            indices2targets[
                from_index: (num_indices if idx ==
                             num_splits - 1 else to_index)
            ]
        )
        from_index = to_index

    idx_batch = []
    for _targets in splitted_targets:
        # rebuild _targets.
        _targets = np.array(_targets)
        _targets_size = len(_targets)

        # use auxi_workers for this subset targets.
        _n_workers = min(num_auxiliary_workers, ndistribute)
        ndistribute = ndistribute - num_auxiliary_workers

        # get the corresponding idx_batch.
        min_size = 0
        while min_size < int(min_size_rate * _targets_size / _n_workers):
            _idx_batch = [[] for _ in range(_n_workers)]
            for _class in range(num_classes):
                # get the corresponding indices in the original 'targets' list.
                idx_class = np.where(_targets[:, 1] == _class)[0]
                idx_class = _targets[idx_class, 0]

                # sampling.
                try:
                    proportions = random_state.dirichlet(
                        np.repeat(alpha, _n_workers)
                    )
                    # balance
                    proportions = np.array(
                        [
                            p * (len(idx_j) < _targets_size / _n_workers)
                            for p, idx_j in zip(proportions, _idx_batch)
                        ]
                    )
                    proportions = proportions / proportions.sum()
                    proportions = (np.cumsum(proportions) * len(idx_class)).astype(int)[
                        :-1
                    ]
                    _idx_batch = [
                        idx_j + idx.tolist()
                        for idx_j, idx in zip(
                            _idx_batch, np.split(idx_class, proportions)
                        )
                    ]
                    sizes = [len(idx_j) for idx_j in _idx_batch]
                    min_size = min([_size for _size in sizes])
                    # print(min_size_rate, min_size, int(min_size_rate * _targets_size / _n_workers))
                except ZeroDivisionError:
                    pass
        idx_batch += _idx_batch

    if label_sort:
        # sorting class
        N = len(idx_batch)
        counts = np.zeros((N, num_classes)).astype(int)
        for i, idxs in enumerate(idx_batch):
            datas = targets[idxs]
            labels, cnts = np.unique(datas, return_counts=True)
            counts[i, labels] = cnts

        order = list(random_state.permutation(range(num_classes)))
        ascending = [False, True] * (num_classes // 2)
        sort_indexs = np.array(pd.DataFrame(counts).sort_values(
            by=order, ascending=ascending).index)
        idx_batch = [idx_batch[i] for i in sort_indexs]
    return idx_batch
