import os
from glob import glob
from itertools import product

import yaml
import pandas as pd
from exp01_hps import load_template, get_all_combinations, save_exp_yaml, replace_value
from exp02_clf import get_df_hps_result_path

root = "PATH/TO/HPS-RESULT"

def load_from_exp_yaml(exp_id, keys):
    target_file = os.path.join(
        f"exp{exp_id//100:02d}s",
        f"exp{exp_id:04d}.yaml"
    )
    with open(target_file, "r") as f:
        exp_info = yaml.safe_load(f)
    info = {key: exp_info[key]["param_val"] for key in keys}
    info["exp_id"] = exp_id
    info["result_path_hps"] = "N/A"
    info["result_path_clf"] = "N/A"
    return info

def fetch_exp_info(exp_id, keys):
    """
    Args:
        exp_id (_description_): _description_

    Returns:
        _description_: _description
    """
    # Load the experiment configuration file
    exp_dir = os.path.join(
        root,
        f"ssl-clf-exp{exp_id//100:02d}s",
        f"exp{exp_id:04d}"
    )
    target_dir = sorted(glob(os.path.join(exp_dir, "??????-??????")))
    if len(target_dir) == 0:
        return load_from_exp_yaml(exp_id, keys)
    
    target_file = os.path.join(target_dir[-1], "exp_config_src.yaml")
    if not os.path.exists(target_file):
        return load_from_exp_yaml(exp_id, keys)        

    with open(target_file, "r") as f:
        exp_info = yaml.safe_load(f)

    # Extract only the keys specified in the argument
    info = {key: exp_info[key] for key in keys}
    info["exp_id"] = exp_id
    info["result_path_hps"] = os.path.join(
        target_dir[-1], "ResultTableHPS.csv")
    clf_result_path = os.path.join(
        target_dir[-1], "ResultTableMultiSeed.csv")
    if os.path.exists(clf_result_path):
        info["result_path_clf"] = clf_result_path
    else:
        info["result_path_clf"] = "N/A"
    return info

def get_df_hps_result_path(exp_ids, keys):
    """
    Returns:
        str: _description_
    """
    info_list = []
    for exp_id in exp_ids:
        info = fetch_exp_info(exp_id, keys)
        info_list.append(info)
    df_info = pd.DataFrame(info_list)
    return df_info

def load_df_result_path(df, settings, load_hps_result=True):
    """
    Args:
        df (pd.DataFrame): _description_
        settings (List): [ft_target, dataset, target_dx]
    Returns:
        str
    """

    ft_target, dataset, target_dx = settings
    df = df[
        (df.finetune_target == ft_target) &
        (df.dataset == dataset) &
        (df.target_dx == target_dx)
    ]
    if len(df) == 0:
        return None
    if load_hps_result:
        return df["result_path_hps"].values[0]
    return df["result_path_clf"].values[0]

def generate_exp_yaml(
    exp_id_start, 
    template_id, 
    val_replace_dict,
    df_hps_result_path
):
    """
    Args:
        exp_id (_description_): _description_
        template_id (_description_): _description_

    Returns:
        str: _description_
    """
    template = load_template(template_id)
    
    prepared_exp_ids = []
    all_combinations = get_all_combinations(val_replace_dict)
    for n_proc, comb in enumerate(all_combinations):
        exp_yaml = template
        exp_id = exp_id_start + n_proc

        for key, val in comb.items():
            exp_yaml = replace_value(exp_yaml, key, val)

        # Insert hps_path.
        settings = [val for val in comb.values()]
        hps_path = load_df_result_path(
            df_hps_result_path, settings, load_hps_result=True)
        if hps_path == "N/A":
            print(exp_id, settings)
            continue
        # We only execute exp_ids with no clf result.
        clf_path = load_df_result_path(
            df_hps_result_path, settings, load_hps_result=False)
        if clf_path == "N/A":
            exp_yaml = replace_value(exp_yaml, "VAL04", hps_path)
            # print(exp_id, settings, "YAML prepared")
            save_exp_yaml(exp_yaml, exp_id)
            prepared_exp_ids.append(exp_id)
    return exp_id, prepared_exp_ids
    
if __name__ == "__main__":
    exp_yaml_start = 601
    template_id = 2
    prepared_exp_ids = []

    # Prepare HPS result paths.
    exp_ids = list(range(401, 521))
    target_keys = ["target_dx", "dataset", "finetune_target"]
    df_hps_result_path = get_df_hps_result_path(
        exp_ids,
        target_keys
    )

    # PTBXL (exp01b)
    val_replace_dict = {
        "VAL01": [
            "pt0001", "pt0002", "pt0003",
            "pt0005", "pt0006",
            "pt0007", "pt0008", 
            "pt0010"
        ],
        "VAL02": ["ptbxl"],
        "VAL03": [
            "wpw", "aflt"
        ],
    }
    last_id, _prepared_exp_ids = generate_exp_yaml(
        exp_yaml_start, 
        template_id, 
        val_replace_dict,
        df_hps_result_path
    )
    prepared_exp_ids.extend(_prepared_exp_ids)

    # G12EC (exp01c)
    val_replace_dict = {
        "VAL01": [
            "pt0001", "pt0002", "pt0003",
            "pt0005", "pt0006",
            "pt0007", "pt0008", 
            "pt0010"
        ],
        "VAL02": ["g12ec"],
        "VAL03": [
            "af", 
            "pvc", 
            "lvh", 
            "irbbb",
            "iavb", 
            "pac", 
            "rbbb"
        ],
    } 
    last_id, _prepared_exp_ids  = generate_exp_yaml(
        last_id+1, 
        template_id, 
        val_replace_dict,
        df_hps_result_path
    )
    prepared_exp_ids.extend(_prepared_exp_ids)

    # CPSC (exp01d)
    val_replace_dict = {
        "VAL01": [
            "pt0001", "pt0002", "pt0003",
            "pt0005", "pt0006",
            "pt0007", "pt0008", 
            "pt0010"
        ],
        "VAL02": ["cpsc"],
        "VAL03": [
            "af", 
            "iavb", 
            "pac", 
            "pvc",
            "std",
            "rbbb"
        ],
    }
    last_id, _prepared_exp_ids  = generate_exp_yaml(
        last_id+1, 
        template_id, 
        val_replace_dict,
        df_hps_result_path
    )
    prepared_exp_ids.extend(_prepared_exp_ids)

    print(len(prepared_exp_ids))
    print(",".join([str(s) for s in prepared_exp_ids]))

