"""Script for running an eval on embeddings of MusicLM's predictions.

Example invocation:

python -m fmri2music.scripts.eval_musiclm_pred \
    --log_path "data/pred/tmp/eval-results.csv" \
    --audioset_probs_file "rc3/fma_retrieval_large_SubjectXX_window10s-stride1_5s-mv101-avg-audioset-15s.npz" \
    --pred_emb_file "pred/rc3/fma_retrieval_large_SubjectXX_window10s-stride1_5s-mv101-avg-emb-window1_5s-stride1_5s-soundstream-avg.npz" \
    --pred_emb_file "pred/rc3/fma_retrieval_large_SubjectXX_window10s-stride1_5s-mv101-avg-emb-window10s-stride1_5s-mv101-avg.npz" \
    --pred_emb_file "pred/rc3/fma_retrieval_large_SubjectXX_window10s-stride1_5s-mv101-avg-gen-emb-window5s-stride1.5s-w2vbert-avg.npz"

    
Upper bound eval:

python -m fmri2music.scripts.eval_musiclm_pred \
    --log_path "data/pred/gtzan-upper-bound.csv" \
    --audioset_probs_file "music-emb/gtzan-audioset.npz" \
    --pred_emb_file "pred/gtzan-emb-window1_5s-stride1_5s-soundstream-avg.npz" \
    --pred_emb_file "pred/gtzan-emb-window10s-stride1_5s-mv101-avg.npz" \
    --pred_emb_file "pred/gtzan-gen-emb-window5s-stride1.5s-w2vbert-avg.npz"

"""

import argparse
from collections import defaultdict
import json
import os

from dotenv import load_dotenv, find_dotenv
import numpy as np

from fmri2music import emb_loader, quant_eval, utils


def emb_name_from_file_name(file_name: str) -> str:
    """Extract the embedding name from the file name.

    Examples:
    * '.../gtzanclassic-emb-window10s-stride1_5s-mv101-avg.npz' -> 'window10s-stride1_5s-mv101-avg'
    """
    name = os.path.splitext(os.path.basename(file_name))[0].split("-emb-")[1]
    return name.replace("1.5", "1_5")


def subject_id_from_key(key: str) -> int:
    return int(key.split("Subject0")[1].split("_")[0])


def split_by_subjects(preds: dict[str, np.ndarray]) -> dict[int, dict[str, np.ndarray]]:
    """Split predictions into multiple by subject ID.

    Keys include subject IDs. Split by them and return a dict mapping from subject ID to preds.
    """

    subject_id_to_preds = defaultdict(dict)
    for pred in preds:
        subject_id = subject_id_from_key(pred)
        subject_id_to_preds[subject_id][pred] = preds[pred]
    return dict(subject_id_to_preds)


def filter_to_first_sample(preds: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    result = {}
    for key, vec in preds.items():
        if "_15s_0.wav" in key:
            result[key] = vec
    return result


def key_to_slice_name(key: str) -> str:
    """Keys are a full file path with additional info, here we extract the slice name."""
    key = key.split(".npz/")[1]
    key = key.replace("_15s_0.wav", "_15s.wav")
    if key.endswith("_15s.wav"):
        return key
    return utils.normalize_slice_name(key)


def main(args):
    """Main entrypoint of the script."""
    log_path = args.log_path
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    pred_emb_files = args.pred_emb_file

    name_subject_preds: list[tuple[str, int, dict[str, np.ndarray]]] = []
    for file in pred_emb_files:
        all_preds = emb_loader.get_musiclm_pred_emb(file)
        for subject_id, preds in split_by_subjects(all_preds).items():
            preds = filter_to_first_sample(preds)
            preds = {key_to_slice_name(k): v for k, v in preds.items()}
            name_subject_preds.append(
                (file + f"Subject{subject_id}", subject_id, preds)
            )

    audioset_probs = emb_loader.get_audioset_probs_from_file(args.audioset_probs_file)
    audioset_probs = filter_to_first_sample(audioset_probs)
    audioset_probs_per_subject = split_by_subjects(audioset_probs)
    audioset_probs_per_subject = {
        subject_id: {key_to_slice_name(k): v for k, v in probs.items()}
        for subject_id, probs in audioset_probs_per_subject.items()
    }

    txt_output = ""

    for file, subject_id, slice_name_to_emb in name_subject_preds:
        emb_name = emb_name_from_file_name(file)
        print(
            f"Loaded {file} with shape {list(slice_name_to_emb.values())[0].shape}; {emb_name=}"
        )
        slice_name_to_emb = {
            utils.normalize_slice_name(k): v for k, v in slice_name_to_emb.items()
        }
        val_result = quant_eval.evaluate_musiclm_model(
            file,
            eval_emb_name=emb_name,
            slice_name_to_gen_emb=slice_name_to_emb,
            clip_name_to_audioset_probs=audioset_probs_per_subject[subject_id],
        )

        reported_metrics = utils.add_key_prefix("val-", val_result.get_result_dict())
        txt_output += json.dumps(reported_metrics) + "\n"

    # Write txt_output to log file.
    with open(log_path, "w", encoding="utf-8") as f:
        f.write(txt_output)

    print(f"Results written to: {log_path}")
    print("Done!")


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(
        description="Script for running an eval on embeddings of MusicLM's predictions."
    )

    parser.add_argument(
        "--log_path",
        type=str,
        required=True,
        help=("Path to which results will be written."),
    )

    parser.add_argument(
        "--pred_emb_file",
        action="append",
        required=True,
        help="Name of the predictions embeddings to use for evaluation.",
    )

    parser.add_argument(
        "--audioset_probs_file",
        type=str,
        required=True,
        help="Name of the file containing the AudioSet probabilities.",
    )

    main(parser.parse_args())
