import os
from itertools import product

from exp01_hps import load_template, get_all_combinations, save_exp_yaml, replace_value
from exp02_clf import get_df_hps_result_path, load_df_hps_result_path
# from exp02b_clf import load_df_result_path

def generate_exp_yaml(
    exp_id_start, 
    template_id, 
    base_pt_model,
    val_replace_dict,
    df_hps_result_path
):
    """
    Args:
        exp_id (_description_): _description_
        template_id (_description_): _description_

    Returns:
        str: _description_
    """
    template = load_template(template_id)
    
    prepared_exp_ids = []
    all_combinations = get_all_combinations(val_replace_dict)
    for n_proc, comb in enumerate(all_combinations):
        exp_yaml = template
        exp_id = exp_id_start + n_proc

        for key, val in comb.items():
            exp_yaml = replace_value(exp_yaml, key, val)

        # Insert hps_path.
        settings = [val for val in comb.values()]
        settings[0] = base_pt_model # Replace `pt-progress-XXE6` with `pt0006`
        hps_path = load_df_hps_result_path(
            df_hps_result_path, settings)
        if hps_path == "N/A":
            print(exp_id, settings)
            continue

        exp_yaml = replace_value(exp_yaml, "VAL04", hps_path)
        # print(exp_id, settings, "YAML prepared")
        save_exp_yaml(exp_yaml, exp_id)
        prepared_exp_ids.append(exp_id)
    return exp_id, prepared_exp_ids

if __name__ == "__main__":
    exp_yaml_start = 2001
    template_id = 3
    base_pt_model = "pt0006" # Syn-MAE

    exp_ids = list(range(1, 105))
    exp_ids += list(range(401, 521))
    target_keys = ["target_dx", "dataset", "finetune_target"]
    df_hps_result_path = get_df_hps_result_path(
        exp_ids,
        target_keys
    )

    # vals.
    val_replace_dict = {
        "VAL01": [
            "progress-pt0006-001E6",
            "progress-pt0006-005E6",
            "progress-pt0006-010E6",
            "progress-pt0006-050E6",
            "progress-pt0006-100E6",
            "progress-pt0006-500E6",
            "progress-pt0006-999E6",
        ]
    }

    ## PTBXL
    val_replace_dict["VAL02"] = ["ptbxl"]
    val_replace_dict["VAL03"] = [
        "af", "asmi", "abqrs", "crbbb", "imi", "irbbb", "isc", "lafb", 
        "lvh", "pac", "pvc", "std", "1avb", 
    ]
    print(val_replace_dict)
    last_id, _ = generate_exp_yaml(
        exp_yaml_start, 
        template_id, 
        base_pt_model,
        val_replace_dict,
        df_hps_result_path
    )
    print("->", last_id)


    ## G12EC
    val_replace_dict["VAL02"] = ["g12ec"]
    val_replace_dict["VAL03"] = ["af", "pvc", "lvh", "irbbb", "iavb", "pac", "rbbb"]
    print(val_replace_dict)

    last_id, _ = generate_exp_yaml(
        last_id + 1, 
        template_id, 
        base_pt_model,
        val_replace_dict,
        df_hps_result_path
    )    
    print("->", last_id)

    ## CPSC
    val_replace_dict["VAL02"] = ["cpsc"]
    val_replace_dict["VAL03"] = ["af", "iavb", "pac", "pvc", "std", "rbbb"]
    print(val_replace_dict)

    last_id, _ = generate_exp_yaml(
        last_id + 1, 
        template_id, 
        base_pt_model,
        val_replace_dict,
        df_hps_result_path
    )   
    print(f"Last ID for pt0009: {last_id}")