import torch
from .cw_trmean import *
from .cw_median import *
from .geomed import *
from .krum import *
from .centered_clipping import *
from .multi_krum import *


def robust_aggregate(rule, vectors, byz_num=-1, weights=-1, reference_point=-1, iterations=-1, nu=-1, tau=-1):
    """
    Compute the result of robust aggregation
    :param rule: the rule applied for robust aggregation
    :param vectors: vectors to be aggregated
    :param byz_num: assumed number of Byzantine vectors
    :param weights: weights parameter for each vector (applied for 'geoMed' rule)
    :param reference_point: estimation of the aggregated result in advance (applied for 'geoMed' rule)
    :param iterations: number of iterations (applied for 'geoMed' & 'CC' rules)
    :param nu: smoothness hyper-parameter (applied for 'geoMed' rule)
    :param tau: clipping radius hyper-parameter (applied for 'CC' rule)
    :return: robust aggregation result
    """

    if rule == 'mean':
        return torch.mean(torch.stack(vectors), dim=0)

    elif rule == 'cwMedian':
        return cw_median(vectors)

    elif rule == 'cwTrmean':
        return cw_trmean(vectors, trim_num=byz_num)

    elif rule == 'geoMed':
        return geomed(vectors, alphas=weights, initial_point=reference_point, iterations=iterations, nu=nu)

    elif rule == 'Krum':
        return krum(vectors, byz_num=byz_num)

    elif rule == 'multiKrum':
        return multi_krum(vectors, byz_num=byz_num)

    elif rule == 'CC':
        return centered_clipping(vectors, tau, initial_point=reference_point, iterations=iterations)

    else:
        raise ValueError("robust rule undefined!")
