"""Script for running a music retrieval.

Example invocation:

python -m fmri2music.scripts.retrieve_music \
    --gtzan_clip_name blues.00000_15s.wav \
    --gtzan_clip_name classical.00003_15s.wav \
    --gtzan_clip_name country.00006_15s.wav \
    --gtzan_clip_name disco.00000_15s.wav \
    --gtzan_clip_name hiphop.00002_15s.wav \
    --gtzan_clip_name jazz.00003_15s.wav \
    --gtzan_clip_name metal.00002_15s.wav \
    --gtzan_clip_name pop.00001_15s.wav \
    --gtzan_clip_name reggae.00001_15s.wav \
    --gtzan_clip_name rock.00001_15s.wav \
    --emb_name "window1_5s-stride1_5s-soundstream-avg" \
    --emb_name "window5s-stride1_5s-w2vbert-avg" \
    --emb_name "window10s-stride1_5s-mv101-avg" \
    --fma_size "small"

"""

import argparse

from dotenv import load_dotenv, find_dotenv

from fmri2music import emb_loader, quant_eval


def main(args):
    """Main entrypoint of the script."""
    emb_names = sorted(args.emb_name)
    results = [['GTZAN'] + emb_names]
    for gtzan_clip in args.gtzan_clip_name:
        result_per_emb = {}
        for emb_name in emb_names:
            gtzan_embs, _ = emb_loader.get_grouped_gtzan_embs_for_clip_name(
                emb_name, gtzan_clip
            )
            fma_clip_name = quant_eval.retrieve_fma_clip_name(
                args.fma_size,
                emb_name,
                gtzan_embs,
                intersect_w_emb_name=None,
                min_num_slices_intersection=None,
            )
            print(
                f"Emb: {emb_name}; GTZAN: {gtzan_clip}; retrieved FMA: {fma_clip_name}"
            )
            result_per_emb[emb_name] = fma_clip_name
        results.append([gtzan_clip] + [result_per_emb[emb_name] for emb_name in emb_names])
    
    print(results)


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(
        description="Train a regression model to predict GTZAN clip embeddings from fmri data."
    )

    parser.add_argument(
        "--gtzan_clip_name",
        action="append",
        required=True,
        help="Name of the GTZAN clips to retrieve FMA clips for.",
    )

    parser.add_argument(
        "--emb_name",
        action="append",
        required=True,
        help="Name of the embeddings to use for retrieval.",
    )

    parser.add_argument(
        "--fma_size",
        choices=["small", "large"],
        required=True,
        help="Size of the FMA dataset to evaluate on.",
    )

    main(parser.parse_args())
