import torch
import torch.nn as nn
from utils.utils import net2vec
from IPython import embed
from utils.utils import getFloatSubModules, getNetMeta

def element2neuro(vec, net, func):
    param_float = getFloatSubModules(net)
    shapes, sizes = getNetMeta(net)
    partition = list(sizes[param] for param in param_float)
    deltas, params = [], []
    for i in range(vec.size(-1)):
        flattenComponents = dict(zip(param_float, torch.split(vec[:,i], partition)))
        components = dict(((k, v.reshape(shapes[k])) for (k, v) in flattenComponents.items()))
        delta = []
        for name in param_float:
            if len(components[name].size()) == 1:
                delta.append(components[name].abs().view(-1, 1))
                if i == 0:
                    params.append(net[name].abs().view(-1, 1))
            elif len(components[name].size()) == 2:
                delta.append(func(components[name]).view(-1, 1))
                if i == 0:
                    params.append(func(net[name]).view(-1, 1))
            elif len(components[name].size()) == 3:
                N = components[name].size(0)
                delta.append(func(components[name].view(N, -1)).view(-1, 1))
                if i == 0:
                    params.append(func(net[name].view(N, -1)).view(-1, 1))
            else:
                assert(False)
        deltas.append(torch.cat(delta, dim=0))
    deltas = torch.cat(deltas, dim=1)
    params = torch.cat(params, dim=0)
    return deltas, params

def getKrum_func(vec, f, K):
    '''
    compute krum or multi-krum of input. O(dn^2)
    
    input : batchsize* vector dimension * n
    
    return 
        krum : batchsize* vector dimension * 1
        mkrum : batchsize* vector dimension * 1
    '''
    n = vec.shape[-1]
    k = n - f

    # collection distance, distance from points to points
    x = vec.view(1, -1, n).permute(0, 2, 1)
    # new cdist: modify begin
    cdist = (x - x.view(n, 1, -1)).abs().view(n*n, -1)
    cdist = cdist.topk(K, dim=1)[0].mean(dim=1).view(n, n)
    # new cdist: modify end
    # find the k nbh of each point
    nbhDist, nbh = torch.topk(cdist, k, largest=False)
    i_star = torch.argmin(nbhDist.sum(1))
    return i_star
    
def bulyan_func(input, vec, K):
    S = []
    left_vec = vec.clone().detach()
    left_input = input.clone().detach()
    
    n = input.shape[-1]
    f = (n-3) // 4 # 4*f+3 <= n
    theta = n - 2*f
    beta = theta - 2*f
    
    for i in range(theta):
        i_star = getKrum_func(left_vec, f, K)
        S.append(left_input[:, :, [i_star]].clone().detach())
        left_input = torch.cat((left_input[:, :, :i_star], left_input[:, :, i_star+1:]), dim=-1)
        left_vec = torch.cat((left_vec[:, :i_star], left_vec[:, i_star+1:]), dim=-1)
    S = torch.cat(S, dim=-1)
    m = torch.median(S, dim=-1, keepdim=True)[0]
    dis = (S - m).abs()
    v, index = dis.topk(beta, dim=-1, largest=False)
    mask = (dis <= v[:, :, -1:] + 1e-10).float()
    out = (S * mask).sum(dim=-1, keepdim=True) / mask.sum(dim=-1, keepdim=True)
    std = S.std(dim=-1, unbiased=True, keepdim=True)
    return out, std
   
