# import sys 
# sys.path.append("..") 
# sys.path.append("../robustFL") 

import geom_median
from geom_median.numpy import compute_geometric_median

import scipy.optimize
from .util import client_space
from .if2 import iterative_filtering, iterative_filtering_sigma_unknown
from .util import robust_mean, median
import copy
import numpy as np
from typing import Literal


import scipy

def print_norm_angle(next_vector_reduced,local_vectors_reduced,byzantine_ratio):
    client_num=len(next_vector_reduced)
    byzantine_num=int(client_num*byzantine_ratio)
    honest_num=client_num-byzantine_num
    true_grad=np.average(local_vectors_reduced[:honest_num],0)
    byzantine_grad=np.average(local_vectors_reduced[honest_num:],0)
    normalized_true_grad=true_grad/np.linalg.norm(true_grad)
    normalized_byzantine_grad=byzantine_grad/np.linalg.norm(byzantine_grad)
    normalized_next_dir=next_vector_reduced/np.linalg.norm(next_vector_reduced)

    print('next_vector_reduced_norm',np.linalg.norm(next_vector_reduced),np.linalg.norm(true_grad),np.linalg.norm(byzantine_grad))
    print('next_vector_reduced_angle',np.dot(normalized_next_dir,normalized_true_grad),np.dot(normalized_next_dir,normalized_byzantine_grad))

def FL_mnist_next_iterative_filtering(local_vectors,byzantine_ratio): #robust gradient
    # state_dict_to_list(state_dict)
    # next_model, current_vec, local_vectors, form=cur_next2vecs(current, local_nexts)
    Q,local_vectors_reduced,client_dims=client_space(local_vectors)

    # local_vectors_reduced_normalized, half_cov=normalize_vecs(local_vectors_reduced)
    next_vector_reduced,remain_workers=iterative_filtering_sigma_unknown(local_vectors_reduced,alpha=1-byzantine_ratio) #Z:|A|*d
    client_num=len(local_vectors)
    # print("remain_workers",remain_workers)
    # print_norm_angle(next_vector_reduced,local_vectors_reduced,byzantine_ratio)

    next_vector=np.dot(Q,next_vector_reduced)

    return next_vector,remain_workers.tolist()

def FL_mnist_next_geo_median(local_vectors,byzantine_ratio): #robust gradient
    Q,local_vectors_reduced,client_dims=client_space(local_vectors)
    client_num=len(local_vectors)
    # random_dir=random_unit_vector(client_dims)
    
    def distance_sum(v):
        s=0
        for i in range(client_num):
            s=s+np.linalg.norm(local_vectors_reduced[i]-v)
        return s
    while True:
        index=np.random.randint(client_num)
        ini=local_vectors_reduced[index]
        res = scipy.optimize.minimize(distance_sum, ini, method='Nelder-Mead', options={'disp': True})
        if res.success==True:
            break
        else:
            print('Fail')
            break
    next_vector_reduced=res.x

    # print_norm_angle(next_vector_reduced,local_vectors_reduced,byzantine_ratio)

    next_vector=np.dot(Q,next_vector_reduced)
    return next_vector,None

def FL_mnist_next_geo_median_Weiszfeld(local_vectors,byzantine_ratio): #robust gradient
    Q,local_vectors_reduced,client_dims=client_space(local_vectors)
    client_num=len(local_vectors)
    # random_dir=random_unit_vector(client_dims)
    points=local_vectors_reduced
    out = compute_geometric_median(points, weights=None, per_component=False, skip_typechecks=False,
	eps=1e-6, maxiter=1000, ftol=1e-20) 
    # print(out.median)
    next_vector_reduced=out.median

    # print_norm_angle(next_vector_reduced,local_vectors_reduced,byzantine_ratio)

    next_vector=np.dot(Q,next_vector_reduced)
    return next_vector,None

def FL_mnist_next_krum(local_vectors,byzantine_ratio): #robust gradient
    Q,local_vectors_reduced,client_dims=client_space(local_vectors)
    client_num=len(local_vectors)
    byzantine_num=int(client_num*byzantine_ratio)
    score=np.zeros(client_num)
    for i in range(client_num):
        l=[np.linalg.norm(local_vectors_reduced[j]-local_vectors_reduced[i])**2 for j in range(client_num)]
        l.sort()
        for j in range(client_num-byzantine_num-1):
            score[i]=score[i]+l[j]
    i=np.argmin(score)
    # print('selected_client',i)
    next_vector_reduced=local_vectors_reduced[i]
    next_vector=np.dot(Q,next_vector_reduced)
    return next_vector,[i]

def FL_mnist_next_coordinate_wise(local_vectors,byzantine_ratio=None,robust_mean_method="median"):
    # Q,local_vectors_reduced,client_dims=client_space(local_vectors)
    client_num=len(local_vectors)
    para_dim=len(local_vectors[0])
    byzantine_num=int(client_num*byzantine_ratio)
    next_vector=np.zeros(para_dim)
    for i in range(para_dim):
        m,_=robust_mean([local_vectors[j][i] for j in range(client_num)],byzantine_ratio=byzantine_ratio,method=robust_mean_method)
        next_vector[i]=m
    return next_vector,None

def FL_mnist_next(
        local_vectors,
        grad_sele_rule:Literal["krum","coordinatewise","iterative_filtering","geo_median"],
        byzantine_ratio=None,robust_mean_method='median'):
    
    if grad_sele_rule=='krum':
        return FL_mnist_next_krum(local_vectors,byzantine_ratio)

    if grad_sele_rule=='coordinatewise':
        return FL_mnist_next_coordinate_wise(local_vectors,byzantine_ratio=byzantine_ratio,robust_mean_method=robust_mean_method)
    
    if grad_sele_rule=='iterative_filtering':
        return FL_mnist_next_iterative_filtering(local_vectors,byzantine_ratio=byzantine_ratio)
    
    if grad_sele_rule=='geo_median':
        return FL_mnist_next_geo_median(local_vectors,byzantine_ratio)
    if grad_sele_rule=='geo_median_w':
        return FL_mnist_next_geo_median_Weiszfeld(local_vectors,byzantine_ratio)

    print('grad_sele_rule not valid:', grad_sele_rule)

