import argparse
from pathlib import Path

from scripts.helper import load_models, parse_datasets
from scripts.test_scripts.test_project_location import DATASETS_ROOT, FEATURES_ROOT, RESULTS_ROOT
from scripts.slurm import run_job

parser = argparse.ArgumentParser()
parser.add_argument('--models_config', type=str, default='./scripts/test_scripts/test_models_config_single_model_layer_combination.json')
parser.add_argument('--datasets', type=str, nargs='+', 
                    default='./scripts/test_scripts/test_webdatasets.txt',
                    help="datasets can be a list of dataset names or a file (e.g., webdatasets.txt) containing dataset names.")
args = parser.parse_args()

MODELS_CONFIG = args.models_config

DATASETS = " ".join(parse_datasets(args.datasets))

if __name__ == "__main__":
    # Retrieve the configuration of all models we intend to evaluate.
    models, n_models = load_models(MODELS_CONFIG)

    model_keys = list(models.keys())

    # Extract features for all models and datasets.
    for key in model_keys:
        print(f"Running feature extraction for {key}")
        job_cmd = f"""python src/cli.py \
                        --dataset {DATASETS} \
                        --dataset_root {Path(DATASETS_ROOT).absolute()} \
                        --output_root {Path(RESULTS_ROOT).absolute()} \
                        --feature_root {Path(FEATURES_ROOT).absolute()} \
                        --task=feature_extraction \
                        --model_key {key} \
                        --models_config_file {Path(MODELS_CONFIG).absolute()} \
                        --batch_size=64 \
                        --train_split train \
                        --test_split test \
                        --num_workers=0
        """

        run_job(
            job_name=f"feat_extr_{key}",
            job_cmd=job_cmd,
            partition='gpu-5h',
            log_dir=f'{Path(FEATURES_ROOT).absolute()}/logs',
            num_jobs_in_array=1,
            mem=64
        )
