import os

import fire
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from evaluate_hh import EvalConfig, evaluate
from loguru import logger


def main(
    use_true_reward: bool = False,
    n=4,
    model="alpaca7b",
    target_kl=0.1,
    true_reward_dir_base="results/hh/models/qwen32b",
    step=1000,
    use_beaver: bool = False,
    eval_size=1000,
):
    if use_beaver:
        true_reward_dir_base = "PKU-Alignment/beaver-7b-v1.0"
    all_results = []
    for helpful_weight in [0.3, 0.4, 0.5, 0.6, 0.7]:
        harmless_weight = 1.0 - helpful_weight
        weights = [helpful_weight, harmless_weight]
        for meta in [False, True]:
            args = EvalConfig().from_dict(
                {
                    "n": n,
                    "meta": meta,
                    "weights": weights,
                    "target_kl": target_kl,
                    "eval_size": eval_size,
                    "model": model,
                    "true_reward_dir_base": true_reward_dir_base,
                    "step": step,
                }
            )

            result_file_path = (
                f"results/hh/eval_pareto/{args.experiment_name}/{args.run_name}.jsonl"
            )
            if use_beaver:
                result_file_path = result_file_path.replace(
                    "eval_pareto", "eval_pareto_beaver"
                )
            if os.path.exists(result_file_path):
                logger.info(f"Loading results from {result_file_path}")
                results_df = pd.read_json(result_file_path, lines=True)
                results = results_df.to_dict(orient="list")
            else:
                logger.info(
                    f"Evaluating model with meta={args.meta}, weights={weights}"
                )
                results = evaluate(args)
                data = pd.DataFrame(results)
                os.makedirs(os.path.dirname(result_file_path), exist_ok=True)
                data.to_json(result_file_path, lines=True, orient="records")
                logger.info(f"Saved results to {result_file_path}")
            all_results.append((meta, weights, results))

    colors = {True: "blue", False: "orange"}
    labels = {True: "IAMA (ours)", False: "Baseline"}
    labels_bon = {True: "IAMA BoN (ours)", False: "Baseline BoN"}
    shift_x = 0.02
    shift_y = 0.02

    with plt.style.context("./config/paper.mplstyle"):
        fig = plt.figure(figsize=(4, 3))
        ax = fig.add_subplot(1, 1, 1)
        for meta, weights, results in all_results:
            if use_true_reward:
                mean_helpful = np.mean(results["true_helpful"])
                mean_harmless = np.mean(results["true_harmless"])
                mean_bon_helpful = np.mean(results["bon_true_helpful"])
                mean_bon_harmless = np.mean(results["bon_true_harmless"])
                std_helpful = np.std(results["true_helpful"])
                std_harmless = np.std(results["true_harmless"])
                std_bon_helpful = np.std(results["bon_true_helpful"])
                std_bon_harmless = np.std(results["bon_true_harmless"])
            else:
                mean_helpful = np.mean(results["helpful"])
                mean_harmless = np.mean(results["harmless"])
                mean_bon_helpful = np.mean(results["bon_helpful"])
                mean_bon_harmless = np.mean(results["bon_harmless"])
                std_helpful = np.std(results["helpful"])
                std_harmless = np.std(results["harmless"])
                std_bon_helpful = np.std(results["bon_helpful"])
                std_bon_harmless = np.std(results["bon_harmless"])
            logger.info(
                f"meta={meta}, weights={weights} => "
                f"mean_helpful={mean_helpful}, mean_harmless={mean_harmless}, "
                f"mean_bon_helpful={mean_bon_helpful}, mean_bon_harmless={mean_bon_harmless}"
            )
            logger.info(
                f"std_helpful={std_helpful}, std_harmless={std_harmless}, "
                f"std_bon_helpful={std_bon_helpful}, std_bon_harmless={std_bon_harmless}"
            )
            ax.scatter(
                mean_helpful,
                mean_harmless,
                label=labels[meta],
                marker="o",
                color=colors[meta],
                s=10,
            )
            ax.scatter(
                mean_bon_helpful,
                mean_bon_harmless,
                label=labels_bon[meta],
                marker="x",
                color=colors[meta],
                s=10,
            )
            ax.text(
                mean_helpful + shift_x,
                mean_harmless + shift_y,
                rf"$w_1 = {weights[0]}$",
                fontsize=6,
            )
            ax.text(
                mean_bon_helpful + shift_x,
                mean_bon_harmless + shift_y,
                rf"$w_1 = {weights[0]}$",
                fontsize=6,
            )

        ax.set_xlabel("Helpfulness")
        ax.set_ylabel("Harmlessness")
        ax.set_title("Evaluation Results")

        handles, legend_labels = ax.get_legend_handles_labels()
        unique = dict(zip(legend_labels, handles))
        ax.legend(unique.values(), unique.keys(), loc="upper left")  # type: ignore

    fig_name = "pareto_true_reward" if use_true_reward else "pareto"
    fig.savefig(f"./figs/hh/{fig_name}_n={n}.png", bbox_inches="tight")


if __name__ == "__main__":
    fire.Fire(main)
