"""
Gaussian Experiment: Convergence Rate vs Horizon

Plots the estimated convergence rates (beta) for mean and covariance errors
against the theoretical prediction for Gaussian token distributions.

The theoretical curve: |beta_th| = 1 / (2 * (1 + Psi))
where Psi depends on the horizon H = ||Sigma^{1/2} A||_2
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# ============================================================
# CONFIGURATION
# ============================================================
OUTPUT_DIR = Path(__file__).parent / "figures"

# Theoretical curve parameters
D = 1  # dimension used in experiment


# ============================================================
# PLOTTING FUNCTION
# ============================================================
def plot_gaussian_convergence(df: pd.DataFrame, output_dir: Path):
    """
    Create ICML-format plot of convergence rates vs horizon.

    Args:
        df: DataFrame with columns: sparsity, slope_m_mean, slope_m_sd,
            slope_c_mean, slope_c_sd
        output_dir: Directory to save figures
    """
    # Theoretical curve
    x = np.linspace(0.001, 170, 10000)
    kappa_th = 0.5 * (1 - 1 / (1 + 1 / (x**2)))

    # ICML full column width: 6.75 inches, 3:2 aspect ratio
    fig, ax = plt.subplots()
    fig.set_size_inches(6.75, 4.5)

    # Plot mean error slopes (first 10 points as circles)
    ax.errorbar(
        -df["sparsity"].iloc[:10],
        -df["slope_m_mean"].iloc[:10],
        yerr=df["slope_m_sd"].iloc[:10],
        fmt="o",
        markerfacecolor="purple",
        markeredgecolor="purple",
        ecolor="purple",
        label=r"$|\beta_{mean}|$ for mean error",
    )

    # Plot 11th point (mean slope) with a star
    if len(df) > 10:
        ax.errorbar(
            -df["sparsity"].iloc[10:11],
            -df["slope_m_mean"].iloc[10:11],
            yerr=df["slope_m_sd"].iloc[10:11],
            fmt="*",
            markersize=12,
            markerfacecolor="purple",
            markeredgecolor="purple",
            ecolor="purple",
        )

    # Plot covariance error slopes
    ax.errorbar(
        -df['sparsity'],
        -df['slope_c_mean'],
        yerr=df['slope_c_sd'],
        fmt='s',
        markerfacecolor='orange',
        markeredgecolor='orange',
        ecolor='orange',
        label=r"$|\beta_{cov}|$ for covariance error"
    )

    # Theoretical curve
    plt.plot(
        x,
        kappa_th,
        label=r'$|\beta_{th}| = \frac{1}{2(1 + \Psi)}$',
        lw=2,
        color="darkred",
    )

    # Reference line at beta = 0.5
    ax.axhline(0.5, color='grey', linestyle='--', linewidth=1.5, alpha=0.8,
               label=r'$|\beta|=0.5$')

    # Axis settings
    plt.xlim(0.1, 8.5)
    plt.ylim(0.001, 2)

    ax.set_xlabel(r"Horizon H", fontsize=20)
    ax.set_ylabel(r"$|\beta|$", fontsize=20)
    ax.tick_params(axis='both', which='major', labelsize=12)

    ax.set_xscale('log')
    ax.set_yscale('log')

    # Legend
    ax.legend(
        loc='lower left',
        fontsize=14,
        frameon=True,
        framealpha=0.9
    )

    plt.tight_layout()

    # Save figures
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / "gaussian_convergence.pdf", dpi=300, bbox_inches='tight')
    plt.savefig(output_dir / "gaussian_convergence.png", dpi=600, bbox_inches='tight')
    print(f"Saved: {output_dir}/gaussian_convergence.pdf")
    print(f"Saved: {output_dir}/gaussian_convergence.png")

    plt.close()


# ============================================================
# MAIN
# ============================================================
def main():
    """
    Main function - loads data and generates plot.

    To use: place your results CSV in the same directory as this script,
    or modify the DATA_FILE path below.
    """
    print("=" * 60)
    print("Gaussian Convergence Plot (ICML Format)")
    print("=" * 60)

    # Look for data file in common locations
    script_dir = Path(__file__).parent
    possible_paths = [
        script_dir / "data" / "gaussian_results.csv",
        script_dir / "gaussian_results.csv",
        script_dir.parent.parent / "results" / "results_multi_Nexp5_Rrep50_4.csv",
    ]

    data_file = None
    for path in possible_paths:
        if path.exists():
            data_file = path
            break

    if data_file is None:
        print("\nNo data file found. Expected locations:")
        for p in possible_paths:
            print(f"  - {p}")
        print("\nPlease provide a CSV with columns:")
        print("  sparsity, slope_m_mean, slope_m_sd, slope_c_mean, slope_c_sd")
        return

    print(f"\nLoading data from: {data_file}")
    df = pd.read_csv(data_file)
    print(f"Loaded {len(df)} data points")

    plot_gaussian_convergence(df, OUTPUT_DIR)
    print("\nDone!")


if __name__ == "__main__":
    main()
