"""Script for running a hyperparameter tuning of decoder models.

Example invocation:

python -m fmri2music.scripts.tune_hparams \
    --config_name "TEST_EXPERIMENT" \
    --log_path "data/logs/tune-test7.csv" \
    --num_threads 1 \
    --eval_emb_name "window10s-stride1_5s-mv101-avg" \
    --eval_emb_name "window10s-stride1_5s-mv109-avg" \
    --fma_size "small" \
    --metric_name "mean_identification_accuracy" \
    --all_subjects False
    
"""

import inspect
from typing import Callable

import argparse
import os

from dotenv import load_dotenv, find_dotenv

from fmri2music import (
    quant_eval,
    training,
    hparam_tuner,
    hparam_ranges,
    data_const,
    fmri_loader,
    utils,
)


def get_tuning_target_fn(
    tuning_name: str,
    eval_fma_size: str,
    eval_emb_names: list[str],
    subject_name: str,
    num_xval_splits: int,
) -> Callable[[hparam_tuner.HParams], hparam_tuner.TargetFn]:
    """Return a function that takes hparams and returns a score."""

    def target_fn(hparams: hparam_tuner.HParams) -> hparam_tuner.TargetFn:
        """Trains a model for the given hparams and evaluates it."""
        emb_name = hparams["emb_name"]
        name = f"{tuning_name}_hparam_sweep_{hparam_tuner.hash_hparams(hparams)}_{subject_name}_fma-{eval_fma_size}"
        model = training.train_predictor(
            name, emb_name, subject_name, num_xval_splits, hparams
        )

        eval_config = quant_eval.QuantEvalConfig(
            fma_size=eval_fma_size, eval_emb_names=eval_emb_names
        )

        trn_results = quant_eval.evaluate_model(
            model, fmri_loader.get_trn_clip_names(), eval_config
        )
        val_results = quant_eval.evaluate_model(
            model, fmri_loader.get_val_clip_names(), eval_config
        )

        fmri_loader.export_predictions(
            file_name=f"{name}.npz",
            emb_name=model.emb_name,
            gtzan_slice_names=model.gtzan_keys,
            gtzan_preds=model.preds,
            gtzan_clip_names_trn=trn_results[0].gtzan_clip_names,
            fma_clip_names_trn=trn_results[0].fma_clip_names,
            gtzan_clip_names_val=val_results[0].gtzan_clip_names,
            fma_clip_names_val=val_results[0].fma_clip_names,
        )

        reported_metrics = {}
        for result in trn_results:
            reported_metrics.update(
                utils.add_key_prefix("trn-", result.get_result_dict())
            )
        for result in val_results:
            reported_metrics.update(
                utils.add_key_prefix("val-", result.get_result_dict())
            )
        return reported_metrics

    return target_fn


def get_experiment_ranges(experiment_name: str) -> hparam_tuner.HParamRanges:
    try:
        return getattr(hparam_ranges, experiment_name)
    except AttributeError:
        print("Available experiments:")
        for name, _ in inspect.getmembers(hparam_ranges):
            if name.isupper():  # convention for constants in Python
                print(name)
        raise ValueError(f"No experiment found for: {experiment_name}")


def main(args):
    """Main entrypoint of the script."""
    log_path = args.log_path
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    subject_names = [args.subject_name]
    if args.all_subjects:
        subject_names = data_const.SUBJECTS
        print("Evaluating on all subjects.")
    else:
        print(f"Evaluating only on subject {args.subject_name}.")

    all_log_paths = []
    for subject_name in subject_names:
        log_path_wo_file_ext = log_path[: -len(".csv")]
        log_path_w_suffix = log_path_wo_file_ext + f"_{subject_name}.csv"
        target_fn = get_tuning_target_fn(
            log_path_wo_file_ext.rsplit("/")[-1],
            args.fma_size,
            args.eval_emb_name,
            subject_name,
            args.num_xval_splits,
        )

        ranges = get_experiment_ranges(args.config_name)
        target_metric_name = ""
        tuner = hparam_tuner.HyperparamTuner(
            ranges, target_fn, target_metric_name, log_path_w_suffix
        )

        best_hparam_hash, best_score = tuner.grid_search(args.num_threads)

        print(
            f"Hyperparameter tuning done for {subject_name}. "
            f"Best score: {best_score} for hparams: {best_hparam_hash}"
        )
        all_log_paths.append(log_path_w_suffix)

    print("Results written to:")
    print(*all_log_paths, sep="\n")

    print("Done for all subjects! Nice :-)")


if __name__ == "__main__":
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(
        description="Train a regression model to predict GTZAN clip embeddings from fmri data."
    )

    parser.add_argument(
        "--config_name",
        type=str,
        required=True,
        help="Name of the config in fmri2music/hparam_ranges.py to use.",
    )

    parser.add_argument(
        "--log_path",
        type=str,
        required=True,
        help=(
            "Path to which hparam tuning logs will be written. If already present, "
            "the tuning will resume from the stored state."
        ),
    )

    parser.add_argument(
        "--num_threads", type=int, default=1, help="Number of threads to use."
    )

    parser.add_argument(
        "--eval_emb_name",
        action="append",
        required=True,
        help="Name of the embeddings to use for evaluation.",
    )

    parser.add_argument(
        "--fma_size",
        choices=["small", "large"],
        required=True,
        help="Size of the FMA dataset to evaluate on.",
    )

    parser.add_argument(
        "--metric_name",
        type=str,
        default="mean-identification-accuracy",
        help="Name of the metric to optimize.",
    )

    parser.add_argument(
        "--all_subjects",
        type=bool,
        default=False,
        help="Whether to evaluate on all subjects.",
    )

    parser.add_argument(
        "--subject_name",
        type=str,
        default="Subject01",
        help="Name of the subject to evaluate on. Used if --all_subjects is False.",
    )

    parser.add_argument(
        "--num_xval_splits",
        type=int,
        default=1,
        help="Number of cross validation splits to use.",
    )

    main(parser.parse_args())
