#!/usr/bin/env python3
"""
Singular Value Analysis Visualization for ICLR Paper
Final integrated version with improved readability
"""

import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

# Set publication-quality defaults with large fonts for readability
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['axes.titlesize'] = 17  # Slightly reduced, no bold
plt.rcParams['xtick.labelsize'] = 13
plt.rcParams['ytick.labelsize'] = 13
plt.rcParams['legend.fontsize'] = 15
plt.rcParams['figure.titlesize'] = 18
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.1

# ICLR-appropriate color palette
COLORS = {
    'lora': '#2E86AB',  # Deep blue
    'seal': '#A23B72',  # Deep magenta
    'diff': '#F18F01',  # Orange for differences
    'neutral': '#7F7F7F',   # Gray for reference lines
    'positive': '#D62728',  # Red for SEAL > LoRA
    'negative': '#1F77B4'   # Blue for LoRA > SEAL
}


def load_json_data(file_path: str) -> Dict:
    """Load singular value data from JSON file."""
    with open(file_path, 'r') as f:
        return json.load(f)


def process_data_for_plots(lora_data: List[Dict], seal_data: List[Dict],
                          model_names: List[str]) -> Dict:
    """
    Process loaded JSON data for all plots.
    Uses pre-computed metrics from JSON: r_eff, S_k, p_i
    """
    processed = {
        'models': model_names,
        'r_eff': {'lora': [], 'seal': []},
        'S_k': {'lora': [], 'seal': []},
        'p_i': {'lora': [], 'seal': []},
        'layer_names': {'lora': [], 'seal': []}
    }

    for model_idx, (lora_json, seal_json, model_name) in enumerate(zip(lora_data, seal_data, model_names)):
        # Process LoRA data
        if lora_json is not None:
            lora_r_eff = []
            lora_S_k_all = []
            lora_p_i_all = []
            lora_layers = []

            for layer_name, layer_data in lora_json.items():
                if isinstance(layer_data, dict):
                    # Use pre-computed values
                    if 'r_eff' in layer_data:
                        lora_r_eff.append(layer_data['r_eff'])

                    if 'S_k' in layer_data:
                        lora_S_k_all.append(layer_data['S_k'][:32])

                    if 'p_i' in layer_data:
                        lora_p_i_all.append(layer_data['p_i'][:32])

                    lora_layers.append(layer_name)

            processed['r_eff']['lora'].append(lora_r_eff)
            processed['S_k']['lora'].append(lora_S_k_all)
            processed['p_i']['lora'].append(lora_p_i_all)
            processed['layer_names']['lora'].append(lora_layers)
        else:
            processed['r_eff']['lora'].append(None)
            processed['S_k']['lora'].append(None)
            processed['p_i']['lora'].append(None)
            processed['layer_names']['lora'].append(None)

        # Process SEAL data
        if seal_json is not None:
            seal_r_eff = []
            seal_S_k_all = []
            seal_p_i_all = []
            seal_layers = []

            for layer_name, layer_data in seal_json.items():
                if isinstance(layer_data, dict):
                    # Use pre-computed values
                    if 'r_eff' in layer_data:
                        seal_r_eff.append(layer_data['r_eff'])

                    if 'S_k' in layer_data:
                        seal_S_k_all.append(layer_data['S_k'][:32])

                    if 'p_i' in layer_data:
                        seal_p_i_all.append(layer_data['p_i'][:32])

                    seal_layers.append(layer_name)

            processed['r_eff']['seal'].append(seal_r_eff)
            processed['S_k']['seal'].append(seal_S_k_all)
            processed['p_i']['seal'].append(seal_p_i_all)
            processed['layer_names']['seal'].append(seal_layers)
        else:
            processed['r_eff']['seal'].append(None)
            processed['S_k']['seal'].append(None)
            processed['p_i']['seal'].append(None)
            processed['layer_names']['seal'].append(None)

    return processed


