from .comlib import *
from .worker_model_v1_1 import WorkerModelR

class FedLearn(training.Train11):
    def __init__(self, evalObjs, save_folder, batch_size, max_epoch, chosenWorkerNum, agg_rule,save_round_inteval=None, save_epoch_inteval=None):
        super().__init__(evalObjs, save_folder, batch_size, max_epoch, save_round_inteval, save_epoch_inteval)
        self.agg_rule=agg_rule
        self.chosenWorkerNum=chosenWorkerNum

    def get_agg_grad_tuple(self,grads:util.FlatModelParaS,form_dict,save_flag,chosen_workers)->tuple[torch.tensor]:
        save_func=self.get_save_func("grad_agg",save_flag,chosen_workers)
        grad=self.agg_rule.aggregate(grads.tensors,save_func)
        grad=util.FlatModelPara(grad)
        return grad.get_tuple(form_dict)
    
    def get_save_func(self,key,flag,*args):
        if flag:
            def save_func(x):
                r=self.evalObjs[key].get(x,*args)
                self.saveObj.add_data_to_row(r)
            return save_func
        else:
            return None
        
    def save_model_stat(self,modelSetup,flag):
        if flag:
            self.saveObj.add_data_to_row({"weight_norm":self.get_weight_norm(modelSetup)})
            for key in ["train","test","worker"]:
                r=self.evalObjs[key].get(modelSetup)
                self.saveObj.add_data_to_row(r)

    def train(self,worker_dataset,modelSetup:training.ModelSetup,
          optimizer,criterion):
        worker_num=worker_dataset.worker_num
        workerModel=WorkerModelR(modelSetup,worker_dataset)
        round_per_epoch=int(worker_num/self.chosenWorkerNum)

        for epoch in range(self.max_epoch):
            for round in range(round_per_epoch):
                flag=self.get_save_flag(epoch,round_per_epoch,round)
                self.save_model_stat(modelSetup,flag)

                chosen_workers=random.sample(range(worker_num), k=self.chosenWorkerNum)
                grads=workerModel.getChosenWorkersGrad(criterion,chosen_workers,self.batch_size,flat=True)
                grad=self.get_agg_grad_tuple(grads,modelSetup.form_dict,flag,chosen_workers)
                self.save_grad_norm(grad,flag)

                optimizer.stepByGradTuple(grad)
                self.saveObj.save_cur_row()


class MeanAgg():
    def __init__(self):
        pass

    def aggregate(self,grad_tuples:torch.tensor,*args):
        return torch.mean(grad_tuples,dim=0)
        # modelPara=grad_tuples.mean()
        # return modelPara.tensors