import itertools

from configs.exp_configs import data_list, data_list_linked, baselines, linked_baselines
from configs.path_configs import path_configs


def generate_experiments(
    num_train=[64, 256, 1024], device="cpu", output_file="experiments.txt"
):
    """
    Generate all (dataset,model,num_train) combinations.
    Random state is NOT included here (it will be handled inside script_evaluate.py).
    """
    # Remove "llm-row_llama-3.1-70b", "kg_kgt5-small_", and "kg_kgt5_" from baselines
    models = [
        b
        for b in baselines
        if b not in ["llm-row_llama-3.1-70b", "kg_kgt5-small_", "kg_kgt5_"]
    ]

    # Generate all combinations
    args_dict = {
        "data_name": data_list,
        "method": [base + "_xgb-pca" for base in models],
        "num_train": num_train,
    }

    missing_exps = get_missing_experiments(
        args_dict,
        path_configs["results"] / "llm_kg_comparison",
        random_states=range(1, 11),
    )

    # Sort by num_train explicitly
    grid_sorted = sorted(missing_exps, key=lambda x: x["num_train"])

    with open(output_file, "w") as f:
        for cfg in grid_sorted:
            f.write(f"{cfg['data_name']} {cfg['method']} {cfg['num_train']} {device}\n")

    print(f"Saved {len(grid_sorted)} experiments to {output_file}")


def generate_linked_experiments(
    num_train=[64, 256, 1024], device="cpu", output_file="linked_experiments.txt"
):
    """
    Generate all (dataset,model,num_train) combinations.
    Random state is NOT included here (it will be handled inside script_evaluate.py).
    """
    # Remove "llm-row_llama-3.1-70b", "kg_kgt5-small_", and "kg_kgt5_" from linked_baselines
    models = [
        b
        for b in linked_baselines
        if b not in ["llm-row_llama-3.1-70b", "kg_kgt5-small_", "kg_kgt5_"]
    ]

    # Generate all combinations
    args_dict = {
        "data_name": data_list_linked,
        "method": [base + "_xgb-pca" for base in models],
        "num_train": num_train,
    }

    missing_exps = get_missing_experiments(
        args_dict,
        path_configs["results"] / "llm_kg_comparison_linked",
        random_states=range(1, 11),
    )

    # Sort by num_train explicitly
    grid_sorted = sorted(missing_exps, key=lambda x: x["num_train"])

    with open(output_file, "w") as f:
        for cfg in grid_sorted:
            f.write(f"{cfg['data_name']} {cfg['method']} {cfg['num_train']} {device}\n")

    print(f"Saved {len(grid_sorted)} experiments to {output_file}")


def get_missing_experiments(args_dict, results_dir_path, random_states=range(1, 11)):
    """
    Return a list of experiments (dataset, method, num_train) that are incomplete.
    An experiment is considered incomplete if at least one random_state result file is missing.

    Parameters
    ----------
    args_dict : dict
        Dictionary containing experiment parameters (data_name, method, num_train, device).
    results_dir_path : str
        Path to the results directory.
    random_states : iterable, optional
        List of random states to check. Default is range(1, 11).

    Returns
    -------
    missing_experiments : list of dict
        Each dict describes a missing experiment with keys:
        {"data_name", "method", "num_train"}
    """

    missing_experiments = []

    for data_name, method, num_train in itertools.product(
        args_dict["data_name"], args_dict["method"], args_dict["num_train"]
    ):
        dataset_dir = results_dir_path / data_name / "score"
        if not dataset_dir.exists():
            # If dataset folder does not exist, all experiments are missing
            missing_experiments.append(
                {
                    "data_name": data_name,
                    "method": method,
                    "num_train": num_train,
                }
            )
            continue

        for rs in random_states:
            expected_file = f"{data_name}|{method}|nt-{num_train}|rs-{rs}.csv"
            expected_path = dataset_dir / expected_file
            if not expected_path.exists():
                missing_experiments.append(
                    {
                        "data_name": data_name,
                        "method": method,
                        "num_train": num_train,
                    }
                )
                break  # No need to check further random states for this experiment

    return missing_experiments


if __name__ == "__main__":
    generate_experiments()
    generate_linked_experiments()
