"""Script for training a regression model from fmri to music embeddings.

Example invocation:

python -m fmri2music.scripts.train_regressor \
    --model_name "gtzan-regression-window10s-stride1_5s-mv101-avg-s1-noxval" \
    --train_target_emb "window10s-stride1_5s-mv101-avg" \
    --eval_emb_name "window10s-stride1_5s-mv101-avg" \
    --fma_size "small" \
    --subject_name "Subject01" \
    --num_xval_splits 1

"""

import argparse

from dotenv import load_dotenv, find_dotenv

from fmri2music import quant_eval, training, data_const, fmri_loader, predictor


def main(args):
    """Main entrypoint of the script."""
    hparams = {
        "haemodynamic_resp_delay": 2,
        "nroi_used_for_ensemble": 6,
        "voxel_num_limit": 5,
        "excluded_genre": args.excluded_genre,
    }

    model: predictor.OnlinePredictor = training.train_predictor(
        args.model_name,
        args.train_target_emb,
        args.subject_name,
        args.num_xval_splits,
        hparams,
    )

    config = quant_eval.QuantEvalConfig(
        fma_size=args.fma_size, eval_emb_names=args.eval_emb_name
    )

    trn_eval_results = [
        r
        for r in quant_eval.evaluate_model(
            model, fmri_loader.get_trn_clip_names(), config
        )
    ]

    print("Train split:")
    print(*[r.report() for r in trn_eval_results], sep="\n")

    val_eval_results = [
        r
        for r in quant_eval.evaluate_model(
            model, fmri_loader.get_val_clip_names(), config
        )
    ]
    print("Validation split:")
    print(*[r.report() for r in val_eval_results], sep="\n")

    fmri_loader.export_predictions(
        file_name=f"{args.model_name}.npz",
        emb_name=model.emb_name,
        gtzan_slice_names=model.gtzan_keys,
        gtzan_preds=model.preds,
        gtzan_clip_names_trn=trn_eval_results[0].gtzan_clip_names,
        fma_clip_names_trn=trn_eval_results[0].fma_clip_names,
        gtzan_clip_names_val=val_eval_results[0].gtzan_clip_names,
        fma_clip_names_val=val_eval_results[0].fma_clip_names,
    )


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(
        description="Train a regression model to predict GTZAN clip embeddings from fmri data."
    )

    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
        help="Name of the trained model.",
    )

    parser.add_argument(
        "--subject_name",
        choices=data_const.SUBJECTS,
        required=True,
        help="Name of the subject to train the model on.",
    )

    parser.add_argument(
        "--train_target_emb",
        type=str,
        required=True,
        help="Training target embedding .npz file name.",
    )

    parser.add_argument(
        "--eval_emb_name",
        action="append",
        required=True,
        help="Name of the embeddings to use for evaluation.",
    )

    parser.add_argument(
        "--fma_size",
        choices=["small", "large"],
        required=True,
        help="Size of the FMA dataset to evaluate on.",
    )

    parser.add_argument(
        "--excluded_genre",
        default="",
        choices=data_const.GTZAN_GENRES,
        required=False,
        help="Specify to exclude a specific category for category generalization training",
    )

    parser.add_argument(
        "--num_xval_splits",
        type=int,
        default=1,
        help="Number of cross validation splits to use.",
    )

    main(parser.parse_args())
