from .comlib import *
from fed_learning import WorkerModelKernel,FedLearn


# def create_train(iid_worker_num,grad_strategy):
#     plotEval=fed_learning.PlotEval(iid_worker_num)
#     evalObjList=plotEval.get_evalObjs(train_dataset, test_dataset, workersDataset,batch_size)
#     grad_strategy=fed_learning.MeanGrad(criterion,batch_size)
#     trainObj=fed_learning.FedLearn11(evalObjList,save_folder,max_epoch,chosenWorkerNum,grad_strategy,save_epoch_inteval=30)
#     trainObj.train(workersDataset,modelSetup,
#           optimizer,)
    
#     plotEval.plot(save_folder)


# class FedLearn3(FedLearn):
#     def __init__(self, evalObjs, save_folder, batch_size, max_epoch, chosenWorkerNum, agg_rule,save_round_inteval=None, save_epoch_inteval=None):
#         super().__init__(evalObjs, save_folder, batch_size, max_epoch, save_round_inteval, save_epoch_inteval)
#         self.agg_rule=agg_rule
#         self.chosenWorkerNum=chosenWorkerNum

#     def train(self,worker_dataset,modelSetup:training.ModelSetup,
#           optimizer,criterion):
#         worker_num=worker_dataset.worker_num
#         workerModel=WorkerModelKernel(modelSetup,worker_dataset)
#         round_per_epoch=int(worker_num/self.chosenWorkerNum)

#         for epoch in range(self.max_epoch):
#             for round in range(round_per_epoch):
#                 flag=self.get_save_flag(epoch,round_per_epoch,round)
#                 self.save_model_stat(modelSetup,flag)

#                 chosen_workers=random.sample(range(worker_num), k=self.chosenWorkerNum)
#                 kernel=workerModel.getChosenWorkersGradKernel(criterion,chosen_workers,self.batch_size,flat=True)
#                 chosen_workers=self.agg_rule(kernel)

#                 grad=self.get_agg_grad_tuple(grads,modelSetup.form_dict,flag,chosen_workers)
#                 self.save_grad_norm(grad,flag)

#                 optimizer.stepByGradTuple(grad)
#                 self.saveObj.save_cur_row()

