import glob
import json
import os
import sys

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np  # Ensure numpy is imported
import pandas as pd
from matplotlib.gridspec import GridSpec

# --- Style and IO Utilities (Adapted from your provided code) ---


def apply_nature_style():
    """Applies a Nature-like plotting style with larger fonts for readability."""
    mpl.rcParams.update(
        {
            "figure.dpi": 150,
            "savefig.dpi": 300,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial", "Helvetica", "DejaVu Sans"],
            "mathtext.fontset": "stixsans",
            "axes.spines.right": True,
            "axes.spines.top": True,
            "axes.linewidth": 1.2,
            "axes.labelsize": 24,
            "axes.titlesize": 24,
            "xtick.labelsize": 24,
            "ytick.labelsize": 24,
            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.major.size": 4,
            "ytick.major.width": 1.2,
            "ytick.major.size": 4,
            "ytick.major.width": 1.2,
            "legend.frameon": False,
            "legend.fontsize": 15,
            "axes.grid": False,
        }
    )


def add_panel_label(ax, label, x=-0.1, y=1.05):
    """Adds a bold panel label to the top-left of an axes."""
    ax.text(
        x,
        y,
        label,
        transform=ax.transAxes,
        fontsize=22,
        fontweight="bold",
        va="bottom",
        ha="left",
    )


def find_file(base_path, pattern):
    files = glob.glob(os.path.join(base_path, pattern))
    return files[0] if files else None


# --- Plotting Functions for Each Panel ---


def plot_panel_a(ax):
    """
    Panel A: Representation vs. Causal Impact.
    """
    base_path = sys.argv[1]
    num_layers = len(glob.glob(os.path.join(base_path, "L[0-9]*")))

    probe_file = find_file(base_path, "budget/valence_probe_by_layer.csv")
    if not probe_file:
        return
    probe_df = pd.read_csv(probe_file)

    causality_results = []
    for i in range(num_layers):
        causality_file = find_file(
            base_path, f"L{i}/causality/valence_causality_L{i}.json"
        )
        if causality_file:
            with open(causality_file, "r") as f:
                data = json.load(f)
                for res in data["results"]:
                    if res["alpha"] == 2.0:
                        res["layer"] = data["layer"]
                        causality_results.append(res)

    if not causality_results:
        return
    causality_df = pd.DataFrame(causality_results)
    causality_df["causal_effect"] = (
        (causality_df["dlogit_double"] + causality_df["dlogit_triple"]) / 2
    ) - causality_df["dlogit_single"]

    # Plotting
    color1 = "royalblue"
    ax.set_xlabel("Transformer Layer")
    ax.set_ylabel("Probe Acc. (%)", color=color1)
    ax.plot(
        probe_df["layer"],
        probe_df["acc"],
        color=color1,
        marker="o",
        markersize=8,
        linewidth=2.5,
        label="Representation (Probe Acc.)",
    )
    ax.tick_params(axis="y", colors=color1)
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda y, _: f"{int(y * 100)}"))
    # --- MODIFICATION: Set a firm upper limit to ensure the 100% tick appears correctly ---
    ax.set_ylim(bottom=0.80, top=1.0)
    ax.grid(True, axis="y", linestyle="--", linewidth=0.6, color="#cccccc")

    ax2 = ax.twinx()
    color2 = "crimson"
    ax2.set_ylabel("Causal Effect (Δ Logit)", color=color2)
    ax2.plot(
        causality_df["layer"],
        causality_df["causal_effect"],
        color=color2,
        marker="s",
        markersize=8,
        linewidth=2.5,
        linestyle="--",
        label="Causal Impact",
    )
    ax2.tick_params(axis="y", colors=color2, length=0)
    ax2.axhline(0, color="grey", linestyle=":", linewidth=1.2)

    # --- MODIFICATION: Removed the code that forced max values onto the ticks ---
    # Matplotlib's default locator will now handle spacing to avoid overlap.

    ax.set_title("Probe accuracy vs. causality", y=1.04)
    ax.set_xticks(probe_df["layer"])

    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc="lower right")


def plot_panel_b(ax):
    """
    Panel B: Causal Control of Bond Prediction.
    """
    base_path = sys.argv[1]
    causality_file = find_file(base_path, "L3/causality/valence_causality_L3.json")
    if not causality_file:
        return

    with open(causality_file, "r") as f:
        data = json.load(f)["results"]

    causality_df = pd.DataFrame(data)

    # Plotting
    ax.plot(
        causality_df["alpha"],
        causality_df["dlogit_single"],
        marker="o",
        markersize=8,
        linewidth=2.5,
        label="Single Bond (-)",
    )
    ax.plot(
        causality_df["alpha"],
        causality_df["dlogit_double"],
        marker="s",
        markersize=8,
        linewidth=2.5,
        label="Double Bond (=)",
    )
    ax.plot(
        causality_df["alpha"],
        causality_df["dlogit_triple"],
        marker="^",
        markersize=8,
        linewidth=2.5,
        label="Triple Bond (#)",
    )

    # Formatting
    ax.axhline(0, color="black", linestyle="--", linewidth=1.0, alpha=0.8)
    ax.axvline(0, color="black", linestyle="--", linewidth=1.0, alpha=0.8)

    # --- MODIFICATION: Customize y-axis ticks to remove the top tick and show the max value ---
    # Find the absolute maximum value from the relevant data columns
    max_dlogit = (
        causality_df[["dlogit_single", "dlogit_double", "dlogit_triple"]].max().max()
    )

    # Get the ticks that matplotlib would normally generate
    current_ticks = ax.get_yticks()

    # Filter out any auto-generated ticks that are above our desired maximum,
    # then add our specific maximum value to the list.
    new_ticks = [tick for tick in current_ticks if tick < max_dlogit]
    new_ticks.append(max_dlogit)

    # Set the new, custom ticks, ensuring they are sorted and unique
    ax.set_yticks(sorted(np.unique(new_ticks)))
    # --- END MODIFICATION ---

    ax.set_title("Causal control (layer 3)", y=1.04)
    ax.set_xlabel("Steering Strength")
    ax.set_ylabel("Change in Logit (Δ Logit)")
    ax.legend(title="Bond Type")
    ax.grid(True, axis="y", linestyle="--", linewidth=0.6, color="#cccccc")
    add_panel_label(ax, "b")


# --- Main Figure Assembly ---


def create_combined_figure(output_pdf="figure_valence_mechanism.pdf"):
    """
    Assembles the final two-panel figure and saves it as a PDF.
    """
    apply_nature_style()

    fig = plt.figure(figsize=(20, 6))
    gs = GridSpec(1, 2, figure=fig, wspace=0.7)

    ax_a = fig.add_subplot(gs[0, 0])
    plot_panel_a(ax_a)

    ax_b = fig.add_subplot(gs[0, 1])
    plot_panel_b(ax_b)

    plt.tight_layout()

    fig.savefig(output_pdf, bbox_inches="tight")
    plt.close(fig)
    print(f"✓ Combined figure saved as '{output_pdf}'")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print(
            "Usage: python generate_final_figure.py <path_to_valence_suite_all_layers>"
        )
        sys.exit(1)

    base_path = sys.argv[1]
    if not os.path.isdir(base_path):
        print(f"Error: Directory not found at '{base_path}'")
        sys.exit(1)

    try:
        create_combined_figure()
    except Exception as e:
        print(f"[ERROR] Could not generate figure. Reason: {e}")
