# %%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch


def setup_plot_style():
    try:
        plt.style.use('seaborn-v0_8-whitegrid')
    except OSError:
        plt.style.use('default')
        plt.grid(True, alpha=0.3)

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 11,
        'figure.titlesize': 16
    })


def plot_stability_basin(ax):
    """
    Illustrates Theorem 5.1: The phase transition with respect to LoRA Norm.
    Intuition: A potential well. Small perturbations return; large ones escape.
    """
    x = np.linspace(-1.5, 3.5, 400)
    # Potential landscape
    potential = 0.5 * x**2 - 0.05 * x**4

    # Plot potential landscape
    ax.plot(x, potential, color='#333333', lw=2,
            label='Optimization Landscape')

    # Equilibrium point
    ax.plot(0, 0, 'go', markersize=10,
            label=r'Pre-trained Weights ($X_*$)', zorder=5)

    # 1. Small Perturbation (Stable)
    ax.annotate('', xy=(1.0, 0.5), xytext=(0.1, 0.1),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2.5, mutation_scale=20))
    ax.text(0.6, 0.7, 'Small LoRA\n(Stable)',
            color='blue', ha='center', fontweight='bold')

    # 2. Large Perturbation (Unstable)
    # FIX: Increased the arc curvature (rad=-0.5) to physically clear the hill peak
    # Adjusted end point (xy) further down the slope to emphasize "escape"
    ax.annotate('', xy=(3.0, 0.0), xytext=(0.1, 0.1),
                arrowprops=dict(arrowstyle='->', color='red', lw=2.5,
                                connectionstyle="arc3,rad=-0.5", mutation_scale=20))

    # Adjusted label position to accommodate higher arrow arc
    ax.text(2.6, 2.0, 'Large LoRA\n(Catastrophic\nForgetting)',
            color='red', ha='center', fontweight='bold')

    # Styling
    ax.set_ylim(-1.0, 3.5)  # Increased y-limit for the higher arrow
    ax.set_xlim(-1.5, 3.5)
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel("Representation Space")
    ax.set_ylabel("Potential Energy")
    ax.legend(loc='upper left', frameon=True)


def plot_depth_bifurcation(ax):
    t = np.linspace(0, 15, 300)
    traj_orig = np.tanh(t - 5) + 0.1 * np.sin(t)

    t_bifurcation = 8.0
    divergence = np.zeros_like(t)
    mask = t > t_bifurcation
    divergence[mask] = -2.5 * np.tanh(0.8 * (t[mask] - t_bifurcation))
    traj_lora = traj_orig + divergence

    ax.plot(t, traj_orig, 'k-', lw=3, label='Base Model Dynamics', alpha=0.7)
    ax.plot(t, traj_lora, 'r--', lw=3, label='LoRA Perturbed Dynamics')

    ax.axvspan(0, t_bifurcation, color='green', alpha=0.1,
               label=r'Stability Window ($W_2$ small)')
    ax.axvline(t_bifurcation, color='grey', linestyle=':', lw=2)
    ax.text(t_bifurcation + 0.2, 0, r'$T^*$ (Bifurcation)',
            rotation=90, va='center')
    ax.text(14, 1.1, 'Cluster A\n(Correct)',
            color='black', ha='center', fontweight='bold')
    ax.text(14, -1.5, 'Cluster B\n(Forgot)',
            color='red', ha='center', fontweight='bold')

    ax.set_xlabel(r"Depth (Time $t$)")
    ax.set_ylabel("Token Position / Clustering")
    ax.legend(loc='lower left', frameon=True)
    ax.set_ylim(-2, 2)
    ax.set_yticks([])


def plot_spectral_gap(ax):
    x = np.linspace(-2, 2, 200)
    steep_well = 4 * x**2
    ax.plot(x, steep_well, 'b-', lw=2,
            label=r'Large Gap ($\lambda_1 \gg \lambda_2$)')
    shallow_well = 0.8 * x**2
    ax.plot(x, shallow_well, 'orange', lw=2, ls='-',
            label=r'Small Gap ($\lambda_1 \approx \lambda_2$)')

    ax.annotate('', xy=(0.5, 4*0.5**2), xytext=(0, 0),
                arrowprops=dict(arrowstyle='->', color='blue', lw=2, mutation_scale=20))
    ax.annotate('', xy=(1.5, 0.8*1.5**2), xytext=(0, 0),
                arrowprops=dict(arrowstyle='->', color='orange', lw=2, mutation_scale=20))

    ax.text(0.6, 2.5, 'Robust\nConfinement',
            color='blue', fontsize=10, fontweight='bold')
    ax.text(1.6, 1.0, 'Easy\nEscape', color='orange',
            fontsize=10, fontweight='bold')

    ax.set_xlabel("Eigenspace Direction")
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_ylim(0, 5)
    ax.legend(loc='upper right', frameon=True)


def generate_schematic_plots():
    setup_plot_style()
    figures = [
        ("schematic_norm_phase_transition.pdf", plot_stability_basin, (7, 5)),
        ("schematic_depth_bifurcation.pdf", plot_depth_bifurcation, (7, 5)),
        ("schematic_spectral_gap.pdf", plot_spectral_gap, (7, 5)),
    ]
    for filename, plot_fn, size in figures:
        fig, ax = plt.subplots(figsize=size)
        plot_fn(ax)
        fig.tight_layout()
        fig.savefig(filename)
        plt.close(fig)
        print(f"Diagram successfully saved to '{filename}'")


if __name__ == "__main__":
    generate_schematic_plots()

# %%
