import argparse
import os
from pathlib import Path

from scripts.helper import get_hyperparams, load_models, parse_datasets
from scripts.project_location import CLUSTERING_ROOT, DATASETS_ROOT, FEATURES_ROOT, MODELS_ROOT, RESULTS_ROOT
from scripts.slurm import run_job

parser = argparse.ArgumentParser()
parser.add_argument(
    "--models_config", type=str, default="./scripts/configs/models_config_single_model_layer_combination.json"
)
parser.add_argument(
    "--datasets",
    type=str,
    nargs="+",
    # default="./scripts/configs/webdataset_configs/webdatasets_part_one_experiments.txt",
    default=["wds/vtab/pcam"],
    help="datasets can be a list of dataset names or a file (e.g., webdatasets.txt) containing dataset names.",
)

parser.add_argument(
    "--models_combination",
    nargs="+",
    type=str,
    default=[
        # "./scripts/configs/model_combinations_last_layer_cls_ap.txt",
        # "./scripts/configs/model_combinations_layers_all_blocks_cls_ap_B.txt",
        # "./scripts/configs/model_combinations_all_tokens_last_layer_B.txt",
        # "./scripts/configs/model_combinations_layers_all_blocks_cls_ap_S_L.txt",
        # "./scripts/configs/model_combinations_all_tokens_last_layer_S_L.txt",
        # "./scripts/configs/model_combinations_layers_all_blocks_cls_ap_rerun.txt",
        # "./scripts/configs/model_combinations_layers_all_blocks_cls_ap_rerun_2.txt",
    ],
    help="File containing model combinations to evaluate with an attentive probe.",
)

args = parser.parse_args()

MODELS_CONFIG = args.models_config

if __name__ == "__main__":
    if not isinstance(args.models_combination, list):
        args.models_combination = [args.models_combination]

    prepared_datasets = sorted(set(parse_datasets(args.datasets)))

    for file_to_model_combinations in args.models_combination:
        # Parse the model combinations from the file
        if os.path.isfile(file_to_model_combinations):
            with open(file_to_model_combinations, "r") as f:
                model_combinations = [[m.strip() for m in line.split(";")] for line in f if line.strip()]
        else:
            raise ValueError("The file does not exist", file_to_model_combinations)

        # Load the models configuration and set global variables
        models, n_models = load_models(MODELS_CONFIG)

        hyper_params, num_jobs = get_hyperparams(num_seeds=1, size="imagenet1k")

        val_proportion = 0.2

        dim_alignment = "zero_padding"

        # Run the attentive probe for each set of model combinations
        for model_set in model_combinations:
            for proj_drop, att_drop in hyper_params["attention_dropout"]:
                assert all([key.split("@")[0] in models.keys() for key in model_set])

                model_keys = " ".join(model_set)

                if "_at@" in model_set[0]:
                    first_model = model_set[0].split("@")[0]
                    dim = models[first_model]["embedding_dim"]
                    nr_models = models[first_model]["set_length"]
                    n_heads = 8  # TODO: what is the best number of heads here???
                    feature_combiner = "already_stacked_zero_pad"
                    DATASETS = " ".join(sorted(set(prepared_datasets) - set(["wds/imagenet1k", "imagenet-subset-50k"])))
                else:
                    dim = max([models[mid.split("@")[0]]["embedding_dim"] for mid in model_set])
                    nr_models = len(model_set)
                    n_heads = len(model_set)
                    feature_combiner = "stacked_zero_pad"
                    # DATASETS = " ".join(prepared_datasets)
                    # tmp_ds = sorted(set(prepared_datasets) - set(["wds/imagenet1k", "imagenet-subset-50k"]))
                    # tm_ds = tmp_ds + ["imagenet-subset-50k", "wds/imagenet1k"] # enforce at the end
                    # DATASETS = " ".join(tm_ds)

                nr_models = min(nr_models, 40)
                mem = int(40 + (nr_models * 10))
                if mem > 300:
                    mem = 300


                for ds in ["imagenet-subset-50k", "wds/imagenet1k"]:
                    print(f"\nRunning attentive probe for:\n{model_set=},\n{proj_drop=} dropout,\n{att_drop=} dropout,\n{n_heads=} heads,\n{ds=} datasets,\nand {mem}GB memory\n")

                    job_cmd = f"""python src/cli.py --dataset {ds} \
                                    --dataset_root {Path(DATASETS_ROOT).absolute()} \
                                    --feature_root {Path(FEATURES_ROOT).absolute()} \
                                    --model_root {Path(MODELS_ROOT).absolute()} \
                                    --output_root {Path(RESULTS_ROOT).absolute()} \
                                    --clustering_root {Path(CLUSTERING_ROOT).absolute()} \
                                    --task=attentive_probe \
                                    --mode=combined_models \
                                    --feature_combiner {feature_combiner} \
                                    --model_key {model_keys} \
                                    --models_config_file {MODELS_CONFIG} \
                                    --batch_size=2048 \
                                    --dim {dim} \
                                    --num_heads {n_heads} \
                                    --dimension_alignment {dim_alignment} \
                                    --fewshot_k {" ".join(hyper_params["fewshot_ks"])} \
                                    --initial_lr {" ".join(hyper_params["initial_lrs"])} \
                                    --epochs {" ".join(hyper_params["epochs"])} \
                                    --reg_lambda {hyper_params["reg_lambda"]} \
                                    --regularization {" ".join(hyper_params["regularization"])} \
                                    --train_split train \
                                    --test_split test \
                                    --val_proportion {val_proportion} \
                                    --seed {" ".join(hyper_params["seeds"])} \
                                    --attention_dropout {proj_drop} {att_drop} \
                                    --grad_norm_clip 5 \
                                    --jitter_p 0.5 \
                                    --skip_existing
                    """

                    run_job(
                        job_name=f"attn_probe_with_n{nr_models}",
                        job_cmd=job_cmd,
                        partition="gpu-2d",
                        log_dir=f"{Path(RESULTS_ROOT).absolute()}/logs",
                        num_jobs_in_array=num_jobs,
                        mem=mem,
                    )
