import argparse
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import yaml


def extract_metrics(log_data):
    results = [r for r in log_data.get("results", []) if r.get("success")]
    metrics = {
        "total": 0,
        "uses_exploit": 0,
        "sp_uses_exploit": 0,
        "sp_no_exploit": 0,
        "total_sp": 0,
    }

    for r in results:
        uses_exploit = r.get("classifications", {}).get("uses_exploit", {})
        sp = r.get("classifications", {}).get("strategic_preservation", {})

        if not (uses_exploit.get("success") and sp.get("success")):
            continue

        metrics["total"] += 1
        exploit_used = uses_exploit["extracted"]["uses_exploit"] > 0
        is_sp = sp["extracted"]["final_answer"] == "Yes"

        if exploit_used:
            metrics["uses_exploit"] += 1
            if is_sp:
                metrics["sp_uses_exploit"] += 1
        elif is_sp:
            metrics["sp_no_exploit"] += 1

        if is_sp:
            metrics["total_sp"] += 1

    exploit_rate = (
        metrics["uses_exploit"] / metrics["total"] * 100 if metrics["total"] else 0
    )
    sp_when_using = (
        metrics["sp_uses_exploit"] / metrics["uses_exploit"] * 100
        if metrics["uses_exploit"]
        else 0
    )
    sp_when_not_using = (
        metrics["sp_no_exploit"] / (metrics["total"] - metrics["uses_exploit"]) * 100
        if metrics["total"] > metrics["uses_exploit"]
        else 0
    )

    return {
        "exploit_rate": exploit_rate,
        "sp_when_using_exploit": sp_when_using * exploit_rate / 100,
        "sp_when_not_using_exploit": sp_when_not_using * (100 - exploit_rate) / 100,
        "sp_rate": (metrics["sp_uses_exploit"] + metrics["sp_no_exploit"])
        / metrics["total"]
        * 100
        if metrics["total"]
        else 0,
    }


def process_setup(config, setup_name):
    setup = config["setups"][setup_name]
    log_pairs = []

    if "log_pairs" in setup:
        log_pairs = [
            (Path(p["training"]), Path(p["not_training"])) for p in setup["log_pairs"]
        ]
    elif "training_logs" in setup:
        log_pairs = list(
            zip(
                [Path(log) for log in setup["training_logs"]],
                [Path(log) for log in setup["not_training_logs"]],
            )
        )
    else:
        log_pairs = [(Path(setup["training_log"]), Path(setup["not_training_log"]))]

    metrics_list = []
    for training_path, not_training_path in log_pairs:
        with open(training_path) as f:
            training_m = extract_metrics(json.load(f))
        with open(not_training_path) as f:
            not_training_m = extract_metrics(json.load(f))

        metrics_list.append(
            {
                "exploit_gap": training_m["exploit_rate"]
                - not_training_m["exploit_rate"],
                "sp_when_using_exploit": training_m["sp_when_using_exploit"],
                "sp_when_not_using_exploit": training_m["sp_when_not_using_exploit"],
                "sp_rate": training_m["sp_rate"],
            }
        )

    if len(metrics_list) > 1:
        return {
            k: {
                "mean": np.mean([m[k] for m in metrics_list]),
                "std": np.std([m[k] for m in metrics_list]),
            }
            for k in metrics_list[0]
        }
    else:
        return {k: {"mean": metrics_list[0][k], "std": 0} for k in metrics_list[0]}


