from .comlib import *

class IFAgg():
    def __init__(self,grad_strategy,kernel_strategy,eta,end_condition_ratio):
        self.grad_strategy=grad_strategy
        self.kernel_strategy=kernel_strategy
        # criterion,batch_size,seg_len
        # self.criterion=criterion
        # self.batch_size=batch_size
        # self.seg_len=seg_len
        self.eta=eta
        self.end_condition_ratio=end_condition_ratio
    
    @staticmethod
    def maximizeAndGetValue(kernel):
        # maximize_var
        try:
            max_eigenvalue, max_eigenvector=util.Kernel.get_top_eigen(kernel,k=1)
        except torch._C._LinAlgError:
            max_eigenvector= torch.randn(len(kernel),device=kernel.device)
            max_eigenvector = max_eigenvector / max_eigenvector.norm()  
        # max_eigenvalue, max_eigenvector=util.Kernel.get_top_eigen(kernel,k=1)
        values=util.Kernel.get_projection(kernel,max_eigenvector)[:,0]
        return values
    
    def getRemainWorkersFromValues(self,values,cur_remain_workers,scores):
        cur_scores=values**2
        max_cur_scores=torch.max(cur_scores)
        nomalized_cur_scores=cur_scores/max_cur_scores
        # print("nomalized_cur_scores",nomalized_cur_scores)
        scores.update(cur_remain_workers,nomalized_cur_scores)
        remain_workers_score=scores.get_scores(cur_remain_workers)
        return cur_remain_workers[remain_workers_score<self.eta]
    
    def conduct_weight_update_round(self,remainWorkerKernel:util.Kernel,cur_remain_workers,scores):
        values=self.maximizeAndGetValue(remainWorkerKernel.get_centered_kernel())
        remain_workers=self.getRemainWorkersFromValues(values,cur_remain_workers,scores)
        return remain_workers
    
    def update_weight(self,kernel):
        device=kernel.device
        worker_num=len(kernel)
        cur_remain_workers=torch.tensor(list(range(worker_num)),device=device)
        round=0
        remainWorkerKernel=util.Kernel(kernel)
        remain_workers={
            round:cur_remain_workers
        }
        scores=discriminator.Score(worker_num,device)
        end_condition_num=self.end_condition_ratio*worker_num
        
        while len(cur_remain_workers)>=end_condition_num:
            self.cur_round=round
            cur_remain_workers=self.conduct_weight_update_round(remainWorkerKernel,cur_remain_workers,scores)
            remain_workers[round+1]=cur_remain_workers
            remainWorkerKernel=remainWorkerKernel.get_sub(list(cur_remain_workers))
            round=round+1
            # print("round",round,torch.sum(cur_remain_workers>=200),torch.sum(cur_remain_workers<200))
        # print(remain_workers)
        return remain_workers[round-1]

    
    def get(self,chosen_workers,worker_dataset,modelSetup,save_func):
        worker_model=fed_learning.WorkerModelKernel(modelSetup,worker_dataset)
        kernel=self.kernel_strategy.get(worker_model,chosen_workers)
        remain_workers=self.update_weight(kernel)
        remain_workers=torch.tensor(chosen_workers)[remain_workers.cpu()]
        if save_func is not None:
            save_func(remain_workers)

        grad=self.grad_strategy.get(remain_workers,worker_dataset,modelSetup)
        return grad


class KernelStrategy():
    def __init__(self,criterion,batch_size,seg_len,centered=True):
        self.criterion=criterion
        self.batch_size=batch_size
        self.seg_len=seg_len
        self.centered=centered

    def get(self,worker_model,chosen_workers):
        return worker_model.getChosenWorkersGradKernel(self.criterion,chosen_workers,self.batch_size,self.seg_len,priorList=None,centered=self.centered,normalize=True)





class AggEval():
    def __init__(self,normal_worker_num=None):
        self.normal_worker_num=normal_worker_num

    def get(self,remain_workers):
        return {"chosen_normal":torch.sum(remain_workers<self.normal_worker_num).item(),
                "chosen_byzantine":torch.sum(remain_workers>=self.normal_worker_num).item()}

    def get_header(self):
        return ["chosen_normal","chosen_byzantine"]
    
    @staticmethod
    def get_fig(df,normal_num,byzantine_num,x_name="round"):
        fig={}
        fig["chosen_normal"]=save_log.Line(df.loc[:,x_name],df["chosen_normal"]/normal_num)
        fig["chosen_byzantine"]=save_log.Line(df.loc[:,x_name],df["chosen_byzantine"]/byzantine_num)
        return fig
    
    @staticmethod
    def plot(df,save_folder,normal_num,byzantine_num,x_name="round",):
        name="chosen_worker"
        fig=AggEval.get_fig(df,normal_num,byzantine_num,x_name)
        save_log.plot1(fig,f"{save_folder}/{name}.png")
    


class PlotEval(fed_learning.PlotEval):
    def __init__(self, nbworkers:worker_with_byzantine.NormalByzantineConf):
        self.nbworkers=nbworkers
        super().__init__(nbworkers.worker_num)

    def get_evalObjs(self,train_dataset,test_dataset,workers_dataset:dataset.AggWorkersDatasetFromConf,batch_size):
        evalObjs=super().get_evalObjs(train_dataset,test_dataset,workers_dataset,batch_size)
        evalObjs["grad_agg"]=AggEval(self.nbworkers.normal_num)
        return evalObjs
    
    def get_figs(self,save_folder):
        x_name="round"
        figs={"loss":{},"acc":{},"out_norm":{},
            "weight_norm":{},
            "grad_norm":{}}
        df=self.get_df(save_folder)
        for cri_name in ["out_norm"]:
            lines={}
            for dataset_name in ["train","test"]:
                line_name=dataset_name
                lines[line_name]=training.Eval11.get_line(df,x_name,cri_name,dataset_name)
            figs[cri_name]=lines

        for cri_name in self.criterions_unreduced:
            lines={}
            for dataset_name in ["train","test"]:
                line_name=dataset_name
                lines[line_name]=training.Eval11.get_line(df,x_name,cri_name,dataset_name)
                
            for line_name in ["normal","byzantine"]:
                ids=self.nbworkers.get_ids(line_name)
                lines[line_name]=fed_learning.Eval.get_line(df,x_name,cri_name,"worker",ids)
            figs[cri_name]=lines

        figs["chosen_worker"]=AggEval.get_fig(df,self.nbworkers.normal_num,self.nbworkers.byzantine_num,x_name)

        for fig_name in ["weight_norm","grad_norm"]:
            figs[fig_name][0]=save_log.get_line_from_df(df,fig_name,x_name)
        return figs