from .comlib import *
from . iterative_filter_v1_1 import KernelStrategy,IFAgg,PlotEval


def create_agg_kernel(criterion,batch_size,seg_len,byzantine_ratio,
                      name,centered_kernel,**kwargs):
    kernel_strategy=KernelStrategy(criterion,batch_size,seg_len,centered_kernel)
    if name=="ifagg":
        grad_strategy=fed_learning.MeanGrad(criterion,batch_size)
        end_condition_ratio=1-2*byzantine_ratio
        end_condition_ratio=end_condition_ratio+kwargs.get("end_condition_ratio_bias")
        return IFAgg(grad_strategy,kernel_strategy,kwargs.get("eta"),end_condition_ratio)
    else:
        grad_strategy=fed_learning.LinearGradR(criterion,batch_size)
        return robust_grad_agg2.RGAStrategyKernel(
            grad_strategy,kernel_strategy,byzantine_ratio,name,centered_kernel)

@util.repr_alias(attr_name=False)
@dataclass 
class GradAggArg():
    name:Literal["ifagg","krum","coordinatewise","iterative_filtering","geo_median","geo_median_w"]
    centered_kernel:bool=field(default=True)
    def get(self, criterion,batch_size,seg_len,byzantine_ratio):
        
        agg_rule=create_agg_kernel(criterion,batch_size,seg_len,byzantine_ratio,**asdict(self))
        return agg_rule

@util.repr_alias(attr_name=False)
@dataclass
class IFArg(GradAggArg):
    name:str=field(default="ifagg",init=False)
    eta:float
    end_condition_ratio_bias:float=field(default=0)
    centered_kernel:bool=field(default=True,init=False)
    

def robustfedlearn(trainFLArg:fed_learning.TrainFLArg,grad_arg:GradAggArg,
                   t,device,random_seed,save_folder,save_name,
             eval_datasets,worker_dataset:dataset.AggWorkersDatasetFromConf,
             nbworkers,batch_size=4000,seg_len=100):    
    
    eval_criterions,worker_criterions=fed_learning.get_default_criterions()
    evalObjs=fed_learning.get_evalObjs(eval_datasets,eval_criterions,worker_dataset,worker_criterions,batch_size)

    grad_strategy=grad_arg.get(trainFLArg.criterion,batch_size,seg_len,
                               nbworkers.get_byzantine_ratio())
    fedLearn=fed_learning.FedLearn11(evalObjs, save_folder,grad_strategy=grad_strategy,save_name=save_name,**asdict(trainFLArg.fedArgs))
    modelSetup=trainFLArg.modelArg.get(t,device,random_seed)
    optimizer=training.Optimizer(modelSetup=modelSetup,**trainFLArg.optimizer_args)
    fedLearn.train(worker_dataset,modelSetup,optimizer) 



# def get(self, criterion,batch_size,seg_len):
#         grad_strategy=fed_learning.MeanGrad(criterion,batch_size)
#         kernel_strategy=KernelStrategy(criterion,batch_size,seg_len)

#         return IFAgg(grad_strategy,kernel_strategy,self.eta,self.end_condition_ratio)