def krum_func(input, vec, K, mode, history):
    n = input.shape[-1]
    f = (n-1) // 2 # 2*f+1 <= n
    k = n - f
    
    # collection distance, distance from points to points
    if mode in ['mdis', 'ddis', 'ldis']:
        x = vec.t()
        x_mean = x.mean(dim=0, keepdims=True)
        x_var = x.std(dim=0, unbiased=True, keepdims=True) ** 2
        SMOOTH_RATIO = 0.1
        x_var = x_var**(1-SMOOTH_RATIO) + x_var.mean()*SMOOTH_RATIO
        # K is initial K_div
        if mode == 'ddis':
            dis = (x - x_mean)**2 / x_var
        elif mode == 'ldis':
            dis = (x - x_mean)**2 * x_var
        elif mode == 'mdis':
            dis = (x - x_mean)**2
        else:
            assert(False)
        dis = dis.topk(K, dim=1)[0].mean(dim=1)
        if not (history is None):
            if len(history.size()) == 2:
                history = dis
            else:
                history.mul_(0.9).add_(dis)
            dis = history
        minDis, minIndex = torch.topk(dis, k, largest=False)
        valid_index = minIndex
        out = input[:, :, valid_index.view(-1)].mean(dim=2, keepdims=True)
        std = input[:, :, valid_index.view(-1)].std(dim=2, unbiased=True, keepdims=True)
        return out, std
        
    # mkrum or krum
    # cdist = torch.cdist(x, x, p=2)
    # new cdist: modify begin
    x = vec.view(1, -1, n).permute(0, 2, 1)
    cdist = ((x - x.view(n, 1, -1))**2).view(n*n, -1)
    if mode == 'dkrum':
        x_sigma_2 = torch.median(x**2, dim=1)[0]
        SMOOTH_RATIO = 0.1
        cdist = cdist / (x_sigma_2*(1-SMOOTH_RATIO) + x_sigma_2.mean()*SMOOTH_RATIO)
        cdist = cdist.topk(K, dim=1)[0].mean(dim=1).view(n, n)    
    elif mode == 'lkrum':
        x_sigma_2 = torch.median(x**2, dim=1)[0]
        SMOOTH_RATIO = 0.1
        cdist = cdist * (x_sigma_2*(1-SMOOTH_RATIO) + x_sigma_2.mean()*SMOOTH_RATIO)
        cdist = cdist.topk(K, dim=1)[0].mean(dim=1).view(n, n)
    elif mode in ['mkrum', 'krum']:
        cdist = cdist.topk(K, dim=1)[0].mean(dim=1).view(n, n)
    else:
        assert(False)
    if not (history is None):
        history.mul_(0.9).add_(cdist)
        cdist = history
    # new cdist: modify end
    
    # find the k nbh of each point
    nbhDist, nbh = torch.topk(cdist, k, largest=False)
    i_star = torch.argmin(nbhDist.sum(1))
    print(nbh[i_star, :], flush=True)
    if mode == 'krum': # krum: m = 1
        out = input[:, :, [i_star]]
    elif mode in ['mkrum', 'dkrum', 'lkrum']: # Multi-Krum: m = n - f
        out = input[:, :, nbh[i_star, :].view(-1)].mean(2, keepdims=True)
    std = input[:, :, nbh[i_star, :].view(-1)].std(dim=-1, unbiased=True, keepdim=True)
    return out, std
    
class Net(nn.Module):
    def __init__(self, model, vocab):
        super(Net, self).__init__()
        self.model = model
        self.vocab = vocab
        freqs = torch.zeros(len(vocab.itos)).long()
        for i in range(len(vocab.itos)):
            freqs[i] = vocab.freqs[vocab.itos[i]]
        freqs[0] = int(1e10)
        self.freqs = freqs
        
    def forward(self, input, scale, t, epochs, sigma, K_div, rou_k, rou_b, mode='mkrum', history=None):
        vec = input[0]
        if scale == 'neuro':
            delta, param = element2neuro(vec, self.model.state_dict(), lambda x: (x**2).mean(dim=1).sqrt())
            vec = delta / (param + 1e-8)
            K = vec.size(0) // K_div
        elif scale == 'element':
            K = vec.size(0) // K_div
        else:
            assert(False)
        
        if mode == 'krum':
            out, std = krum_func(input, vec, K, 'krum', history)
        elif mode == 'mkrum':
            out, std = krum_func(input, vec, K, 'mkrum', history)
        elif mode == 'dkrum':
            out, std = krum_func(input, vec, K, 'dkrum', history)
        elif mode == 'lkrum':
            out, std = krum_func(input, vec, K, 'lkrum', history)
        elif mode == 'mdis':
            out, std = krum_func(input, vec, K, 'mdis', history)
        elif mode == 'ddis':
            out, std = krum_func(input, vec, K, 'ddis', history)
        elif mode == 'ldis':
            out, std = krum_func(input, vec, K, 'ldis', history)
        elif mode == 'bulyan':
            out, std = bulyan_func(input, vec, K)
        else:
            assert(False)
            
        # bound begin
        rou = rou_k*t+rou_b
        if rou > 0:
            L2 = torch.norm(out, dim=1, p=2)[0, 0].item()
            out = out * min(rou, L2) / L2
        # bound end
        
        # noise begin
        if (sigma != '0') and (t < epochs):
            if 'a' in sigma:
                out += float(sigma.replace('a', '')) * std * torch.randn(out.size())
            else:
                out += float(sigma) * torch.randn(out.size())
        # noise end
        return out
        