import argparse
import os

from helper import get_hyperparams, load_models, parse_datasets
from project_location import DATASETS_ROOT, FEATURES_ROOT, MODELS_ROOT, RESULTS_ROOT
from slurm import run_job

parser = argparse.ArgumentParser()
parser.add_argument(
    "--models_config", type=str, default="./scripts/configs/models_config_wo_alignment.json"
)
parser.add_argument(
    "--datasets",
    type=str,
    nargs="+",
    default="./scripts/configs/webdatasets_attentive_probe.txt",
    help="datasets can be a list of dataset names or a file (e.g., webdatasets.txt) containing dataset names.",
)

parser.add_argument(
    "--models_combination", type=str, default="./scripts/configs/exp_4_number_heads.txt"
)

args = parser.parse_args()

MODELS_CONFIG = args.models_config
DATASETS = " ".join(parse_datasets(args.datasets))

if __name__ == "__main__":

    if os.path.isfile(args.models_combination):
        with open(args.models_combination, "r") as f:
            model_combinations = [
                [m.strip() for m in line.split(";")] for line in f if line.strip()
            ]
    else:
        raise ValueError("The file does not exist", args.models_combination)

    
    models, n_models = load_models(MODELS_CONFIG)

    hyper_params, num_jobs = get_hyperparams(num_seeds=1, size="n_heads")

    val_proportion = 0.2

    for model_set in model_combinations:
        for n_heads in [1,2,4,16,32]:
            assert all([key in models.keys() for key in model_set])
            model_keys = " ".join(model_set)
            dim = max([models[mid]["embedding_dim"] for mid in model_set])

            job_cmd = f"""export XLA_PYTHON_CLIENT_PREALLOCATE=false && \
            export XLA_PYTHON_CLIENT_ALLOCATOR=platform && \
            sim_consistency --dataset {DATASETS} \
                            --dataset_root {DATASETS_ROOT} \
                            --feature_root {FEATURES_ROOT} \
                            --model_root {MODELS_ROOT} \
                            --output_root {RESULTS_ROOT} \
                            --task=attentive_probe \
                            --mode=combined_models \
                            --feature_combiner tuple \
                            --model_key {model_keys} \
                            --models_config_file {MODELS_CONFIG} \
                            --batch_size=2048 \
                            --dim {dim} \
                            --num_heads {n_heads} \
                            --dimension_alignment zero_padding \
                            --fewshot_k {" ".join(hyper_params["fewshot_ks"])} \
                            --initial_lr {" ".join(hyper_params["initial_lrs"])} \
                            --epochs {" ".join(hyper_params["epochs"])} \
                            --reg_lambda {hyper_params["reg_lambda"]} \
                            --regularization {" ".join(hyper_params["regularization"])} \
                            --train_split train \
                            --test_split test \
                            --val_proportion {val_proportion} \
                            --seed {" ".join(hyper_params["seeds"])}
            """

            run_job(
                job_name=f"attn_probe_with_heads{n_heads}",
                job_cmd=job_cmd,
                partition="gpu-2d",
                log_dir=f"{RESULTS_ROOT}/logs",
                num_jobs_in_array=num_jobs,
                mem=100,
            )
