import argparse
import os
from pathlib import Path

from helper import get_hyperparams, parse_datasets
from project_location import DATASETS_ROOT, FEATURES_ROOT, MODELS_ROOT, RESULTS_ROOT
from slurm import run_job

from src.argparser import rep_loss_type

parser = argparse.ArgumentParser()
parser.add_argument(
    "--models_config", type=str, default="./scripts/configs/models_config_wo_alignment.json"
)
parser.add_argument(
    "--downstream_datasets",
    type=str,
    nargs="+",
    default="./scripts/configs/webdatasets_w_in1k.txt",
    help="datasets can be a list of dataset names or a file (e.g., webdatasets.txt) containing dataset names.",
)

parser.add_argument(
    "--pretrained_dataset",
    type=str,
    default="wds/imagenet1k",
    help="Dataset used as a pretraining for MVAE models. ",
)

parser.add_argument(
    "--rep_loss",
    type=rep_loss_type,
    default="mse",
    help="Loss function for rep2rep task. Options:"
    " (1) Single losses: 'mse', 'mae', 'cosine_distance', 'glocal_TEMPT' or 'glocal_TEMPT_TEMPS'"
    " (where TEMPT/TEMPS are floats - single value sets same temperature for both teacher and student,"
    " two values set teacher to first and student to second),"
    " 'cka_linear' or 'cka_rbf_SIGMA' (where SIGMA is a float);"
    " (2) Combined losses: 'combinedALPHA__L1__L2'"
    " where ALPHA (between 0 and 1) determines the weight as ALPHA*L1 + (1-ALPHA)*L2.",
)
parser.add_argument(
    "--models_combination", type=str, default="./scripts/configs/rep2rep_combinations.txt"
)
parser.add_argument(
    "--normalize",
    action="store_true",
    default=True,
    help="enable features normalization",
)
parser.add_argument(
    "--no-normalize",
    dest="normalize",
    action="store_false",
    help="disable features normalization",
)

parser.add_argument(
    "--combination",
    type=str,
    default="concat",
    choices=["ensemble", "concat", "concat_pca", "mvae_eval"],
    help="Model combination to use",
)

args = parser.parse_args()

MODELS_CONFIG = args.models_config
DATASETS = " ".join(parse_datasets(args.downstream_datasets))


if __name__ == "__main__":
    # Retrieve the configuration of all models we intend to evaluate.
    if os.path.isfile(args.models_combination):
        with open(args.models_combination, "r") as f:
            model_combinations = [
                [m.strip() for m in line.split(";")] for line in f if line.strip()
            ]
    else:
        raise ValueError("The file does not exist", args.models_combination)

    # Extracting hyperparameters for evaluation: learning rate, few-shot k samples, epoch numbers, and seeds.
    hyper_params, num_jobs = get_hyperparams(num_seeds=1, size="small")

    val_proportion = 0.2

    print("We evaluate the following hyperparameter", hyper_params)

    mode = (
        "combined_models"
        if args.combination in ["concat", "concat_pca"]
        else args.combination
    )

    feature_combiner_map = {
        "mvae_eval": "tuple",
        "ensemble": "concat",
    }
    feature_combiner = feature_combiner_map.get(args.combination, args.combination)

    # Run evaluation for each model set
    for model_set in model_combinations:
        print(f"Submitting Job with model_key{' '.join(model_set)}")
        job_cmd = f"""python src/cli.py \
            --dataset {DATASETS} \
            --dataset_root {DATASETS_ROOT} \
            --feature_root {FEATURES_ROOT} \
            --output_root {RESULTS_ROOT} \
            --models_config_file {Path(MODELS_CONFIG).absolute()} \
            --task=linear_probe \
            --mode={mode} \
            --feature_combiner={feature_combiner} \
            --model_key {" ".join(model_set)} \
            --batch_size=2048 \
            --fewshot_k {" ".join(hyper_params["fewshot_ks"])} \
            --initial_lr {" ".join(hyper_params["initial_lrs"])} \
            --epochs {" ".join(hyper_params["epochs"])} \
            --reg_lambda {hyper_params["reg_lambda"]} \
            --regularization {" ".join(hyper_params["regularization"])} \
            --train_split train \
            --test_split test \
            --val_proportion {val_proportion} \
            --seed {" ".join(hyper_params["seeds"])} \
            --rep_loss {args.rep_loss} \
            {"--normalize" if args.normalize else "--no_normalize"} \
            --skip_existing \
            --pretrained_dataset {args.pretrained_dataset if args.combination == "mvae_eval" else "FALSE"} 
            """

        mem = 32 if args.combination == "ensemble" else 128

        run_job(
            job_name="combined_eval",
            job_cmd=job_cmd,
            partition="cpu-5h" if args.combination == "ensemble" else "gpu-2d",
            log_dir=f"{RESULTS_ROOT}/logs",
            num_jobs_in_array=num_jobs,
            mem=mem,
        )
