
from exp01_hps import (
    generate_exp_yaml, 
    get_all_combinations, 
    load_template,
    replace_value,
    save_exp_yaml
)
from exp02_clf import (
    get_df_hps_result_path,
    load_df_hps_result_path
)

def generate_exp_yaml(
    exp_id_start, 
    template_id, 
    val_replace_dict,
    df_hps_result_path
):
    """
    Args:
        exp_id (_description_): _description_
        template_id (_description_): _description_

    Returns:
        str: _description_
    """
    template = load_template(template_id)
    
    all_combinations = get_all_combinations(val_replace_dict)
    for n_proc, comb in enumerate(all_combinations):
        exp_yaml = template
        exp_id = exp_id_start + n_proc

        for key, val in comb.items():
            exp_yaml = replace_value(exp_yaml, key, val)

        # Insert hps_path.
        settings = [val for val in comb.values()] # [pt_id, dataset, dx]
        # `DX-lead_XX` -> `DX`
        settings[2] = settings[2].split("-")[0]
        hps_path = load_df_hps_result_path(df_hps_result_path, settings)
        if hps_path == "N/A":
            print(exp_id, settings)
            continue
        exp_yaml = replace_value(exp_yaml, "VAL04", hps_path)
        save_exp_yaml(exp_yaml, exp_id)
    return exp_id, exp_yaml

if __name__ == "__main__":
    exp_yaml_start = 4001
    template_id = 3

    # Load hyperparameter search results.
    exp_ids = list(range(1, 105))
    target_keys = ["target_dx", "dataset", "finetune_target"]
    df_hps_result_path = get_df_hps_result_path(
        exp_ids,
        target_keys
    )

    # Prepare all dx-lead combinations.
    dx = [
        "af", "asmi", "abqrs", "crbbb", "imi", "irbbb", "isc", "lafb", 
        "lvh", "pac", "pvc", "std", "1avb", 
    ]
    lead = [
        "i", "iii", "avr", "avl", "avf",
        "v1", "v2", "v3", "v4", "v5", "v6"
    ]
    comb_dict = {
        "lead": lead,
        "dx": dx
    }
    all_dx_lead_comb = get_all_combinations(comb_dict)
    all_dx_lead_comb = [f"{d['dx']}-lead_{d['lead']}" for d in all_dx_lead_comb]

    # prepare `val_replace_dict`.
    val_replace_dict = {}
    val_replace_dict["VAL01"] = ["pt0006"]
    val_replace_dict["VAL02"] = ["ptbxl"]
    val_replace_dict["VAL03"] = all_dx_lead_comb

    last_id, sample_yaml = generate_exp_yaml(
        exp_yaml_start, 
        template_id, 
        val_replace_dict,
        df_hps_result_path
    )
    print("-"*80)
    print("Sample YAML:")
    print(sample_yaml)
    print("-"*80)
    print(f"Last ID for pt0006: {last_id}")
