import argparse
import os
from pathlib import Path

from scripts.helper import get_hyperparams, load_models, parse_datasets
from scripts.project_location import CLUSTERING_ROOT, DATASETS_ROOT, FEATURES_ROOT, MODELS_ROOT, RESULTS_ROOT
from scripts.slurm import run_job

parser = argparse.ArgumentParser()
parser.add_argument(
    "--models_config",
    type=str,
    # default="./scripts/configs/models_config_single_model_layer_combination.json"
    default="./scripts/configs/models_config_single_model_all_tokens_last_layer_old.json",
)
parser.add_argument(
    "--datasets",
    type=str,
    nargs="+",
    default="./scripts/configs/webdataset_configs/webdatasets_evaluation_probes_small.txt",
    help="datasets can be a list of dataset names or a file (e.g., webdatasets.txt) containing dataset names.",
)

parser.add_argument(
    "--models_combination",
    type=str,
    # default="./scripts/configs/model_combinations_equally_spaced_dinov2.txt",
    # default="./scripts/configs/model_combinations_all_tokens_last_layer_dinov2_B.txt",
    default="./scripts/configs/model_combinations_all_tokens_last_layer_clip_b_openai.txt",
    help="File containing model combinations to evaluate with an attentive probe.",
)
parser.add_argument(
    "--num_clusters",
    type=int,
    default=-1,
    help="Number of clusters to use for the attentive probe."
    " Models/Layers of one model will be clustered based on their representational similarity,"
    " and only one representative model from each cluster will be used by the probe."
    " If -1, clustering is disabled and all models are used.",
)
parser.add_argument(
    "--clustering_similarity_method",
    type=str,
    default="cka_kernel_linear_unbiased",
    help="Method to use for model similarity task during clustering.",
)

parser.add_argument(
    "--num_heads",
    nargs="+",
    type=int,
    default=[24],  # [4, 8, 24, 48]
    help="Number of attention heads for the attentive probe.",
)

args = parser.parse_args()

MODELS_CONFIG = args.models_config
DATASETS = " ".join(parse_datasets(args.datasets))

if __name__ == "__main__":
    if os.path.isfile(args.models_combination):
        with open(args.models_combination, "r") as f:
            model_combinations = [[m.strip() for m in line.split(";")] for line in f if line.strip()]
    else:
        raise ValueError("The file does not exist", args.models_combination)

    models, n_models = load_models(MODELS_CONFIG)

    hyper_params, num_jobs = get_hyperparams(num_seeds=1, size="imagenet1k")

    val_proportion = 0.2

    dim_alignment = "zero_padding"

    for n_heads in args.num_heads:
        for proj_drop, att_drop in hyper_params["attention_dropout"]:
            for model_set in model_combinations:
                assert all([key.split("@")[0] in models.keys() for key in model_set])

                print(f"\nRunning attentive probe for {model_set}\n")
                model_keys = " ".join(model_set)

                if "_at@" in model_set[0]:
                    # dim = models[model_set[0].split("@")[0]]["embedding_dim"] * 2
                    dim = models[model_set[0].split("@")[0]]["embedding_dim"]
                    nr_models = models[model_set[0].split("@")[0]]["n_tokens"]
                    feature_combiner = "already_stacked_zero_pad"
                else:
                    # dim = max([models[mid.split("@")[0]]["embedding_dim"] for mid in model_set]) * 2
                    dim = max([models[mid.split("@")[0]]["embedding_dim"] for mid in model_set])
                    nr_models = len(model_set)
                    feature_combiner = "stacked_zero_pad"

                job_cmd = f"""python src/cli.py --dataset {DATASETS} \
                          --dataset_root {Path(DATASETS_ROOT).absolute()} \
                          --feature_root {Path(FEATURES_ROOT).absolute()} \
                          --model_root {Path(MODELS_ROOT).absolute()} \
                          --output_root {Path(RESULTS_ROOT).absolute()} \
                          --clustering_root {Path(CLUSTERING_ROOT).absolute()} \
                          --task=attentive_probe \
                          --mode=combined_models \
                          --feature_combiner {feature_combiner} \
                          --model_key {model_keys} \
                          --models_config_file {MODELS_CONFIG} \
                          --batch_size=2048 \
                          --dim {dim} \
                          --num_heads {n_heads} \
                          --dimension_alignment {dim_alignment} \
                          --fewshot_k {" ".join(hyper_params["fewshot_ks"])} \
                          --initial_lr {" ".join(hyper_params["initial_lrs"])} \
                          --epochs {" ".join(hyper_params["epochs"])} \
                          --reg_lambda {hyper_params["reg_lambda"]} \
                          --regularization {" ".join(hyper_params["regularization"])} \
                          --train_split train \
                          --test_split test \
                          --val_proportion {val_proportion} \
                          --seed {" ".join(hyper_params["seeds"])} \
                          --num_clusters {args.num_clusters} \
                          --clustering_similarity_method {args.clustering_similarity_method} \
                          --attention_dropout {proj_drop} {att_drop} \
                          --grad_norm_clip 5 \
                          --jitter_p 0.5
          """

                run_job(
                    job_name=f"attn_probe_with_n{nr_models}",
                    job_cmd=job_cmd,
                    partition="gpu-2d",
                    log_dir=f"{Path(RESULTS_ROOT).absolute()}/logs",
                    num_jobs_in_array=num_jobs,
                    mem=int(40 + (0.5 * nr_models)),
                )
