import numpy as np
from scipy.optimize import minimize
from scipy.optimize import NonlinearConstraint
import matplotlib.pyplot as plt

def vector_2norm(v):
    return np.dot(v,v)

def matrix_F_distance(A,B):
    A=A.flatten()
    B=B.flatten()
    C=A-B
    return vector_2norm(C)

def instance_generation_honest(d=100, sample_num=20,client_num=10):
    xs_raw=np.random.randn(client_num,sample_num,d)
    xs=np.zeros([client_num,d])
    for i in range(client_num):
        xs[i]=np.average(xs_raw[i],0)
    return xs

def instance_generation_Byzantine(d=100, sample_num=20,client_num=10):
    xs_raw=np.random.randn(client_num,sample_num,d)+5
    xs=np.zeros([client_num,d])
    for i in range(client_num):
        xs[i]=np.average(xs_raw[i],0)
    return xs

def instance_generation(d=100, sample_num=20,client_num=10,Byzantine_client_num=10):
    xs_honest=instance_generation_honest(d=d, sample_num=sample_num,client_num=client_num)
    xs_Byzantine=instance_generation_Byzantine(d=d, sample_num=sample_num,client_num=Byzantine_client_num)
    return np.concatenate((xs_honest,xs_Byzantine),0)



# np.concatenate(())

def solve_saddle_point(client_num,d,alpha,A,c,xs,w0=None):
    r=(4-alpha)/((2+alpha)*alpha)
    P=np.dot(np.ones([len(A),1]),np.ones([1,len(A)]))/len(A)
    I=np.eye(len(A))
    xs_prime=np.dot((I-P),xs[A])
    u,s,vh=np.linalg.svd(xs_prime) #u orthonormal columns
    # def cons(w):
    #     return r**2-1-vector_2norm(w)
    def cons(w):
        return r-1-vector_2norm(w) ####
    def func(w):
        var=np.array([(s[i]*(1-w[i]))**2 for i in range(len(s))])
        return np.max(var)
    constraints_W = [dict(type='ineq', fun=cons)]
    if w0 is None:
        w=np.zeros(len(s))
    else:
        w=w0
    res = minimize(func, w, method='SLSQP', constraints=constraints_W, options={'ftol': 1e-9, 'disp': False})
    w=res.x
    f=res.fun
    var=np.array([s[i]*(1-w[i]) for i in range(len(s))])
    arg_Y=np.argmax(var)
    tau=[u[i,arg_Y]**2*f for i in range(len(A))]
    w_lenA=np.zeros(len(A))
    w_lenA[:len(w)]=w
    W_prime=np.linalg.multi_dot([u,np.diag(w_lenA),u.T])

    # print("W_prime",np.linalg.norm(W_prime, ord='fro')) 
    W_prime=np.dot(W_prime,(I-P))
    W=P+W_prime
    return W,tau,f

def iterative_filtering(instance,alpha=0.8,sigma_max=1):
    xs=instance
    client_num=len(xs)
    d=len(xs[0])
    c=np.ones(client_num)
    A=np.arange(client_num)
    Y_v=np.zeros(d)
    Y_v[0]=1
    W=np.ones([client_num,client_num])/client_num
    # constraints_W = [dict(type='eq', fun=eq_cons_W),dict(type='ineq', fun=ineq_cons_W,args=(alpha,client_num))]
    # constraints_Y = [dict(type='ineq', fun=cons_Y)]
    while True:
        W,tau,f=solve_saddle_point(client_num,d,alpha,A,c,xs)
        
        # W=W.reshape(len(A),len(A))
        # # Y=Y.reshape(d,d)
        # print(W,Y_v,A,c)
        if f<4*client_num*sigma_max**2: ###############
            break
        # tau=onepoint_var(c,A,xs,W,Y_v)
        # 
        A_new=[]
        tau_max=np.max(tau)
        print('counter')
        for i in range(len(tau)):
            j=A[i]
            c[j]=(1-tau[i]/tau_max)*c[j]
            if c[j]>=1/2:
                A_new.append(j)
        A=np.array(A_new)
        # W=np.array([W[i][A] for i in A])
    w,v=np.linalg.eig(W)
    # print('w',w)
    P=np.dot(np.ones([len(A),1]),np.ones([1,len(A)]))/len(A)
    # print('W',vector_2norm(W.flatten()))
    w1=np.zeros(len(w))
    for i in range(len(w)):
        if w[i]<0.9:
            w1[i]=w[i]
    W1=np.linalg.multi_dot([v.T,np.diag(w1),v])
    W0=np.dot((W-W1),np.linalg.inv(np.eye(len(w))-W1))
    Z=np.dot(W0,xs[A])
    # print('A',A)
    WZ=np.dot(W,xs[A])
    # visualize.high_dimension_plot(,color='g')
    return WZ,Z


def iterative_filtering_sigma_unknown(instance,alpha=0.8,sigma_max=1):
    xs=instance
    client_num=len(xs)
    d=len(xs[0])
    c=np.ones(client_num)
    A=np.arange(client_num)
    min_A_num=((2+alpha)*alpha)/(4-alpha)*client_num
    Y_v=np.zeros(d)
    Y_v[0]=1
    W=np.ones([client_num,client_num])/client_num
    while True:
        W,tau,f=solve_saddle_point(client_num,d,alpha,A,c,xs)
        A_new=[]
        tau_max=np.max(tau)
        if tau_max<1e-6:
            break
        for i in range(len(tau)):
            j=A[i]
            c[j]=(1-tau[i]/tau_max)*c[j]
            if c[j]>=1/2:
                A_new.append(j)
        if len(A_new)<min_A_num: ###############
            break
        else:
            A=np.array(A_new)
    w,v=np.linalg.eig(W)
    P=np.dot(np.ones([len(A),1]),np.ones([1,len(A)]))/len(A)
    w1=np.zeros(len(w))
    counter=0
    for i in range(len(w)):
        if w[i]<0.9:
            counter=counter+1
            w1[i]=w[i].real
    W1=np.linalg.multi_dot([v.T,np.diag(w1),v])
    W0=np.dot((W-W1),np.linalg.inv(np.eye(len(w))-W1))

    return np.average(xs[A],axis=0),A

