# -*- coding: utf-8 -*-
"""Plot results of single-variable experiments
from multiple folders and save figures.

Support tasks: reasoning_retrieval, reasoning_DAG.

"""

import argparse
import numpy as np
import matplotlib.pyplot as plt

from plot_exp_results import load_data


FIGSIZE = (3.5, 3.5)

LINE_STYLES = ["--o", "--*"]
ERROR_COLORS = ["b", "m"]
COST_COLORS = ["r", "y"]

dict_metrics_common = {
    "llm_calls": "LLM calls",
    "prefilling_tokens_total": "Prefilling tokens",
    "decoding_tokens_total": "Decoding tokens",
    "latency": "Latency (sec)",
    "latency_finite_parallel": "Latency, p=4 (sec)",
    "latency_ideal_parallel": r"Latency, p=$\infty$ (sec)",
}

dict_metrics_reasoning_retrieval = {
    "error_EM": "Exact-match error",
    "error_abs": "Absolute error",
    "error_missed_coverage": "Missed-coverage ratio",
    "num_passes": "Number of rounds",
} | {
    key + "_per_round": val + " per round"
    for (key, val) in dict_metrics_common.items()
}

dict_metrics_reasoning_DAG = {
    "error_EM": "Exact-match error",
    "error_abs": "Absolute error",
    "error_missed_coverage": "Missed-coverage ratio",
    "ideal_llm_calls": "Ideal number of LLM calls",
}

dict_metrics_all_tasks = {
    "reasoning_retrieval": dict_metrics_common
    | dict_metrics_reasoning_retrieval,
    "reasoning_DAG": dict_metrics_common | dict_metrics_reasoning_DAG,
}

dict_name_xlabel = {
    "n": "Task size $n$",
    "m": "Sub-task size $m$",
    "depth": "Depth $d$",
    "width": "Width $w$",
    "degree": "Degree $g$",
}

dict_legend_label = (
    {  # convert config parameter to actual legend label in plots
        "decomposition_cyclic": "Cyclic",
        "decomposition_parallel": "Parallel",
        "answer_directly": "Answer directly",
        "reason_step_by_step": "Calculate step by step",
    }
)


def preprocess_metrics_data(
    configs_results: list,
    metric_names: list,
    variable_name: str,
    legend_variable: str,
) -> tuple:
    """Preprocess data of metrics

    Args:
        configs_results: list of config_result
        metric_names: list of metric names to be plotted
        variable_name: name of variable for the X-axis
        legend_variable: name of variable for legend

    Returns:
        lst_variable: values of variable for X-axis
        metrics_mean: mean for each metric
        metrics_std: std for each metric
        legend_label: label for legend

    """
    lst_variable = []
    metrics_mean = {name: [] for name in metric_names}
    metrics_std = {name: [] for name in metric_names}

    for config_result in configs_results:
        config = config_result["config"]
        lst_variable.append(config[variable_name])

        trials_results = config_result["trials_results"]
        for name in metric_names:
            if name.endswith("_per_round"):  # for reasoning_retrieval
                idx_postfix = name.rfind("_per_round")
                name_original = name[:idx_postfix]
                lst = [
                    rst[name_original] / rst["num_passes"]
                    for rst in trials_results
                ]
            else:
                lst = [rst[name] for rst in trials_results]
            metrics_mean[name].append(np.mean(lst))
            metrics_std[name].append(np.std(lst))

    legend_label = config[legend_variable]
    legend_label = dict_legend_label[legend_label]

    idx_sort = np.argsort(lst_variable)
    lst_variable = np.array(lst_variable)[idx_sort]
    for name in metric_names:
        metrics_mean[name] = np.array(metrics_mean[name])[idx_sort]
        metrics_std[name] = np.array(metrics_std[name])[idx_sort]
    return lst_variable, metrics_mean, metrics_std, legend_label


def plot_metrics(
    folders: list,
    variable_name: str,
    legend_variable: str,
    dict_metrics: dict,
) -> None:
    """Plot error and cost metrics with legend"""
    metric_names = list(dict_metrics.keys())

    lst_configs_results = [load_data(folder) for folder in folders]
    lst_preprocessed_data = [
        preprocess_metrics_data(
            configs_results,
            metric_names,
            variable_name,
            legend_variable,
        )
        for configs_results in lst_configs_results
    ]

    # Plot
    name_xlabel = dict_name_xlabel[variable_name]

    for name in metric_names:
        _, _ = plt.subplots(figsize=FIGSIZE)

        for idx, preprocessed_data in enumerate(lst_preprocessed_data):
            (
                lst_variable,
                metrics_mean,
                metrics_std,
                legend_label,
            ) = preprocessed_data

            line_style = LINE_STYLES[idx]
            if name.startswith("error_"):
                color = ERROR_COLORS[idx]
            else:
                color = COST_COLORS[idx]

            plt.plot(
                lst_variable,
                metrics_mean[name],
                line_style,
                color=color,
                label=legend_label,
            )
            plt.fill_between(
                lst_variable,
                metrics_mean[name] - metrics_std[name],
                metrics_mean[name] + metrics_std[name],
                alpha=0.1,
                facecolor=color,
            )

        # special plot: compare llm_calls with ideal_llm_calls
        if (name == "llm_calls") and ("ideal_llm_calls" in dict_metrics):
            plt.plot(
                lst_variable,
                metrics_mean["ideal_llm_calls"],
                "--",
                color="lightgray",
                label="Ideal number of LLM calls",
            )
            plt.fill_between(
                lst_variable,
                metrics_mean["ideal_llm_calls"]
                - metrics_std["ideal_llm_calls"],
                metrics_mean["ideal_llm_calls"]
                + metrics_std["ideal_llm_calls"],
                alpha=0.1,
                facecolor="lightgray",
            )

        plt.grid(True)
        plt.xlabel(name_xlabel)
        plt.ylabel(dict_metrics[name])
        plt.legend()
        plt.tight_layout()
        fig_path = folders[0] + f"/{name}.pdf"
        plt.savefig(fig_path, format="pdf")


def get_task_from_folder(folder: str) -> str:
    """Get task name from folder name"""
    if "exp_reasoning_retrieval" in folder:
        task = "reasoning_retrieval"
    elif "exp_reasoning_DAG" in folder:
        task = "reasoning_DAG"
    else:
        raise ValueError("task name not found.")
    return task


def get_variable_name_from_folder(folder: str) -> str:
    """Get variable name from folder name"""

    candidates = list(dict_name_xlabel.keys())
    for name in candidates:
        if f"vary_{name}" in folder:
            return name

    print("Name of X-label cannot be inferred from folder.")
    return "null"


def parse_args() -> argparse.Namespace:
    """Parse arguments"""
    parser = argparse.ArgumentParser()

    parser.add_argument("--folder1", type=str)
    parser.add_argument("--folder2", type=str)
    parser.add_argument("--legend_variable", type=str)

    return parser.parse_args()


if __name__ == "__main__":
    args_main = parse_args()
    folder1_main = args_main.folder1
    folder2_main = args_main.folder2
    legend_variable_main = args_main.legend_variable

    task_main = get_task_from_folder(folder1_main)
    variable_name_main = get_variable_name_from_folder(folder1_main)

    folders_main = [folder1_main, folder2_main]
    dict_metrics_main = dict_metrics_all_tasks[task_main]

    plot_metrics(
        folders_main,
        variable_name_main,
        legend_variable_main,
        dict_metrics_main,
    )
