from .comlib import *
import scipy.linalg
from .gagg import FL_mnist_next

def sqrtm(A):
    w, V = np.linalg.eigh(A)
    w = np.where(w < 1e-6, 0, w)
    S = (V @ np.diag(np.sqrt(w)) @ V.T)
    return S

def inv(A):
    w, V = np.linalg.eigh(A)
    inv_w = np.where(np.abs(w) < 1e-6, 0, 1/w)
    S = (V @ np.diag(inv_w) @ V.T)
    return S

class RGAStrategyKernel():
    def __init__(self,grad_strategy:fed_learning.LinearGrad,
                 kernel_strategy,byzantine_ratio,grad_sele_rule,ker_cened=True):
        self.grad_strategy=grad_strategy
        self.kernel_strategy=kernel_strategy
        self.byzantine_ratio=byzantine_ratio
        self.grad_sele_rule=grad_sele_rule
        self.ker_cened=ker_cened

    def get_agg_vec(self,kernel:torch.Tensor):
        kernel=kernel.cpu().numpy()
        sr_kernel=sqrtm(kernel)
        # print(sr_kernel)
        agg_vec,remain_workers1=FL_mnist_next(sr_kernel,self.grad_sele_rule,self.byzantine_ratio)
        return agg_vec

    def get_remain_workers_weight(self,kernel:torch.Tensor):
        n=len(kernel)
        kernel=kernel.cpu().numpy()
        sr_kernel=sqrtm(kernel)
        agg_vec,remain_workers1=FL_mnist_next(sr_kernel,self.grad_sele_rule,self.byzantine_ratio)
        weight=inv(sr_kernel)@agg_vec
        if self.ker_cened:
            weight=weight+np.ones(n,dtype=np.float32)/n
        return torch.from_numpy(weight).to(dtype=torch.float),remain_workers1

    
    def get(self,chosen_workers,worker_dataset,modelSetup,save_func):
        worker_model=fed_learning.WorkerModelKernel(modelSetup,worker_dataset)
        kernel=self.kernel_strategy.get(worker_model,chosen_workers)
        workers_weight,remain_workers_local=self.get_remain_workers_weight(kernel)
        # print(remain_workers.cpu())
        # remain_workers=torch.tensor(chosen_workers)[remain_workers.cpu()]
        if (save_func is not None) & (remain_workers_local is not None):
            remain_workers=torch.tensor(chosen_workers)[torch.tensor(remain_workers_local)]
            save_func(remain_workers)

        grad=self.grad_strategy.get(workers_weight,chosen_workers,worker_dataset,modelSetup)
        return grad
    
    def test_get_remain_workers(self,kernel:torch.Tensor):
        kernel=kernel.cpu().numpy()
        sr_kernel=sqrtm(kernel)
        agg_vec,remain_workers1=FL_mnist_next(sr_kernel,self.grad_sele_rule,self.byzantine_ratio)
        one_hot=torch.zeros(len(kernel))
        one_hot[torch.tensor(remain_workers1)]=1/len(remain_workers1)
        return one_hot
    
    def test_get(self,chosen_workers,worker_dataset,modelSetup,save_func):
        worker_model=fed_learning.WorkerModelKernel(modelSetup,worker_dataset)
        kernel=self.kernel_strategy.get(worker_model,chosen_workers)
        remain_workers=self.test_get_remain_workers(kernel)
        remain_workers=torch.tensor(chosen_workers)[remain_workers.cpu()]
        if save_func is not None:
            save_func(remain_workers)

        grad=self.grad_strategy.get(remain_workers,worker_dataset,modelSetup)
        return grad

    

# class RGAStrategyCpu():
#     def __init__(self,grad_strategy,kernel_strategy,byzantine_ratio,grad_sele_rule):
#         self.grad_strategy=grad_strategy
#         self.kernel_strategy=kernel_strategy
#         self.byzantine_ratio=byzantine_ratio
#         self.grad_sele_rule=grad_sele_rule

#     def get_remain_workers(self,kernel:torch.Tensor):
#         kernel=kernel.cpu().numpy
#         sr_kernel=scipy.linalg.sqrtm(kernel)
#         agg_vec,remain_workers1=FL_mnist_next(sr_kernel,self.grad_sele_rule,self.byzantine_ratio)
#         remain_workers=scipy.linalg.inv(sr_kernel)@agg_vec
#         print(remain_workers1,remain_workers)
#         return torch.from_numpy(remain_workers) 

    
#     def get(self,chosen_workers,worker_dataset,modelSetup,save_func):
#         worker_model=fed_learning.WorkerModelKernel(modelSetup,worker_dataset)
#         kernel=self.kernel_strategy.get(worker_model,chosen_workers)
#         remain_workers=self.get_remain_workers(kernel)
#         remain_workers=torch.tensor(chosen_workers)[remain_workers.cpu()]
#         if save_func is not None:
#             save_func(remain_workers)

#         grad=self.grad_strategy.get(remain_workers,worker_dataset,modelSetup)
#         return grad