#!/usr/bin/env python3
"""
Create publication-ready tightness analysis figures

Generates:
1. Ratio distribution histogram (main figure)
2. Summary table (LaTeX)
3. Optional: Log-scale scatter plot (appendix)
"""

import json
import numpy as np
import argparse
from pathlib import Path
from typing import Dict, List
import matplotlib.pyplot as plt
import seaborn as sns

def load_json(json_path: str) -> Dict:
    """Load JSON file."""
    with open(json_path, 'r') as f:
        return json.load(f)

def load_pseudo_radius_data(directory: str, sigma: float) -> Dict[int, Dict]:
    """Load pseudo-radius data."""
    directory = Path(directory)
    files = list(directory.glob(f"pseudo_radius_sigma{sigma}*.json"))
    
    pseudo_data = {}
    for file in files:
        try:
            data = load_json(file)
            for result in data.get('results', []):
                sample_idx = result.get('sample_idx')
                if sample_idx is not None:
                    pseudo_data[sample_idx] = {
                        'R_true_raw': result.get('R_true_raw'),
                        'hit_R_max': result.get('info', {}).get('hit_R_max', False),
                        'test_dataset_idx': result.get('test_dataset_idx')
                    }
        except Exception as e:
            print(f"⚠ Warning: Failed to load {file}: {e}")
    
    return pseudo_data

def load_certified_radius_data(file: str, method: str = 'with_gradient') -> Dict[int, float]:
    """Load certified radius data."""
    data = load_json(file)
    certified_radii = {}
    for result in data.get('results', []):
        sample_idx = result.get('sample_idx')
        if method == 'with_gradient':
            r = result.get('radius_with_gradient')
        elif method == 'variance_mean':
            r = result.get('radius_variance_mean')
        else:
            r = result.get(f'radius_{method}')
        
        if sample_idx is not None and r is not None:
            certified_radii[sample_idx] = float(r)
    
    return certified_radii

