from .comlib import *
from .test_mnist_item import conduct_rl_mnist,class_name_only,get_string_hash,run_with_oom_wait
from .test_mnist_parallel import conduct_rl_mnist_parallel,create_random_seed

name = Path(__file__).resolve().stem
FOLDER=f"./tests_result/{name}"

def get_keyargs_hash_other(c):
    keyarg=c[0].byzantine_args.data_num
    c2=copy.deepcopy(c)
    c2[0].byzantine_args.data_num=0
    c2[0].normal_args.data_num=0
    s_c2='-'.join(map(str, c2[:3]))
    return f"dn{keyarg:.1e}",get_string_hash(s_c2)


def create_nbWorkersArgs():
    normal_num,byz_num=200,50
    data_nums=[30*i for i in range(1,5)]
    r=[]
    for data_num in data_nums:
        normalArg=WorkersArgs("iid",data_num,normal_num)
        byzArg=ConcatArgs(data_num,byz_num,"s","swl",1.0)
        r.append(
            NBWorkersArgsRef(normalArg,byzArg,byz_num*2)
        )

    return r


def create_weightUpdaterArg():
    rs=[]
    layer_nums= [2]
    hidden_nums=[128]
    ecr_biases=[0.0]
    cartesian = list(product(layer_nums, hidden_nums,ecr_biases))
    for layer_num,hidden_num,ecr in cartesian:
        r=WeightUpdaterArg(
            DiscriminatorIFArgs(1.0,ecr,True,False),
            FedLearnArgs(20,0.2,save_epoch_inteval=4),
            GradArgs(MvArgs(0.01),LinearArgs('ce')),
            DnnArgDiscrim(layer_num,hidden_num),
            {"name":"Adam","lr":0.001}
        )
        rs.append(r)
    return rs

def create_trainFLArg():
    rs=[]
    layer_num_lrs= [(2,0.01)]
    hidden_nums=[128]
    chosen_ratios=[0.2] 
    cartesian = list(product(layer_num_lrs, hidden_nums,chosen_ratios))
    for (layer_num,lr),hidden_num,chosen_ratio in cartesian:
        r=TrainFLArg(
            FedLearnArgs(20,chosen_ratio,save_epoch_inteval=4),
            DnnArg('default',layer_num,hidden_num),
            {"name":"Adam","lr":lr}
        )
        rs.append(r)
    return rs

def create_grad_args():
    rs=[]
    r=[(GradAggArg("krum")),
    (GradAggArg("iterative_filtering")),
    (GradAggArg("geo_median_w")),]
    rs.append(r)
    return rs



    
def test_data_num():
    load_mnist_data=task.MnistTask().load_mnist_data()
    # date = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    cartesian = list(product(create_nbWorkersArgs(), 
                             create_weightUpdaterArg(),
                             create_trainFLArg(),
                             create_grad_args(),
                             create_random_seed(5),
                             ))
    conduct_rl_mnist_parallel(load_mnist_data,cartesian,n_jobs=30,
                              root_folder=os.path.join(FOLDER),
                              get_keyargs_hash_other=get_keyargs_hash_other)
  
