from .comlib import *
from .test_mnist_parallel import conduct_rl_mnist_parallel,create_random_seed

name = Path(__file__).resolve().stem
FOLDER=f"./tests_result/{name}"

def create_nbWorkersArgs_from_normalArg(normalArg,data_num=30,byz_num=50):
    arg2_ratios=[x*0.3 for x in [1,2,3]]
    byzArgs=[]
    for arg2_ratio in arg2_ratios:
        byzArgs.append(ConcatArgs(data_num,byz_num,"s","swl",arg2_ratio))

    r=[]
    for byzArg in byzArgs:
        r.append(
            NBWorkersArgsRef(normalArg,byzArg,byz_num*2)
        )
    return r


def create_nbWorkersArgs_(conf:Literal["iid","non_iid"],data_num=30,normal_num=200,byz_num=50):

    if conf=="iid":
        normalArgs=[WorkersArgs("iid",data_num,normal_num)]
    else:
        normalArgs=[NonIidClassWorkersArgs(data_num,normal_num,alpha=1.0)]

    r=[]
    for normalArg in normalArgs:
        r=r+create_nbWorkersArgs_from_normalArg(normalArg,data_num,byz_num)

    return r

def create_nbWorkersArgs(conf):
    r=[]
    r=r+create_nbWorkersArgs_(conf,data_num=30,normal_num=200,byz_num=50)
    return r


def create_weightUpdaterArg():
    rs=[]
    layer_nums= [2]
    hidden_nums=[128]
    ecr_biases=[0.0]
    cartesian = list(product(layer_nums, hidden_nums,ecr_biases))
    for layer_num,hidden_num,ecr in cartesian:
        r=WeightUpdaterArg(
            DiscriminatorIFArgs(1.0,ecr,True,False),
            FedLearnArgs(20,0.2,save_epoch_inteval=4),
            GradArgs(MvArgs(0.01),LinearArgs('ce')),
            DnnArgDiscrim(layer_num,hidden_num),
            {"name":"Adam","lr":0.001}
        )
        rs.append(r)
    return rs

def create_trainFLArg():
    rs=[]
    layer_num_lrs= [(2,0.01)]
    hidden_nums=[128]
    chosen_ratios=[0.2] 
    cartesian = list(product(layer_num_lrs, hidden_nums,chosen_ratios))
    for (layer_num,lr),hidden_num,chosen_ratio in cartesian:
        r=TrainFLArg(
            FedLearnArgs(20,chosen_ratio,save_epoch_inteval=4),
            DnnArg('default',layer_num,hidden_num),
            {"name":"Adam","lr":lr}
        )
        rs.append(r)
    return rs

def create_grad_args():
    rs=[]
    r=[(GradAggArg("krum")),
    (GradAggArg("iterative_filtering")),
    (GradAggArg("geo_median_w")),]
    rs.append(r)
    return rs

    
def test_poison_ratio(conf:Literal["iid","non_iid"]):
    load_mnist_data=task.MnistTask().load_mnist_data()
    # date = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    cartesian = list(product(create_nbWorkersArgs(conf), 
                             create_weightUpdaterArg(),
                             create_trainFLArg(),
                             create_grad_args(),
                             create_random_seed(5),
                             ))
    conduct_rl_mnist_parallel(load_mnist_data,cartesian,n_jobs=30,
                              root_folder=os.path.join(FOLDER))
  