def plot_r_eff_violin(processed_data: Dict, output_path: str = 'r_eff_violin.pdf'):
    """
    Main Plot 1: Split violin plot for r_eff distribution with median values displayed
    """
    n_models = len(processed_data['models'])

    # Use 1x3 layout for 3 models, 2x2 for 4+ models
    if n_models == 3:
        fig, axes = plt.subplots(1, 3, figsize=(15, 3))
    else:
        fig, axes = plt.subplots(2, 2, figsize=(12, 6))

    axes = axes.flatten()

    for idx, model_name in enumerate(processed_data['models']):
        if idx >= n_models:
            break

        ax = axes[idx]

        # Prepare data for violin plot
        lora_r_eff = processed_data['r_eff']['lora'][idx]
        seal_r_eff = processed_data['r_eff']['seal'][idx]

        if lora_r_eff is None and seal_r_eff is None:
            ax.text(0.5, 0.5, f'{model_name}\n(No data)',
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_xlim(0, 1)
            ax.set_ylim(0, 1)
            continue

        # Create DataFrame for plotting
        plot_data = []

        if lora_r_eff is not None:
            for val in lora_r_eff:
                plot_data.append({'Method': 'LoRA', 'r_eff': val, 'x': 0})

        if seal_r_eff is not None:
            for val in seal_r_eff:
                plot_data.append({'Method': 'SEAL', 'r_eff': val, 'x': 0})

        df = pd.DataFrame(plot_data)

        # Create split violin plot
        if len(df[df['Method'] == 'LoRA']) > 0 and len(df[df['Method'] == 'SEAL']) > 0:
            # Both methods available
            parts = ax.violinplot([df[df['Method'] == 'LoRA']['r_eff'].values],
                                 positions=[0], widths=0.8, showmeans=False, showmedians=False)
            for pc in parts['bodies']:
                pc.set_facecolor(COLORS['lora'])
                pc.set_alpha(0.6)
                vertices = pc.get_paths()[0].vertices
                vertices[:, 0] = np.clip(vertices[:, 0], -np.inf, 0)

            parts = ax.violinplot([df[df['Method'] == 'SEAL']['r_eff'].values],
                                 positions=[0], widths=0.8, showmeans=False, showmedians=False)
            for pc in parts['bodies']:
                pc.set_facecolor(COLORS['seal'])
                pc.set_alpha(0.6)
                vertices = pc.get_paths()[0].vertices
                vertices[:, 0] = np.clip(vertices[:, 0], 0, np.inf)

            # Add scatter points with jitter
            np.random.seed(42)
            lora_jitter = np.random.normal(-0.15, 0.05, len(lora_r_eff))
            seal_jitter = np.random.normal(0.15, 0.05, len(seal_r_eff))

            ax.scatter(lora_jitter, lora_r_eff, alpha=0.5, s=14, c=COLORS['lora'], label='LoRA (points)')
            ax.scatter(seal_jitter, seal_r_eff, alpha=0.5, s=14, c=COLORS['seal'], label='SEAL (points)')

            # Calculate statistics
            lora_median = np.median(lora_r_eff)
            lora_q1 = np.percentile(lora_r_eff, 25)
            lora_q3 = np.percentile(lora_r_eff, 75)

            seal_median = np.median(seal_r_eff)
            seal_q1 = np.percentile(seal_r_eff, 25)
            seal_q3 = np.percentile(seal_r_eff, 75)

            # Add median lines with values
            ax.hlines(lora_median, -0.4, 0, colors=COLORS['lora'], linewidth=2.5, linestyles='solid', label='Median')
            ax.text(-0.42, lora_median, f'{lora_median:.2f}', va='center', ha='right',
                   fontsize=12, color=COLORS['lora'], weight='bold')

            ax.hlines([lora_q1, lora_q3], -0.35, -0.05, colors=COLORS['lora'],
                     linewidth=1, linestyles='dashed', alpha=0.5, label='25-75%' if idx == 0 else '')

            ax.hlines(seal_median, 0, 0.4, colors=COLORS['seal'], linewidth=2.5, linestyles='solid')
            ax.text(0.42, seal_median, f'{seal_median:.2f}', va='center', ha='left',
                   fontsize=12, color=COLORS['seal'], weight='bold')

            ax.hlines([seal_q1, seal_q3], 0.05, 0.35, colors=COLORS['seal'],
                     linewidth=1, linestyles='dashed', alpha=0.5)

            # Add legend for first subplot
            if idx == 0:
                ax.legend(loc='upper right', frameon=True, fancybox=True, shadow=False, fontsize=14)

        elif len(df[df['Method'] == 'LoRA']) > 0:
            # Only LoRA available
            parts = ax.violinplot([df[df['Method'] == 'LoRA']['r_eff'].values],
                                 positions=[0], widths=0.4, showmeans=False, showmedians=True)
            for pc in parts['bodies']:
                pc.set_facecolor(COLORS['lora'])
                pc.set_alpha(0.6)

            np.random.seed(42)
            lora_jitter = np.random.normal(0, 0.05, len(lora_r_eff))
            ax.scatter(lora_jitter, lora_r_eff, alpha=0.3, s=10, c=COLORS['lora'], label='LoRA')

            lora_median = np.median(lora_r_eff)
            ax.text(0.22, lora_median, f'{lora_median:.2f}', va='center', ha='left',
                   fontsize=12, color=COLORS['lora'], weight='bold')

            ax.legend(loc='upper right', frameon=True, fontsize=10)

        ax.set_xlim(-0.5, 0.5)
        ax.set_ylabel('Effective Rank ($r_{eff}$)')
        ax.set_title(model_name.title())  # No bold
        ax.set_xticks([])
        ax.grid(axis='y', alpha=0.3, linestyle='--')

    # Hide unused subplots if using 2x2 layout
    if n_models < 4 and len(axes) > n_models:
        for idx in range(n_models, len(axes)):
            axes[idx].set_visible(False)

    # No suptitle - moved to LaTeX caption
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    print(f"Saved r_eff violin plot to {output_path}")
    plt.close()


def plot_delta_s_k_curves(processed_data: Dict, output_path: str = 'delta_s_k_curves.pdf'):
    """
    Main Plot 2: ΔS_k curves showing cumulative energy difference
    """
    n_models = len(processed_data['models'])

    # Use 1x3 layout for 3 models, 2x2 for 4+ models
    if n_models == 3:
        fig, axes = plt.subplots(1, 3, figsize=(15, 3))
    else:
        fig, axes = plt.subplots(2, 2, figsize=(12, 6))

    axes = axes.flatten()

    k_values = np.arange(1, 33)
    highlight_k = [1, 4, 8, 16, 32]

    for idx, model_name in enumerate(processed_data['models']):
        if idx >= n_models:
            break

        ax = axes[idx]

        lora_s_k = processed_data['S_k']['lora'][idx]
        seal_s_k = processed_data['S_k']['seal'][idx]

        if lora_s_k is None or seal_s_k is None:
            ax.text(0.5, 0.5, f'{model_name}\n(Incomplete data)',
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_xlim(0, 32)
            ax.set_ylim(-0.1, 0.1)
            continue

        # Calculate ΔS_k = S_k(SEAL) - S_k(LoRA)
        # Need to handle cases where LoRA and SEAL have different numbers of layers
        lora_s_k_array = np.array(lora_s_k)
        seal_s_k_array = np.array(seal_s_k)

        # Find matching layers by taking minimum number of layers
        if lora_s_k_array.shape[0] != seal_s_k_array.shape[0]:
            print(f"  Warning: {model_name} has different layer counts (LoRA: {lora_s_k_array.shape[0]}, SEAL: {seal_s_k_array.shape[0]})")
            min_layers = min(lora_s_k_array.shape[0], seal_s_k_array.shape[0])
            print(f"  Using first {min_layers} layers for comparison")
            lora_s_k_array = lora_s_k_array[:min_layers, :]
            seal_s_k_array = seal_s_k_array[:min_layers, :]

        # Ensure arrays have same k dimension
        min_k = min(lora_s_k_array.shape[1], seal_s_k_array.shape[1], 32)
        lora_s_k_array = lora_s_k_array[:, :min_k]
        seal_s_k_array = seal_s_k_array[:, :min_k]

        # Calculate statistics
        delta_s_k_mean = np.mean(seal_s_k_array - lora_s_k_array, axis=0)
        delta_s_k_q1 = np.percentile(seal_s_k_array - lora_s_k_array, 25, axis=0)
        delta_s_k_q3 = np.percentile(seal_s_k_array - lora_s_k_array, 75, axis=0)

        # Plot main curve with 25-75% ribbon
        ax.plot(k_values[:min_k], delta_s_k_mean, color=COLORS['diff'], linewidth=2.5, label='Mean ΔS$_k$')
        ax.fill_between(k_values[:min_k], delta_s_k_q1, delta_s_k_q3,
                        color=COLORS['diff'], alpha=0.2, label='25-75%')

        # Highlight specific k values with improved annotation placement
        for k in highlight_k:
            if k <= min_k:
                ax.scatter(k, delta_s_k_mean[k-1], s=80, c=COLORS['diff'],
                          edgecolors='black', linewidth=1.5, zorder=5)

                # Smart annotation placement to avoid overlaps
                val = delta_s_k_mean[k-1]

                # Place annotations strategically based on k position
                if k == 1:
                    # First value - place to the right
                    ax.annotate(f'{val:.3f}',
                               xy=(k, val),
                               xytext=(k+0.5, val-0.03),
                               fontsize=11, weight='bold',
                               ha='center', va='top',
                            #    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='none', alpha=0.8),
                               zorder=10)
                elif k == 32:
                    # Last value - place to the left
                    ax.annotate(f'{val:.3f}',
                               xy=(k, val),
                               xytext=(k-0.5, val-0.03),
                               fontsize=11, weight='bold',
                               ha='right', va='center',
                            #    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='none', alpha=0.8),
                               zorder=10)
                else:
                    # Middle values - alternate up/down placement
                    offset_y = 0.02 if k in [4, 16] else -0.02
                    ax.annotate(f'{val:.3f}',
                               xy=(k, val),
                               xytext=(k, val + offset_y),
                               fontsize=11, weight='bold',
                               ha='center', va='bottom' if offset_y > 0 else 'top',
                            #    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='none', alpha=0.8),
                               zorder=10)

        # Add reference line at y=0
        ax.axhline(y=0, color=COLORS['neutral'], linewidth=2, linestyle='--', alpha=0.7, label='Baseline')

        # Styling
        ax.set_xlabel('k (number of modes)')
        ax.set_ylabel('ΔS$_k$')
        ax.set_title(model_name.title())  # No bold
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_xlim(0, 33)
        # Place legend in less obtrusive position (best to avoid k=32 annotation)
        ax.legend(loc='upper right', frameon=True, fontsize=10)

    # Hide unused subplots if using 2x2 layout
    if n_models < 4 and len(axes) > n_models:
        for idx in range(n_models, len(axes)):
            axes[idx].set_visible(False)

    # No suptitle - moved to LaTeX caption
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    print(f"Saved ΔS_k curves to {output_path}")
    plt.close()


def plot_mode_energy_lollipop(processed_data: Dict, output_path: str = 'mode_energy_lollipop.pdf'):
    """
    Supplementary Plot: Lollipop chart for mode energy differences (replaces heatmap)
    """
    n_models = len(processed_data['models'])

    # Use 1x3 layout for 3 models, 2x2 for 4+ models
    if n_models == 3:
        fig, axes = plt.subplots(1, 3, figsize=(15, 3.5))
    else:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    axes = axes.flatten()

    for idx, model_name in enumerate(processed_data['models']):
        if idx >= n_models:
            break

        ax = axes[idx]

        lora_p_i = processed_data['p_i']['lora'][idx]
        seal_p_i = processed_data['p_i']['seal'][idx]

        if lora_p_i is None or seal_p_i is None:
            ax.text(0.5, 0.5, f'{model_name}\n(Incomplete data)',
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_xlim(0, 33)
            ax.set_ylim(-0.05, 0.05)
            continue

        # Calculate median difference
        lora_p_i_array = np.array(lora_p_i)
        seal_p_i_array = np.array(seal_p_i)

        # Handle different layer counts
        if lora_p_i_array.shape[0] != seal_p_i_array.shape[0]:
            print(f"  Warning: {model_name} has different layer counts for p_i (LoRA: {lora_p_i_array.shape[0]}, SEAL: {seal_p_i_array.shape[0]})")
            min_layers = min(lora_p_i_array.shape[0], seal_p_i_array.shape[0])
            print(f"  Using first {min_layers} layers for comparison")
            lora_p_i_array = lora_p_i_array[:min_layers, :]
            seal_p_i_array = seal_p_i_array[:min_layers, :]

        median_p_i_lora = np.median(lora_p_i_array, axis=0)
        median_p_i_seal = np.median(seal_p_i_array, axis=0)
        diff = median_p_i_seal - median_p_i_lora

        x = np.arange(1, len(diff) + 1)

        # Create lollipops
        ax.hlines(0, 0.5, 32.5, color=COLORS['neutral'], linewidth=1.5, alpha=0.5)

        for i, (xi, val) in enumerate(zip(x, diff)):
            color = COLORS['positive'] if val > 0 else COLORS['negative']
            ax.plot([xi, xi], [0, val], color=color, linewidth=2, alpha=0.7)
            ax.scatter(xi, val, color=color, s=40, edgecolor='black', linewidth=0.5, zorder=5)

        # Add legend (only on first subplot)
        if idx == 0:
            # Create custom legend
            from matplotlib.lines import Line2D
            legend_elements = [
                Line2D([0], [0], marker='o', color='w', markerfacecolor=COLORS['positive'],
                       markersize=10, label='SEAL > LoRA'),
                Line2D([0], [0], marker='o', color='w', markerfacecolor=COLORS['negative'],
                       markersize=10, label='LoRA > SEAL')
            ]
            ax.legend(handles=legend_elements, loc='upper right', frameon=True, fontsize=14)

        # Styling
        ax.set_xlabel('Mode index')
        ax.set_ylabel('Δp$_i$ = p$_i$(SEAL) - p$_i$(LoRA)')
        ax.set_title(model_name.upper())  # No bold
        ax.set_xlim(0.5, 32.5)
        ax.grid(axis='y', alpha=0.3, linestyle='--')

    # Hide unused subplots if using 2x2 layout
    if n_models < 4 and len(axes) > n_models:
        for idx in range(n_models, len(axes)):
            axes[idx].set_visible(False)

    # No suptitle - moved to LaTeX caption
    plt.tight_layout()
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
    print(f"Saved mode energy lollipop chart to {output_path}")
    plt.close()


def create_summary_table(processed_data: Dict, output_path: str = 'summary_table.tex'):
    """
    Supplementary Table: Summary statistics
    """
    summary_data = []

    for idx, model_name in enumerate(processed_data['models']):
        row = {'Model': model_name.upper()}

        # r_eff statistics
        lora_r_eff = processed_data['r_eff']['lora'][idx]
        seal_r_eff = processed_data['r_eff']['seal'][idx]

        if lora_r_eff is not None:
            row['r_eff (LoRA)'] = f"{np.median(lora_r_eff):.2f}"
        else:
            row['r_eff (LoRA)'] = "—"

        if seal_r_eff is not None:
            row['r_eff (SEAL)'] = f"{np.median(seal_r_eff):.2f}"
        else:
            row['r_eff (SEAL)'] = "—"

        if lora_r_eff is not None and seal_r_eff is not None:
            row['Δr_eff'] = f"{np.median(seal_r_eff) - np.median(lora_r_eff):.2f}"
        else:
            row['Δr_eff'] = "—"

        # S_k at specific values
        for k in [1, 4, 8, 16]:
            if processed_data['S_k']['lora'][idx] is not None:
                lora_s_k = np.array(processed_data['S_k']['lora'][idx])
                if k <= lora_s_k.shape[1]:
                    row[f'S_{k} (LoRA)'] = f"{np.median(lora_s_k[:, k-1]):.3f}"
                else:
                    row[f'S_{k} (LoRA)'] = "—"
            else:
                row[f'S_{k} (LoRA)'] = "—"

            if processed_data['S_k']['seal'][idx] is not None:
                seal_s_k = np.array(processed_data['S_k']['seal'][idx])
                if k <= seal_s_k.shape[1]:
                    row[f'S_{k} (SEAL)'] = f"{np.median(seal_s_k[:, k-1]):.3f}"
                else:
                    row[f'S_{k} (SEAL)'] = "—"
            else:
                row[f'S_{k} (SEAL)'] = "—"

        summary_data.append(row)

    # Create DataFrame and save
    df = pd.DataFrame(summary_data)

    # Save as LaTeX table
    latex_table = df.to_latex(index=False, escape=False, column_format='l' + 'c' * (len(df.columns) - 1))
    with open(output_path, 'w') as f:
        f.write(latex_table)
    print(f"Saved summary table to {output_path}")

    # Also save as CSV for convenience
    csv_path = str(output_path).replace('.tex', '.csv')
    df.to_csv(csv_path, index=False)
    print(f"Saved summary table to {csv_path}")

    return df


def create_latex_captions(output_dir: Path):
    """
    Create LaTeX captions for all plots - each caption is self-contained with metric definitions
    """
    captions = {
        'r_eff_violin': r"""
\caption{
    \textbf{Effective rank distribution comparison between LoRA and SEAL adapters.}
    The effective rank $r_{\text{eff}} = (\sum_{i} \sigma_i)^2 / \sum_{i} \sigma_i^2$ quantifies the dimensionality of the adaptation,
    where $\sigma_i$ are singular values of the adapter weight matrices.
    Split violin plots show the distribution across all layers: left (blue) for LoRA, right (magenta) for SEAL.
    Thick lines indicate median values with numerical annotations, dashed lines show 25-75\% range, and jittered points represent individual layers.
    Lower $r_{\text{eff}}$ indicates stronger spectral concentration, suggesting more efficient low-rank representation.
}
\label{fig:r_eff_violin}
""",
        'delta_s_k_curves': r"""
\caption{
    \textbf{Cumulative spectral energy difference between adaptation methods.}
    We define the cumulative energy $S_k = \sum_{i=1}^k \sigma_i^2 / \sum_{i=1}^r \sigma_i^2$, representing the fraction of total spectral energy
    captured by the top-$k$ singular modes.
    The difference $\Delta S_k = S_k(\text{SEAL}) - S_k(\text{LoRA})$ quantifies relative energy concentration.
    Positive values indicate SEAL concentrates more energy in top modes.
    Solid lines show mean across layers, shaded regions indicate 25-75\% range.
    Highlighted points at $k \in \{1, 4, 8, 16, 32\}$ provide key reference values.
}
\label{fig:delta_s_k}
""",
        'mode_energy_lollipop': r"""
\caption{
    \textbf{Per-mode normalized energy distribution differences.}
    The normalized mode energy $p_i = \sigma_i^2 / \sum_{j=1}^r \sigma_j^2$ represents the fractional energy contribution of mode $i$.
    Lollipop charts show $\Delta p_i = p_i(\text{SEAL}) - p_i(\text{LoRA})$, the median difference across layers.
    Red points indicate higher SEAL concentration, blue points indicate higher LoRA concentration.
    The horizontal baseline at zero facilitates comparison of magnitude and direction.
    This metric complements $S_k$ by showing mode-wise rather than cumulative differences.
}
\label{fig:mode_energy_lollipop}
""",
        'summary_table': r"""
\caption{
    \textbf{Summary statistics for spectral analysis of adapter methods.}
    All values are medians across layers.
    Effective rank: $r_{\text{eff}} = (\sum_i \sigma_i)^2 / \sum_i \sigma_i^2$.
    Cumulative energy: $S_k = \sum_{i=1}^k \sigma_i^2 / \sum_{i=1}^r \sigma_i^2$ for $k \in \{1,4,8,16\}$.
    $\Delta$ denotes SEAL minus LoRA.
    Dash (—) indicates unavailable data.
}
\label{tab:summary_stats}
"""
    }

    # Save captions to file
    caption_file = output_dir / 'latex_captions.tex'
    with open(caption_file, 'w') as f:
        f.write("% LaTeX captions for singular value analysis plots\n")
        f.write("% Each caption is self-contained with metric definitions\n\n")
        for plot_name, caption in captions.items():
            f.write(f"% Caption for {plot_name}\n")
            f.write(caption)
            f.write("\n\n")

    print(f"Saved LaTeX captions to {caption_file}")


def main():
    parser = argparse.ArgumentParser(description='Generate singular value analysis plots for ICLR paper')
    parser.add_argument('--lora', nargs='+', required=True, help='Paths to LoRA singular value JSON files')
    parser.add_argument('--seal', nargs='+', required=True, help='Paths to SEAL singular value JSON files (use "none" for missing)')
    parser.add_argument('--models', nargs='+', required=True, help='Model abbreviations for labeling')
    parser.add_argument('--output-dir', default='./plots', help='Output directory for plots')
    parser.add_argument('--format', default='pdf', choices=['pdf', 'png', 'svg'], help='Output format')

    args = parser.parse_args()

    # Validate input
    if len(args.lora) != len(args.seal) or len(args.lora) != len(args.models):
        raise ValueError("Number of LoRA files, SEAL files, and model names must match")

    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Load data
    lora_data = []
    seal_data = []

    for lora_path, seal_path in zip(args.lora, args.seal):
        # Load LoRA data
        if lora_path.lower() != 'none':
            lora_data.append(load_json_data(lora_path))
        else:
            lora_data.append(None)

        # Load SEAL data
        if seal_path.lower() != 'none':
            seal_data.append(load_json_data(seal_path))
        else:
            seal_data.append(None)

    # Process data
    processed_data = process_data_for_plots(lora_data, seal_data, args.models)

    # Generate plots
    print("Generating main plots...")
    plot_r_eff_violin(processed_data, output_dir / f'r_eff_violin.{args.format}')
    plot_delta_s_k_curves(processed_data, output_dir / f'delta_s_k_curves.{args.format}')

    print("\nGenerating supplementary plots...")
    plot_mode_energy_lollipop(processed_data, output_dir / f'mode_energy_lollipop.{args.format}')

    print("\nGenerating summary table...")
    summary_df = create_summary_table(processed_data, output_dir / 'summary_table.tex')
    print("\nSummary Table:")
    print(summary_df.to_string())

    print("\nGenerating LaTeX captions...")
    create_latex_captions(output_dir)

    print(f"\nAll plots and captions saved to {output_dir}/")


if __name__ == '__main__':
    main()