from .comlib import *

class IFAgg():
    def __init__(self,eta,end_condition_ratio):
        self.eta=eta
        self.end_condition_ratio=end_condition_ratio
    
    @staticmethod
    def centerize(vecs):
        mu = vecs.mean(dim=0)

        # 中心化数据 (减去均值)
        vecs_centered = vecs - mu

        return vecs_centered
    @staticmethod
    def get_kernel(vecs_centered):
        kernel = (vecs_centered @ vecs_centered.T) / vecs_centered.shape[0]
        return kernel
    @staticmethod
    def get_max_eigen(matrix):
        eigenvalues, eigenvectors = torch.linalg.eigh(matrix)
        # 找到最大特征值的索引
        # print("eigenvalues, eigenvectors",eigenvalues, eigenvectors)
        max_eigenvalue_index = torch.argmax(eigenvalues)

        # 获取最大特征值和对应的特征向量
        max_eigenvalue = eigenvalues[max_eigenvalue_index]
        max_eigenvector = eigenvectors[:, max_eigenvalue_index]
        return max_eigenvalue, max_eigenvector
    @staticmethod
    def maximize(grads):
        grads_centered = IFAgg.centerize(grads)
        kernel=IFAgg.get_kernel(grads_centered)
        max_eigenvalue, max_eigenvector=IFAgg.get_max_eigen(kernel)
        return max_eigenvalue, max_eigenvector
    @staticmethod
    def maximizeAndGetValue(grads,normalize=False):
        # maximize_var
        grads_centered = IFAgg.centerize(grads)
        kernel=IFAgg.get_kernel(grads_centered)
        max_eigenvalue, max_eigenvector=IFAgg.get_max_eigen(kernel)
        if normalize:
            max_eigenvector=grads_centered.T @ max_eigenvector
            max_eigenvector=max_eigenvector/torch.norm(max_eigenvector)
            values= grads_centered @ max_eigenvector
        else:
            values= max_eigenvector @ kernel
        return values
    
    def update_weight(self,grads):
        device=grads.device
        worker_num=len(grads)
        cur_remain_workers=torch.tensor(list(range(worker_num)),device=device)
        round=0
        remainWorkerGrad=grads
        remain_workers={
            round:cur_remain_workers
        }
        scores=discriminator.Score(worker_num,device)
        end_condition_num=self.end_condition_ratio*worker_num
        
        while len(cur_remain_workers)>=end_condition_num:
            self.cur_round=round
            cur_remain_workers=self.conduct_weight_update_round(remainWorkerGrad,cur_remain_workers,scores)
            remain_workers[round+1]=cur_remain_workers
            remainWorkerGrad=grads[cur_remain_workers]
            round=round+1
            # print("round",round,torch.sum(cur_remain_workers>=200),torch.sum(cur_remain_workers<200))
        # print(remain_workers)
        return remain_workers[round-1]

    def conduct_weight_update_round(self,remainWorkerGrad,cur_remain_workers,scores):
        values=self.maximizeAndGetValue(remainWorkerGrad)
        remain_workers=self.getRemainWorkersFromValues(values,cur_remain_workers,scores)
        return remain_workers

    def getRemainWorkersFromValues(self,values,cur_remain_workers,scores):
        cur_scores=values**2
        max_cur_scores=torch.max(cur_scores)
        nomalized_cur_scores=cur_scores/max_cur_scores
        # print("nomalized_cur_scores",nomalized_cur_scores)
        scores.update(cur_remain_workers,nomalized_cur_scores)
        remain_workers_score=scores.get_scores(cur_remain_workers)
        return cur_remain_workers[remain_workers_score<self.eta]
    
    
    def aggregate(self,vecs,save_func=None):
        grad_vecs=vecs  
        remain_workers=self.update_weight(grad_vecs)
        if save_func is not None:
            save_func(remain_workers)
        return torch.mean(grad_vecs[remain_workers],dim=0)
        # weight=torch.tensor(weight,device=vecs.device)
        # vec=weight @ grad_vecs
        
    # util.FlatModelPara(vec).get_tuple(self.form_dict)
        # return training.unflatten_to_tuple_with_form_dict(vec,self.form_dict)

class AggEval():
    def __init__(self,normal_worker_num=None):
        self.normal_worker_num=normal_worker_num

    def get(self,remain_workers,chosen_workers):
        remain_workers=torch.tensor(chosen_workers)[remain_workers.to("cpu")]
        return {"chosen_normal":torch.sum(remain_workers<self.normal_worker_num).item(),
                "chosen_byzantine":torch.sum(remain_workers>=self.normal_worker_num).item()}

    def get_header(self):
        return ["chosen_normal","chosen_byzantine"]
    
    @staticmethod
    def get_fig(df,normal_num,byzantine_num,x_name="round"):
        fig={}
        fig["chosen_normal"]=save_log.Line(df.loc[:,x_name],df["chosen_normal"]/normal_num)
        fig["chosen_byzantine"]=save_log.Line(df.loc[:,x_name],df["chosen_byzantine"]/byzantine_num)
        return fig
    
    @staticmethod
    def plot(df,save_folder,normal_num,byzantine_num,x_name="round",):
        name="chosen_worker"
        fig=AggEval.get_fig(df,normal_num,byzantine_num,x_name)
        save_log.plot1(fig,f"{save_folder}/{name}.png")
    

def get_evalObjs_gradagg(train_dataset,test_dataset,workers_dataset:dataset.AggWorkersDatasetFromConf,normal_worker_num,batch_size):
    evalObjs=fed_learning.get_evalObjs_fl(train_dataset,test_dataset,workers_dataset,batch_size)
    evalObjs["grad_agg"]=AggEval(normal_worker_num)
    return evalObjs