def plot_stacked_bars(config, output_path=None, sp_mode="both"):
    # Keep the same serif font settings
    plt.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": ["Times New Roman", "Times", "DejaVu Serif", "serif"],
        }
    )

    # ---- collect metrics for each setup ----
    data = []
    for name in config["setups"]:
        metrics = process_setup(config, name)
        setup = config["setups"][name]
        display_name = setup.get(
            "title", name
        )  # Use title if provided, otherwise use key
        data.append(
            {
                "name": display_name,
                "exploit_gap": abs(
                    metrics["exploit_gap"]["mean"]
                ),  # Use absolute value for plotting
                "exploit_gap_err": metrics["exploit_gap"]["std"],
                "sp_when_using_exploit": metrics["sp_when_using_exploit"]["mean"],
                "sp_when_using_exploit_err": metrics["sp_when_using_exploit"]["std"],
                "sp_when_not_using_exploit": metrics["sp_when_not_using_exploit"][
                    "mean"
                ],
                "sp_when_not_using_exploit_err": metrics["sp_when_not_using_exploit"][
                    "std"
                ],
            }
        )
        print(f"\n{display_name}:")
        print(
            f"  Exploit gap: {metrics['exploit_gap']['mean']:.1f}% ± {metrics['exploit_gap']['std']:.1f}%"
        )
        print(
            f"  SP when using exploit: {metrics['sp_when_using_exploit']['mean']:.1f}% ± {metrics['sp_when_using_exploit']['std']:.1f}%"
        )
        print(
            f"  SP when not using exploit: {metrics['sp_when_not_using_exploit']['mean']:.1f}% ± {metrics['sp_when_not_using_exploit']['std']:.1f}%"
        )

    n = len(data)

    # ---- figure: 1 row x 2 columns (metrics as columns) ----
    # width identical to before (2 panels * 3.5), height scales with number of setups
    fig_height = max(2, n)  # keeps the per-row look consistent if n > 2
    fig, axes = plt.subplots(1, 2, figsize=(2 * 3.5, fig_height), sharey=True)
    plt.subplots_adjust(wspace=0)

    # vertical placement now depends on the number of setups
    bar_height = 0.075
    step = 0.08  # vertical spacing between rows; matches the original look for two rows
    # symmetric positions around 0, e.g. n=2 -> [0.04, -0.04]
    y_pos = [((n - 1) / 2 - i) * step for i in range(n)]
    error_kw = {"elinewidth": 0.75, "capsize": 2, "alpha": 0.75, "ecolor": "0.1"}

    # ---- Left column: Exploit Gap ----
    ax_eg = axes[0]
    for i, d in enumerate(data):
        ax_eg.barh(
            y_pos[i],
            d["exploit_gap"],
            bar_height,
            color="#4B0082",
            label="Exploit Gap" if i == 0 else "",
            xerr=d["exploit_gap_err"],
            error_kw=error_kw,
        )

    # ---- Right column: Strategic Preservation ----
    ax_sp = axes[1]
    for i, d in enumerate(data):
        if sp_mode == "both":
            # Stacked bars showing both
            ax_sp.barh(
                y_pos[i],
                d["sp_when_using_exploit"],
                bar_height,
                color="#0080FF",
                alpha=0.8,
                label="SP when using exploit" if i == 0 else "",
                xerr=d["sp_when_using_exploit_err"],
                error_kw=error_kw,
            )
            ax_sp.barh(
                y_pos[i],
                d["sp_when_not_using_exploit"],
                bar_height,
                left=d["sp_when_using_exploit"],
                color="#00BFFF",
                alpha=0.8,
                label="SP when not using exploit" if i == 0 else "",
                xerr=d["sp_when_not_using_exploit_err"],
                error_kw=error_kw,
            )
        elif sp_mode == "not-using":
            # Only show SP when not using exploit
            ax_sp.barh(
                y_pos[i],
                d["sp_when_not_using_exploit"],
                bar_height,
                color="#00BFFF",
                alpha=0.8,
                label="SP when not using exploit" if i == 0 else "",
                xerr=d["sp_when_not_using_exploit_err"],
                error_kw=error_kw,
            )
        else:  # sp_mode == "using"
            # Only show SP when using exploit
            ax_sp.barh(
                y_pos[i],
                d["sp_when_using_exploit"],
                bar_height,
                color="#0080FF",
                alpha=0.8,
                label="SP when using exploit" if i == 0 else "",
                xerr=d["sp_when_using_exploit_err"],
                error_kw=error_kw,
            )

    # ---- common formatting (kept visually identical) ----
    for j, ax in enumerate(axes):
        ax.set_xlim(0, 100)
        pad = 0.045  # matches original headroom: [-0.085, 0.085] when n == 2
        ax.set_ylim(min(y_pos) - pad, max(y_pos) + pad)
        ax.set_yticks(y_pos)
        ax.set_yticklabels([d["name"] for d in data])
        ax.grid(True, alpha=0.1, axis="x")
        ax.set_axisbelow(True)

        # column titles are now the metric names
        ax.set_title(
            "Exploit Gap (%)" if j == 0 else "Strategic Preservation Reasoning (%)",
            fontsize=11,
            pad=10,
        )
        for spine in ["top", "right"] + (["left"] if j > 0 else []):
            ax.spines[spine].set_visible(False)
        if j > 0:
            ax.tick_params(left=False)

    # legend based on sp_mode
    from matplotlib.patches import Rectangle

    if sp_mode == "both":
        colors = ["#4B0082", "#0080FF", "#00BFFF"]
        labels = ["Exploit Gap", "SP when using exploit", "SP when not using exploit"]
        alphas = [0.8, 0.8, 0.8]
    elif sp_mode == "not-using":
        colors = ["#4B0082", "#00BFFF"]
        labels = ["Exploit Gap", "SP when not using exploit"]
        alphas = [0.8, 0.8]
    else:  # sp_mode == "using"
        colors = ["#4B0082", "#0080FF"]
        labels = ["Exploit Gap", "SP when using exploit"]
        alphas = [0.8, 0.8]

    handles = [Rectangle((0, 0), 1, 1, fc=c, alpha=a) for c, a in zip(colors, alphas)]
    fig.legend(
        handles=handles,
        labels=labels,
        loc="lower center",
        ncol=len(labels),
        frameon=False,
        fontsize=11,
        bbox_to_anchor=(0.5, -0.15),
        handlelength=1,
        handleheight=1,
    )

    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"\nPlot saved to: {output_path}")
    else:
        plt.show()


