#!/usr/bin/python3
"""
Script to analyze experimental results.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import click
import os
import numpy as np
import scipy.stats as stats
from enum import Enum
from gymnasium import registry
from importlib import import_module
from pathlib import Path
from typing import Dict, Tuple, Union

import leon
from leon.optim import get_optimizers


class Metrics(str, Enum):
    MAX = "max_score"
    MIN = "min_score"
    MEAN = "mean_score"
    STD = "std_score"
    MEDIAN = "median_score"
    PERCENTILE_90 = "90th_percentile_score"
    NUM_INPUT_TOKENS = "num_input_tokens"
    NUM_OUTPUT_TOKENS = "num_output_tokens"


@click.command()
@click.option(
    "--task",
    "-t",
    "task_name",
    type=click.Choice(list(leon.registry.keys())),
    required=True,
    multiple=True,
    help="Biomedical zero-shot optimization task."
)
@click.option(
    "--optimizer",
    "-o",
    type=click.Choice(get_optimizers()),
    required=True,
    help="Backbone optimizer."
)
@click.option(
    "--suffix",
    "-s",
    type=str,
    default="",
    show_default=True,
    help="Suffix to add to the results directory."
)
@click.option(
    "--results-dir",
    type=click.Path(
        exists=True,
        file_okay=False,
        dir_okay=True,
        readable=True,
        resolve_path=True,
        path_type=str
    ),
    default="results",
    show_default=True,
    help="Local directory where the experimental results are saved."
)
@click.option(
    "--budget",
    "-k",
    type=int,
    required=True,
    help="Oracle evaluation budget."
)
@click.option(
    "--n-surrogate-calls",
    "-n",
    type=int,
    default=-1,
    show_default=True,
    help="Number of allowed surrogate calls (all by default)."
)
@click.option(
    "--metrics",
    "-m",
    type=click.Choice(list(Metrics.__members__.values()) + ["all"]),
    default=("all",),
    show_default=True,
    multiple=True,
    help="Quantitative metrics to compute."
)
@click.option(
    "--unnormalize/--normalize",
    default=False,
    help="Whether to unnormalize the metric values."
)
@click.option(
    "--verbose/--quiet",
    default=True,
    help="Whether to display outputs to `stdout`."
)
def main(
    task_name: Tuple[str, ...],
    optimizer: str,
    suffix: str,
    results_dir: Union[Path, str],
    budget: int,
    n_surrogate_calls: int,
    metrics: Tuple[str, ...],
    unnormalize: bool,
    verbose: bool
) -> Dict[str, Dict[str, np.ndarray]]:
    """Experimental results analysis script."""
    optimizer = getattr(leon.optim, optimizer).optimizer_name
    if len(suffix) > 0:
        optimizer += f"-{suffix}"

    all_results = {}
    for task in task_name:
        savedir: str = os.path.join(str(results_dir), optimizer, task)
        if not os.path.isdir(savedir):
            continue
        fns = sorted([os.path.join(savedir, fn) for fn in os.listdir(savedir)])
        fns = list(filter(lambda fn: fn.lower().endswith(".npz"), fns))

        if "all" in metrics:
            metrics = tuple(list(Metrics.__members__.values()))
        else:
            metrics = tuple(
                m for m in Metrics.__members__.values() if m in metrics
            )

        metric_vals: Dict[str, np.ndarray] = {m: np.array([]) for m in metrics}

        for fn in sorted(fns):
            data = parse_results(
                fn, task, budget, n_surrogate_calls, unnormalize
            )
            for metric in metrics:
                metric_vals[metric] = np.hstack(
                    (metric_vals[metric], [data[metric]])
                )

        metric_vals = {
            m: np.array(res) for m, res in metric_vals.items()
        }

        if verbose:
            print(f"Task: {task}")
            print(f"  N = {len(fns)}")
            for m, res in metric_vals.items():
                y_adj = res[np.where(np.logical_not(np.isnan(res)))[0]]
                print(f"  {m}: {y_adj.mean():.12f},{stats.sem(y_adj):.12f}")

        all_results[task] = metric_vals
    return all_results


def parse_results(
    fn: str,
    task_name: str,
    budget: int,
    n_surrogate_calls: int,
    unnormalize: bool
) -> Dict[str, float]:
    """
    Parse the experimental results from a single file.
    Input:
        fn: the file to parse experimental results from.
        task_name: the name of the optimization task.
        budget: the oracle evaluation budget.
        n_surrogate_calls: the number of allowed surrogate calls.
        unnormalize: whether to unnormalize the metric values.
    Output:
        A dictionary containing the experimental results.
    """
    data = np.load(fn)

    predictions = data["predictions"].squeeze()
    scores = data["scores"].squeeze()
    assert len(scores) == len(predictions)

    if n_surrogate_calls > 0:
        assert n_surrogate_calls <= len(predictions)
        predictions = predictions[:n_surrogate_calls]
        scores = scores[:n_surrogate_calls]
    assert budget <= len(scores)
    idxs = np.argsort(predictions)[-budget:]
    scores = scores[idxs]

    num_input_tokens = data.get("num_input_tokens", -1)
    num_output_tokens = data.get("num_output_tokens", -1)

    if unnormalize:
        mod, attr = registry[task_name].kwargs["dataset"].split(":", 1)
        task_cls = getattr(import_module(mod), attr)
        if all(x is not NotImplemented for x in [task_cls._mu, task_cls._std]):
            scores = (scores * task_cls._std) + task_cls._mu

    return {
        "max_score": np.max(scores),
        "min_score": np.min(scores),
        "mean_score": np.mean(scores),
        "std_score": np.std(scores),
        "median_score": np.median(scores),
        "90th_percentile_score": np.percentile(scores, 90),
        "num_input_tokens": num_input_tokens,
        "num_output_tokens": num_output_tokens
    }


if __name__ == "__main__":
    main()