def create_ratio_distribution_plot(
    pseudo_data: Dict[int, Dict],
    certified_radii: Dict[int, float],
    sigma: float,
    method: str,
    output_file: Path,
    exclude_capped: bool = True
):
    """Create ratio distribution histogram (main figure)."""
    
    # Match samples
    common_indices = sorted(set(pseudo_data.keys()) & set(certified_radii.keys()))
    
    pseudo_raw = []
    certified = []
    hit_cap = []
    
    for idx in common_indices:
        pseudo_info = pseudo_data[idx]
        cert_r = certified_radii[idx]
        
        pseudo_raw.append(pseudo_info['R_true_raw'])
        certified.append(cert_r)
        hit_cap.append(pseudo_info['hit_R_max'])
    
    pseudo_raw = np.array(pseudo_raw)
    certified = np.array(certified)
    hit_cap = np.array(hit_cap)
    
    # Compute ratios
    ratios = pseudo_raw / certified
    
    # Filter
    if exclude_capped:
        ratios_plot = ratios[~hit_cap]
        n_excluded = np.sum(hit_cap)
    else:
        ratios_plot = ratios
        n_excluded = 0
    
    # Create plot
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    
    # Histogram
    n, bins, patches = ax.hist(ratios_plot, bins=25, alpha=0.7, edgecolor='black', 
                               color='steelblue', linewidth=1.2)
    
    # Mean and median lines
    mean_ratio = np.mean(ratios_plot)
    median_ratio = np.median(ratios_plot)
    
    ax.axvline(mean_ratio, color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {mean_ratio:.2f}×')
    ax.axvline(median_ratio, color='green', linestyle='--', linewidth=2, 
              label=f'Median: {median_ratio:.2f}×')
    ax.axvline(1.0, color='black', linestyle=':', linewidth=1.5, alpha=0.5,
               label='Ratio = 1.0')
    
    ax.set_xlabel('Ratio (Pseudo-True / Certified)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Frequency', fontsize=13, fontweight='bold')
    ax.set_title(f'Tightness Analysis: {method} (σ = {sigma})', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11, framealpha=0.9)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add text box with statistics
    stats_text = f"n = {len(ratios_plot)} samples"
    if n_excluded > 0:
        stats_text += f"\n{n_excluded} capped (excluded)"
    ax.text(0.98, 0.98, stats_text, transform=ax.transAxes,
            fontsize=10, verticalalignment='top', horizontalalignment='right',
            bbox=dict(boxstyle='round,pad=0.5', facecolor='white', alpha=0.8,
                     edgecolor='gray', linewidth=0.5))
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✓ Saved ratio distribution plot: {output_file}")
    plt.close()
    
    return {
        'ratios': ratios_plot.tolist(),
        'mean': float(mean_ratio),
        'median': float(median_ratio),
        'n_samples': len(ratios_plot),
        'n_excluded': int(n_excluded)
    }

def create_summary_table(
    analyses: List[Dict],
    output_file: Path
):
    """Create LaTeX table with summary statistics."""
    
    table_lines = [
        "% Table: Tightness Analysis Summary",
        "% Auto-generated from tightness analysis",
        "",
        "\\begin{table}[t]",
        "    \\centering",
        "    \\caption{Tightness analysis comparing certified radii with pseudo-true radii. ",
        "    Ratios are computed for uncapped samples only (samples that hit R\\textsubscript{max}=5.0 are excluded).}",
        "    \\label{tab:tightness_analysis}",
        "    \\begin{tabular}{lcccccc}",
        "        \\toprule",
        "        Method & $\\sigma$ & Mean Cert. & Mean Pseudo & Mean Ratio & Median Ratio & \\% Capped \\\\",
        "        \\midrule"
    ]
    
    for analysis in analyses:
        method = analysis['method']
        sigma = analysis['sigma']
        mean_cert = analysis['mean_certified']
        mean_pseudo = analysis['mean_pseudo_uncapped']
        mean_ratio = analysis['mean_ratio_uncapped']
        median_ratio = analysis['median_ratio_uncapped']
        pct_capped = analysis['pct_capped']
        
        method_label = method.replace('_', ' ').title()
        if 'gradient' in method:
            method_label = '$(E, C, G) + M$'
        elif 'variance' in method:
            method_label = '$(E, C) + M$'
        
        table_lines.append(
            f"        {method_label} & {sigma} & {mean_cert:.3f} & {mean_pseudo:.3f} & "
            f"{mean_ratio:.2f}$\\times$ & {median_ratio:.2f}$\\times$ & {pct_capped:.1f}\\% \\\\"
        )
    
    table_lines.extend([
        "        \\bottomrule",
        "    \\end{tabular}",
        "\\end{table}"
    ])
    
    with open(output_file, 'w') as f:
        f.write('\n'.join(table_lines))
    
    print(f"✓ Saved summary table: {output_file}")

def main():
    parser = argparse.ArgumentParser(description='Create tightness analysis figures')
    parser.add_argument('--pseudo_dir', type=str, default='pseudo_radius_results')
    parser.add_argument('--certified_file', type=str, required=True)
    parser.add_argument('--sigma', type=float, required=True)
    parser.add_argument('--method', type=str, default='with_gradient',
                       choices=['with_gradient', 'variance_mean'])
    parser.add_argument('--output_dir', type=str, default='figures/tightness_analysis')
    
    args = parser.parse_args()
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load data
    print(f"Loading data...")
    pseudo_data = load_pseudo_radius_data(args.pseudo_dir, args.sigma)
    certified_radii = load_certified_radius_data(args.certified_file, args.method)
    
    # Match and compute statistics
    common_indices = sorted(set(pseudo_data.keys()) & set(certified_radii.keys()))
    
    pseudo_raw = np.array([pseudo_data[idx]['R_true_raw'] for idx in common_indices])
    certified = np.array([certified_radii[idx] for idx in common_indices])
    hit_cap = np.array([pseudo_data[idx]['hit_R_max'] for idx in common_indices])
    
    ratios = pseudo_raw / certified
    ratios_uncapped = ratios[~hit_cap]
    
    # Create ratio distribution plot
    plot_file = output_dir / f'tightness_ratio_dist_sigma{args.sigma}_{args.method}.png'
    plot_stats = create_ratio_distribution_plot(
        pseudo_data, certified_radii, args.sigma, args.method, plot_file
    )
    
    # Prepare analysis summary
    analysis = {
        'method': args.method,
        'sigma': args.sigma,
        'mean_certified': float(np.mean(certified)),
        'mean_pseudo_uncapped': float(np.mean(pseudo_raw[~hit_cap])),
        'mean_ratio_uncapped': float(np.mean(ratios_uncapped)),
        'median_ratio_uncapped': float(np.median(ratios_uncapped)),
        'pct_capped': float(100 * np.sum(hit_cap) / len(common_indices))
    }
    
    # Create table
    table_file = output_dir / f'tightness_table_sigma{args.sigma}_{args.method}.tex'
    create_summary_table([analysis], table_file)
    
    # Save statistics
    stats_file = output_dir / f'tightness_stats_sigma{args.sigma}_{args.method}.json'
    with open(stats_file, 'w') as f:
        json.dump(analysis, f, indent=2)
    
    print(f"\n✓ Analysis complete!")
    print(f"  Plot: {plot_file}")
    print(f"  Table: {table_file}")
    print(f"  Stats: {stats_file}")

if __name__ == '__main__':
    main()