def plot_scatter(config, output_path=None):
    data = []
    for name in config["setups"]:
        metrics = process_setup(config, name)
        setup = config["setups"][name]
        display_name = setup.get(
            "title", name
        )  # Use title if provided, otherwise use key
        data.append(
            {
                "name": display_name,
                "exploit_gap": abs(
                    metrics["exploit_gap"]["mean"]
                ),  # Use absolute value for plotting
                "sp_rate": metrics["sp_rate"]["mean"],
            }
        )
        print(f"\n{display_name}:")
        print(f"  Exploit gap: {metrics['exploit_gap']['mean']:.1f}%")
        print(f"  SP rate (training round): {metrics['sp_rate']['mean']:.1f}%")

    _, ax = plt.subplots(figsize=(10, 8))

    markers = ["o", "s", "^", "D", "v", "*", "p", "h", "<", ">", "P", "X"]
    colors = [
        "#0081A7",
        "#F77F00",
        "#9D4EDD",
        "#06A77D",
        "#D62246",
        "#003049",
        "#F4A460",
        "#2E8B57",
        "#FF1493",
        "#4682B4",
        "#8B4513",
        "#DA70D6",
    ]

    for i, d in enumerate(data):
        ax.scatter(
            d["exploit_gap"],
            d["sp_rate"],
            s=200,
            alpha=0.8,
            c=colors[i % len(colors)],
            marker=markers[i % len(markers)],
            label=d["name"],
            edgecolors="black",
            linewidth=0.5,
        )

    ax.set_xlabel("Exploit Gap - Absolute (%)", fontsize=12)
    ax.set_ylabel("Strategic Preservation Rate - Training Round (%)", fontsize=12)
    ax.set_title(
        "Strategic Preservation vs Exploit Gap", fontsize=14, fontweight="bold"
    )
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.grid(True, alpha=0.3)
    ax.legend(
        loc="upper center",
        frameon=True,
        fancybox=True,
        shadow=True,
        bbox_to_anchor=(0.5, -0.15),
        ncol=3,
        borderaxespad=0,
    )

    plt.tight_layout()
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        print(f"\nPlot saved to: {output_path}")
    else:
        plt.show()


def main():
    parser = argparse.ArgumentParser(
        description="Plot exploit gap and strategic preservation rates"
    )
    parser.add_argument(
        "--config",
        type=Path,
        required=True,
        help="YAML configuration file with log paths",
    )
    parser.add_argument("--output", type=Path, help="Output file path for the plot")
    parser.add_argument(
        "--scatter",
        action="store_true",
        help="Create scatter plot instead of bar chart",
    )
    parser.add_argument(
        "--stacked", action="store_true", help="Create stacked bar plot"
    )
    parser.add_argument(
        "--sp-mode",
        choices=["both", "not-using", "using"],
        default="both",
        help="Which SP data to show: 'both' (stacked), 'not-using' (only when not using exploit), 'using' (only when using exploit)",
    )

    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.safe_load(f)

    if args.scatter:
        plot_scatter(config, args.output)
    else:
        plot_stacked_bars(config, args.output, sp_mode=args.sp_mode)


if __name__ == "__main__":
    main()
