import torch


def multi_krum(vectors, byz_num):
    """
    Compute the aggregated result with multi-Krum
    :param vectors: vectors to be aggregated
    :param byz_num: assumed number of Byzantine vectors
    :return: aggregated result
    """

    length = len(vectors)
    if 2*byz_num >= length or byz_num < 0:
        raise ValueError("byz_num error")

    # compute distance matrix
    distance = []
    for i in range(length):
        distance.append(torch.FloatTensor([0] * length).cuda())
    for i in range(length):
        for j in range(i + 1, length):
            distance_ij = torch.norm(vectors[i].sub(vectors[j]), 2)
            distance_ij.mul_(distance_ij)
            distance[i][j] = distance_ij
            distance[j][i] = distance_ij

    # compute krum score
    score = [0] * length
    for i in range(length):
        score[i] = torch.sum(torch.topk(distance[i], length - byz_num - 2, 0, largest=False, sorted=False)[0])
    indices = torch.topk(torch.FloatTensor(score).cuda(), length-2*byz_num, 0, largest=False, sorted=False)[1]
    vectors_tensor = torch.stack(vectors)
    result = torch.mean(vectors_tensor[indices], dim=0)
    return result
