import copy
import numpy as np
import torch

def median(l):
    index=int(len(l)/2)
    l.sort()
    return l[index]

def kth_number(l,k):
    l.sort()
    return l[k]

def trimmed_mean(l, byzantine_ratio):
    index_low=int(len(l)*byzantine_ratio)
    index_high=len(l)-int(len(l)*byzantine_ratio)
    l=l.sort().values
    return torch.mean(l[index_low:index_high])

def trimmed_mean_index(l, index_low,index_high):
    l=l.sort().values
    return torch.mean(l[index_low:index_high])

def robust_mean_median(losss,need_var):
    losss=np.array(losss)
    mean_esti=median(losss)
    if need_var==False:
        return mean_esti
    else:
        loss_vars=[(losss[i]-mean_esti)**2 for i in range(len(losss))]
        loss_vars_esti=median(loss_vars)
        return mean_esti, loss_vars_esti

def robust_mean_trimmed(losss,byzantine_ratio,need_var):
    losss=torch.tensor(losss)
    mean_esti=trimmed_mean(losss,byzantine_ratio)
    # print(losss,mean_esti)
    if need_var==False:
        return mean_esti
    else:
        loss_vars=[(losss[i]-mean_esti)**2 for i in range(len(losss))]
        loss_vars_esti=trimmed_mean(loss_vars,byzantine_ratio)
        return mean_esti, loss_vars_esti

def robust_mean_two_equ_class_trimmed(losss,byzantine_ratio,need_var):
    losss=np.array(losss)
    c=len(losss)
    b=int(c*byzantine_ratio)
    index_low=int((c+b)/2)
    index_high=c-b
    mean_esti=trimmed_mean_index(losss,index_low,index_high)
    if need_var==False:
        return mean_esti
    else:
        loss_vars=[(losss[i]-mean_esti)**2 for i in range(len(losss))]
        # loss_vars_esti=trimmed_mean_index(loss_vars,index_low=b,index_high=int((c-b)/2))
        loss_vars_esti=trimmed_mean_index(loss_vars,index_low,index_high)
        return mean_esti, loss_vars_esti

def robust_mean(losss,byzantine_ratio=None,method='median',need_var=False):
    if method=='median':
        return robust_mean_median(losss,need_var=need_var)
    if method=='trimmed_mean':
        return robust_mean_trimmed(losss,byzantine_ratio,need_var=need_var)
    if method=='two_equ_class_trimmed_mean':
        return robust_mean_two_equ_class_trimmed(losss,byzantine_ratio,need_var=need_var)
    if method=='kth_number':
        return kth_number(losss)

def client_space(local_vectors):
    local_vectors=np.array(local_vectors)
    Q,R=np.linalg.qr(local_vectors.T)
    client_dims=len(R)
    return Q,R.T,client_dims

