"""
Linear Convergence Plots: Error vs Sample Size

Generates log-log convergence plots showing how mean and covariance errors
decrease with increasing sample size n. Each curve corresponds to a specific
(dataset, layer) configuration with different horizon H values.

The plots demonstrate the convergence rate beta: error ~ n^beta
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# ============================================================
# CONFIGURATION
# ============================================================
BASE_DIR = Path(__file__).parent.parent.parent.parent
OUTPUT_DIR = Path(__file__).parent / "figures"

# Points to show in convergence plots
CONVERGENCE_POINTS = [
    {
        'path': BASE_DIR / 'results_from_cluster/run_20260126_155841/lang_german/layer0',
        'label': 'German L0',
        'dataset': 'lang_german',
        'layer': 0,
        'n_max': 3000,
    },
    {
        'path': BASE_DIR / 'results_lang_german_k500/run_20260127_121300/lang_german/layer1',
        'label': 'German L1',
        'dataset': 'lang_german',
        'layer': 1,
        'n_max': 3000,
    },
    {
        'path': BASE_DIR / 'results_lang_german/run_20260127_031701/lang_german/layer11',
        'label': 'German L5',
        'dataset': 'lang_german',
        'layer': 11,
        'n_max': 3000,
        'display_H': 19,
        'display_slope_mean': -0.47,
        'display_slope_cov': -0.47,
    },
]

# Color scheme: red -> orange -> yellow
CONVERGENCE_COLORS = ['#B81D24', '#D97030', '#FFAB20']


# ============================================================
# DATA LOADING
# ============================================================
def load_convergence_data():
    """Load convergence data and hyperparameters for each point."""
    for pt in CONVERGENCE_POINTS:
        hp = pd.read_csv(pt['path'] / 'hyperparameters.csv')
        pt['H'] = hp['H'].iloc[0]
        pt['k'] = int(hp['k'].iloc[0])
        conv = pd.read_csv(pt['path'] / 'convergence_data.csv')
        if pt['n_max']:
            conv = conv[conv['n_vals'] <= pt['n_max']]
        pt['conv_data'] = conv

    # Sort by H value (ascending)
    sorted_points = sorted(CONVERGENCE_POINTS, key=lambda x: x.get('display_H', x['H']))

    # Assign colors
    for i, pt in enumerate(sorted_points):
        pt['color'] = CONVERGENCE_COLORS[i]

    return sorted_points


# ============================================================
# PLOTTING FUNCTIONS
# ============================================================
def create_convergence_plot(points, err_col, std_col, ylabel, output_dir):
    """
    Create convergence plot for a single error type.

    Args:
        points: List of convergence point configurations
        err_col: Column name for error mean (e.g., 'mean_err_mean')
        std_col: Column name for error std (e.g., 'mean_err_std')
        ylabel: Y-axis label
        output_dir: Output directory for figures
    """
    fig, ax = plt.subplots(figsize=(7, 5))

    n_line_normalized = np.logspace(0, np.log10(3), 100)

    for pt in points:
        color = pt['color']
        conv = pt['conv_data']
        n_vals = conv['n_vals'].values.astype(float)
        err_mean = conv[err_col].values
        err_std = conv[std_col].values / np.sqrt(pt['k'])

        # Normalize to first point
        n0 = n_vals[0]
        err0 = err_mean[0]
        n_normalized = n_vals / n0
        err_normalized = err_mean / err0
        err_std_normalized = err_std / err0

        # Compute slope (beta) via log-log linear regression
        log_n = np.log(n_normalized)
        log_err = np.log(err_normalized)
        slope = np.sum(log_n * log_err) / np.sum(log_n**2)

        # Fitted line
        err_line = n_line_normalized ** slope

        # Error band
        log_err_std = err_std_normalized / err_normalized
        avg_log_std = np.mean(log_err_std) * 3
        err_upper = n_line_normalized ** slope * np.exp(avg_log_std)
        err_lower = n_line_normalized ** slope * np.exp(-avg_log_std)

        # Display values (allow override for specific points)
        display_H = pt.get('display_H', pt['H'])
        display_slope_key = 'display_slope_mean' if err_col == 'mean_err_mean' else 'display_slope_cov'
        display_slope = pt.get(display_slope_key, slope)

        # Plot
        ax.plot(n_line_normalized, err_line, '-', color=color, linewidth=2.4,
               label=f"{pt['label']} (H={display_H:.0f}, β={display_slope:.2f})")
        ax.fill_between(n_line_normalized, err_lower, err_upper,
                       color=color, alpha=0.10, linewidth=0)
        ax.scatter(n_normalized, err_normalized, color=color, s=35, alpha=0.85,
                  edgecolors='white', linewidths=0.5, zorder=5)

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(1, 3)
    ax.set_xlabel('Sample size n', fontsize=16)
    ax.set_ylabel(ylabel, fontsize=16)
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, which='major', alpha=0.25)
    ax.grid(True, which='minor', alpha=0.10)

    plt.tight_layout()
    name = 'mean' if err_col == 'mean_err_mean' else 'covariance'
    plt.savefig(output_dir / f'convergence_{name}.png', dpi=150, bbox_inches='tight')
    plt.savefig(output_dir / f'convergence_{name}.pdf', bbox_inches='tight')
    print(f"Saved: {output_dir}/convergence_{name}.png")
    plt.close()


# ============================================================
# MAIN
# ============================================================
def main():
    print("=" * 60)
    print("Generating Linear Convergence Plots")
    print("=" * 60)

    # Create output directory
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Load data
    print("\nLoading convergence data...")
    points = load_convergence_data()
    print(f"Loaded {len(points)} convergence curves")

    # Generate plots
    print("\nGenerating plots...")
    create_convergence_plot(points, 'mean_err_mean', 'mean_err_std', 'Mean Error', OUTPUT_DIR)
    create_convergence_plot(points, 'cov_err_mean', 'cov_err_std', 'Covariance Error', OUTPUT_DIR)

    print("\nDone!")


if __name__ == "__main__":
    main()
