from .comlib import *
from .train_v1_1 import FedLearn, MeanGrad
# from .eval import PlotEval,Eval
from .eval_v1_1 import get_evalObjs,get_default_criterions


@util.repr_alias(attr_name=False)   
@dataclass
class FedLearnArgs():
    max_epoch:int
    chosenWorkerRatio:float
    save_round_inteval:int|None=None
    save_epoch_inteval:int|None=None


@util.repr_alias(attr_name=False)   
@dataclass
class TrainFLArg():
    fedArgs:FedLearnArgs
    modelArg:training.DnnArg
    optimizer_args:dict
    criterion:nn.Module|Callable=nn.CrossEntropyLoss(reduction='none')

    def fedlearn(self,grad_strategy,t,device,random_seed,
                 save_folder,save_name,
                 eval_datasets:dict,worker_dataset:dataset.AggWorkersDatasetFromConf,
                    batch_size=4000):
        eval_criterions,worker_criterions=get_default_criterions()
        evalObjs=get_evalObjs(eval_datasets,eval_criterions,worker_dataset,worker_criterions,batch_size)

        fedLearn=FedLearn(evalObjs, save_folder,grad_strategy=grad_strategy,save_name=save_name,**asdict(self.fedArgs))
        modelSetup=self.modelArg.get(t,device,random_seed)
        optimizer=training.Optimizer(modelSetup=modelSetup,**self.optimizer_args)
        fedLearn.train(worker_dataset,modelSetup,optimizer)

    def fedavg(self,t,device,random_seed,save_folder,save_name,eval_datasets,
        worker_dataset:dataset.AggWorkersDatasetFromConf,nbworkers,batch_size=4000):
        grad_strategy=MeanGrad(self.criterion,batch_size)
        self.fedlearn(grad_strategy,t,device,random_seed,save_folder,save_name,eval_datasets,
        worker_dataset,nbworkers,batch_size)
