from .comlib import *
from .test_mnist_item import conduct_rl_mnist,class_name_only,get_string_hash,run_with_oom_wait

parent_name = Path(__file__).resolve().parent.name
FOLDER=f"./tests_result/{parent_name}/item1"

def get_keyargs_hash_other(c):
    keyarg=c[0].byzantine_args.args2_ratio
    c2=copy.deepcopy(c)
    c2[0].byzantine_args.args2_ratio=0
    s_c2='-'.join(map(str, c2[:3]))
    return f"br{keyarg:.1e}",get_string_hash(s_c2)


def save_config(config_folder,c):
    print(os.path.exists(config_folder))
    if os.path.exists(config_folder):
        os.makedirs(config_folder, exist_ok=True)
        lst=[asdict(dc) for dc in c if is_dataclass(dc)]
        with open(os.path.join(config_folder, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump(lst,f,indent=2,default=class_name_only)  


# def get_save_folder(root_folder,c):
#     byz_ratio,kwhash=get_keyargs_hash_other(c)
#     # print(kwhash)
#     key_str = f"br{byz_ratio:.1e}"
#     save_folder = os.path.join(root_folder,kwhash, key_str,c[4])
#     return save_folder
# def run(cmd):
#     subprocess.run(cmd, shell=True, check=True)

def conduct_rl_mnist_parallel(load_mnist_data,cartesian,n_jobs=20,root_folder=FOLDER,get_keyargs_hash_other=get_keyargs_hash_other):
    
    print("item nums",len(cartesian))
    # commands = ["python a.py", "python b.py"]
    def run(c):
        print(cartesian.index(c))
        key_str,kwhash=get_keyargs_hash_other(c)
        print(kwhash)
        config_folder = os.path.join(root_folder,kwhash, key_str)
        save_config(config_folder,c)

        save_folder = os.path.join(config_folder,c[4])
        print(save_folder)
        os.makedirs(save_folder, exist_ok=True)

        run_with_oom_wait(
            conduct_rl_mnist,
            save_folder=save_folder,load_mnist_data=load_mnist_data,
            weightUpdaterArg=c[1],nBWorkersArgs=c[0],
            trainFLArg=c[2],grad_args=c[3],random_seed=c[4])
        # conduct_rl_mnist(save_folder,load_mnist_data,c[1],c[0],c[2],c[3])
    
    # run(cartesian[0])
    Parallel(n_jobs=n_jobs)(
        delayed(run)(c) for c in cartesian
        )
    

  
def create_random_seed(n):
    return [f"test{i}" for i in range(n)]