"""Script for running the quantitative evaluation.

Example invocation:

python -m fmri2music.scripts.run_quant_eval \
    --fma_size=small \
    --eval_emb_name "window15s-stride15s-mv101-avg" \
    --eval_emb_name "window15s-stride15s-mv109-avg" \
    --model=model1:gtzan-regression-window10s-stride1_5s-mv101-avg-s1-noxval.npz


Evaluate similarity measures on different embeddings:

python -m fmri2music.scripts.run_quant_eval \
    --fma_size=large \
    --eval_emb_name "window1_5s-stride1_5s-soundstream-avg" \
    --eval_emb_name "window10s-stride1_5s-mv101-avg" \
    --eval_emb_name "window5s-stride1_5s-w2vbert-avg" \
    --use_gtzan_trn_split True

"""

import argparse

from dotenv import load_dotenv, find_dotenv

from fmri2music import fmri_loader, quant_eval


def main(args):
    """Main entrypoint of the script."""
    models = dict(model.split(":") for model in args.model)
    config = quant_eval.QuantEvalConfig(
        fma_size=args.fma_size, eval_emb_names=args.eval_emb_name
    )
    if args.use_gtzan_trn_split:
        gtzan_clip_names = fmri_loader.get_trn_clip_names()
    else:
        gtzan_clip_names = fmri_loader.get_val_clip_names()

    for result in quant_eval.evaluate_models(models, gtzan_clip_names, config):
        print(result.get_result_dict())

    for result in quant_eval.evaluate_models(models, gtzan_clip_names, config):
        print(result.report())


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(description="Run quantitative evaluation.")

    parser.add_argument(
        "--model",
        action="append",
        required=False,
        default=[],
        help=(
            "Model paths in the format --model=name:path where path is an .npz "
            "file with the exported prediction on validation data."
        ),
    )

    parser.add_argument(
        "--fma_size",
        choices=["small", "large"],
        required=True,
        help="Size of the FMA dataset to evaluate on.",
    )

    parser.add_argument(
        "--eval_emb_name",
        action="append",
        required=True,
        help="Name of the embeddings to use for evaluation.",
    )

    parser.add_argument(
        "--use_gtzan_trn_split",
        type=bool,
        default=False,
        help="Whether to use the GTZAN training split for evaluation.",
    )

    main(parser.parse_args())
