import torch
from scipy.stats import norm


def omniscient_attack_on_server(byz_mode, byz_num, rank, current_vector, good_vectors, epsilon=0):
    """
    Compute gradients with omniscient attack on server
    :param byz_mode: type of Byzantine attack
    :param byz_num: total number of Byzantine workers
    :param rank: the number of the worker that the gradient is received from
    :param current_vector: currently received gradient
    :param good_vectors: history gradients received from non-Byzantine workers (omniscient knowledge)
    :param epsilon: hyper-parameter of FoE attack

    :return: gradients after omniscient attack
    """

    if rank >= byz_num or byz_mode == 'noAtk' or byz_mode == 'labelFlip' or byz_mode == 'bitFlip' or \
            byz_mode == 'RD_atk' or byz_mode == 'NG_atk':
        pass

    else:

        if byz_mode == 'FoE':
            current_vector = torch.mean(torch.stack(good_vectors), dim=0)
            current_vector.mul_(-epsilon)

        elif byz_mode == 'ALIE':
            good_num = len(good_vectors)
            s = int((good_num + byz_num) / 2 + 1) - byz_num
            z = norm.ppf((good_num - s) / good_num)
            mu = torch.mean(torch.stack(good_vectors), dim=0)
            sigma = torch.std(torch.stack(good_vectors), dim=0)
            current_vector = mu.add(-z, sigma)

        else:
            raise ValueError("undefined byz_mode!")

    return current_vector
