import torch
import torch.nn as nn
from IPython import embed

'''
Krum aggregation
- find the point closest to its neignborhood

Reference:
Blanchard, Peva, Rachid Guerraoui, and Julien Stainer. "Machine learning with adversaries: Byzantine tolerant gradient descent." Advances in Neural Information Processing Systems. 2017.
`https://papers.nips.cc/paper/6617-machine-learning-with-adversaries-byzantine-tolerant-gradient-descent.pdf`

'''


def getKrum(input, f):
    '''
    compute krum or multi-krum of input. O(dn^2)
    
    input : batchsize* vector dimension * n
    
    return 
        krum : batchsize* vector dimension * 1
        mkrum : batchsize* vector dimension * 1
    '''

    n = input.shape[-1]
    k = n - f

    # collection distance, distance from points to points
    x = input.permute(0, 2, 1)
    cdist = torch.cdist(x, x, p=2)
    # find the k nbh of each point
    nbhDist, nbh = torch.topk(cdist, k, largest=False)
    i_star = torch.argmin(nbhDist.sum(2))
    return i_star


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

    def forward(self, input):
        #         print(input.shape)
        '''
        input: batchsize* vector dimension * n 
        
        return 
            out : batchsize* vector dimension * 1
        '''
        S = []
        left = input.clone().detach()
        
        n = input.shape[-1]
        f = (n-3) // 4 # 4*f+3 <= n
        theta = n - 2*f
        beta = theta - 2*f
        for i in range(theta):
            i_star = getKrum(left, f)
            S.append(left[:, :, [i_star]].clone().detach())
            left = torch.cat((left[:, :, :i_star], left[:, :, i_star+1:]), dim=-1)
        S = torch.cat(S, dim=-1)
        m = torch.median(S, dim=-1, keepdim=True)[0]
        dis = (S - m).abs()
        v, index = dis.topk(beta, dim=-1, largest=False)
        mask = (dis <= v[:, :, -1:] + 1e-10).float()
        out = (S * mask).sum(dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True)
        return out
