"""
Visualization Script for Paper Figures

Creates publication-ready figures from experiment results.
"""

import os
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set style for publication
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.figsize'] = (8, 5)
plt.rcParams['figure.dpi'] = 150


def load_results(results_dir: str) -> Dict:
    """
    Load all experiment results from directory.
    
    Use this after we finished running all experiments.
    """
    results = {}
    
    for filename in os.listdir(results_dir):
        if filename.endswith('_results.json'):
            filepath = os.path.join(results_dir, filename)
            with open(filepath, 'r') as f:
                data = json.load(f)
                exp_name = data.get('experiment_name', filename.replace('_results.json', ''))
                results[exp_name] = data
    
    return results


def plot_experiment1_temperature_comparison(results: Dict, output_dir: str):
    """
    Plot: Temperature Matching Results
    
    Shows that hyperfitting ≠ temperature scaling
    """
    exp1 = results.get('temperature_matching', {})
    
    if not exp1 or 'results' not in exp1:
        logger.warning("No Experiment 1 results found")
        return
    
    exp1_results = exp1['results']
    comparison = exp1_results.get('generation_comparison', {}).get('aggregated', {})
    
    if not comparison:
        logger.warning("No generation comparison data found")
        return
    
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: TTR Comparison
    ax1 = axes[0]
    models = ['Original\n(greedy)', f'Original\n(T={exp1_results["matched_temperature"]:.2f})', 'Hyperfitted\n(greedy)']
    ttrs = [
        comparison['original_greedy']['mean_ttr'],
        comparison['original_matched_temp']['mean_ttr'],
        comparison['hyperfitted_greedy']['mean_ttr'],
    ]
    stds = [
        comparison['original_greedy']['std_ttr'],
        comparison['original_matched_temp']['std_ttr'],
        comparison['hyperfitted_greedy']['std_ttr'],
    ]
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
    bars = ax1.bar(models, ttrs, yerr=stds, capsize=5, color=colors, alpha=0.8)
    ax1.set_ylabel('Type-Token Ratio (TTR)')
    ax1.set_title('Generation Quality Comparison\n(Higher TTR = Less Repetition)')
    ax1.set_ylim(0, 1)
    
    for bar, ttr in zip(bars, ttrs):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
                f'{ttr:.3f}', ha='center', va='bottom', fontsize=10)
    
    # Plot 2: Entropy Comparison
    ax2 = axes[1]
    entropies = [
        exp1_results['original_entropy_default'],
        exp1_results['original_entropy_matched'],
        exp1_results['hyperfitted_entropy'],
    ]
    
    bars2 = ax2.bar(models, entropies, color=colors, alpha=0.8)
    ax2.set_ylabel('Prediction Entropy')
    ax2.set_title('Distribution Entropy\n(Matched between Original+T and Hyperfitted)')
    
    for bar, ent in zip(bars2, entropies):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
                f'{ent:.2f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, 'figure1_temperature_comparison.pdf')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.savefig(output_path.replace('.pdf', '.png'), bbox_inches='tight', dpi=300)
    logger.info(f"Saved: {output_path}")
    
    plt.close()


def plot_experiment2_rank_analysis(results: Dict, output_dir: str):
    """
    Plot: Rank Analysis Results
    
    Shows how token rankings change between original and hyperfitted
    """
    exp2 = results.get('rank_analysis', {})
    
    if not exp2 or 'results' not in exp2:
        logger.warning("No Experiment 2 results found")
        return
    
    exp2_results = exp2['results']
    top1_comp = exp2_results.get('top1_comparison', {})
    
    if not top1_comp:
        logger.warning("No top-1 comparison data found")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Plot 1: Top-1 in Original Top-K
    ax1 = axes[0]
    k_values = [1, 5, 10, 50, 100]
    rates = [
        top1_comp.get('top1_agreement', 0),
        top1_comp.get('hyper_top1_in_orig_top5', 0),
        top1_comp.get('hyper_top1_in_orig_top10', 0),
        top1_comp.get('hyper_top1_in_orig_top50', 0),
        top1_comp.get('hyper_top1_in_orig_top100', 0),
    ]
    
    ax1.bar([str(k) for k in k_values], rates, color='steelblue', alpha=0.8)
    ax1.set_xlabel('Original Model Top-K')
    ax1.set_ylabel('Rate')
    ax1.set_title("Hyperfitted Model's Top-1 Token\nAppearance in Original Model's Top-K")
    ax1.set_ylim(0, 1)
    
    for i, rate in enumerate(rates):
        ax1.text(i, rate + 0.02, f'{rate:.3f}', ha='center', va='bottom', fontsize=10)
    
    ax1.axhline(y=0.8, color='red', linestyle='--', alpha=0.5, label='80% threshold')
    ax1.legend()
    
    # Plot 2: Top Promoted Tokens
    ax2 = axes[1]
    promoted = exp2_results.get('promoted_tokens', [])[:15]
    
    if promoted:
        tokens = [p['token_str'][:10] for p in promoted]
        improvements = [p['rank_improvement'] for p in promoted]
        
        y_pos = np.arange(len(tokens))
        ax2.barh(y_pos, improvements, color='forestgreen', alpha=0.8)
        ax2.set_yticks(y_pos)
        ax2.set_yticklabels(tokens)
        ax2.set_xlabel('Rank Improvement')
        ax2.set_title('Most Promoted Tokens\n(Moved up in ranking after hyperfitting)')
        ax2.invert_yaxis()
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, 'figure2_rank_analysis.pdf')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.savefig(output_path.replace('.pdf', '.png'), bbox_inches='tight', dpi=300)
    logger.info(f"Saved: {output_path}")
    
    plt.close()


def plot_experiment3_synthetic(results: Dict, output_dir: str):
    """
    Plot: Synthetic Hyperfitting Results
    """
    exp3 = results.get('synthetic_hyperfitting', {})
    
    if not exp3 or 'results' not in exp3:
        logger.warning("No Experiment 3 results found")
        return
    
    exp3_results = exp3['results']
    baselines = exp3_results.get('baselines', {})
    synthetic = exp3_results.get('synthetic_by_scale', {})
    
    if not synthetic:
        logger.warning("No synthetic results found")
        return
    
    _, ax = plt.subplots(figsize=(10, 6))
    
    # Prepare data
    scales = sorted([float(s) for s in synthetic.keys()])
    ttrs = [synthetic[str(s) if str(s) in synthetic else s]['mean_ttr'] for s in scales]
    
    # Plot synthetic results
    ax.plot(scales, ttrs, 'o-', color='purple', linewidth=2, markersize=8, label='Synthetic Correction')
    
    # Add baseline lines
    orig_ttr = baselines.get('original', {}).get('mean_ttr', 0)
    hyper_ttr = baselines.get('hyperfitted', {}).get('mean_ttr', 0)
    
    ax.axhline(y=orig_ttr, color='blue', linestyle='--', linewidth=2, label=f'Original Model ({orig_ttr:.3f})')
    ax.axhline(y=hyper_ttr, color='green', linestyle='--', linewidth=2, label=f'Hyperfitted Model ({hyper_ttr:.3f})')
    
    ax.set_xlabel('Correction Scale')
    ax.set_ylabel('Type-Token Ratio (TTR)')
    ax.set_title('Synthetic Hyperfitting: Can Rank Corrections Replicate the Effect?')
    ax.legend()
    ax.set_xscale('log')
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, 'figure3_synthetic.pdf')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.savefig(output_path.replace('.pdf', '.png'), bbox_inches='tight', dpi=300)
    logger.info(f"Saved: {output_path}")
    
    plt.close()


def plot_experiment4_layers(results: Dict, output_dir: str):
    """
    Plot: Layer-wise Representation Analysis
    """
    exp4 = results.get('representation_analysis', {})
    
    if not exp4 or 'results' not in exp4:
        logger.warning("No Experiment 4 results found")
        return
    
    exp4_results = exp4['results']
    layer_analysis = exp4_results.get('layer_wise_analysis', [])
    
    if not layer_analysis:
        logger.warning("No layer analysis data found")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    layers = [l['layer'] for l in layer_analysis]
    
    # Plot 1: Cosine Similarity
    ax1 = axes[0, 0]
    cos_sim = [l['mean_cosine_sim'] for l in layer_analysis]
    ax1.plot(layers, cos_sim, 'o-', color='steelblue', linewidth=2)
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Cosine Similarity')
    ax1.set_title('Hidden State Similarity\n(Lower = More Changed)')
    ax1.set_ylim(0, 1)
    
    # Plot 2: L2 Distance
    ax2 = axes[0, 1]
    l2_dist = [l['mean_l2_dist'] for l in layer_analysis]
    ax2.plot(layers, l2_dist, 'o-', color='coral', linewidth=2)
    ax2.set_xlabel('Layer')
    ax2.set_ylabel('L2 Distance')
    ax2.set_title('Hidden State Distance\n(Higher = More Changed)')
    
    # Plot 3: Dimension Change
    ax3 = axes[1, 0]
    dim_change = [l['mean_dim_change'] for l in layer_analysis]
    ax3.bar(layers, dim_change, color='forestgreen', alpha=0.8)
    ax3.set_xlabel('Layer')
    ax3.set_ylabel('Effective Dimension Change')
    ax3.set_title('Change in Effective Dimensionality')
    ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    # Plot 4: Norm Change
    ax4 = axes[1, 1]
    norm_change = [l['mean_norm_change'] for l in layer_analysis]
    ax4.bar(layers, norm_change, color='purple', alpha=0.8)
    ax4.set_xlabel('Layer')
    ax4.set_ylabel('Norm Change')
    ax4.set_title('Change in Hidden State Norm')
    ax4.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, 'figure4_layer_analysis.pdf')
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    plt.savefig(output_path.replace('.pdf', '.png'), bbox_inches='tight', dpi=300)
    logger.info(f"Saved: {output_path}")
    
    plt.close()


def create_summary_table(results: Dict, output_dir: str):
    summary = []
    summary.append("=" * 80)
    summary.append("EXPERIMENT SUMMARY")
    summary.append("=" * 80)
    
    # Experiment 1
    exp1 = results.get('temperature_matching', {}).get('results', {})
    if exp1:
        comp = exp1.get('generation_comparison', {}).get('aggregated', {})
        if comp:
            summary.append("\n[Experiment 1: Temperature Matching]")
            summary.append(f"  Matched Temperature: {exp1.get('matched_temperature', 'N/A'):.4f}")
            summary.append(f"  Original Model TTR: {comp.get('original_greedy', {}).get('mean_ttr', 'N/A'):.4f}")
            summary.append(f"  Original + Matched Temp TTR: {comp.get('original_matched_temp', {}).get('mean_ttr', 'N/A'):.4f}")
            summary.append(f"  Hyperfitted Model TTR: {comp.get('hyperfitted_greedy', {}).get('mean_ttr', 'N/A'):.4f}")
            
            hyper_ttr = comp.get('hyperfitted_greedy', {}).get('mean_ttr', 0)
            matched_ttr = comp.get('original_matched_temp', {}).get('mean_ttr', 0)
            if hyper_ttr > matched_ttr + 0.05:
                summary.append("  → CONCLUSION: Hyperfitting ≠ Temperature (TTR difference significant)")
    
    # Experiment 2
    exp2 = results.get('rank_analysis', {}).get('results', {})
    if exp2:
        top1 = exp2.get('top1_comparison', {})
        summary.append("\n[Experiment 2: Rank Analysis]")
        summary.append(f"  Top-1 Agreement Rate: {top1.get('top1_agreement', 'N/A'):.4f}")
        summary.append(f"  Hyperfitted Top-1 in Original Top-10: {top1.get('hyper_top1_in_orig_top10', 'N/A'):.4f}")
        
        if top1.get('top1_agreement', 1) < 0.8:
            summary.append("  → CONCLUSION: Significant rank changes detected")
    
    # Experiment 3
    exp3 = results.get('synthetic_hyperfitting', {}).get('results', {})
    if exp3:
        baselines = exp3.get('baselines', {})
        synthetic = exp3.get('synthetic_by_scale', {})
        if baselines and synthetic:
            best_scale = max(synthetic.keys(), key=lambda s: synthetic[s].get('mean_ttr', 0))
            
            summary.append("\n[Experiment 3: Synthetic Hyperfitting]")
            summary.append(f"  Original TTR: {baselines.get('original', {}).get('mean_ttr', 'N/A'):.4f}")
            summary.append(f"  Hyperfitted TTR: {baselines.get('hyperfitted', {}).get('mean_ttr', 'N/A'):.4f}")
            summary.append(f"  Best Synthetic TTR (scale={best_scale}): {synthetic[best_scale].get('mean_ttr', 'N/A'):.4f}")
    
    # Experiment 4
    exp4 = results.get('representation_analysis', {}).get('results', {})
    if exp4:
        most_changed = exp4.get('most_changed_layers', [])
        if most_changed:
            summary.append("\n[Experiment 4: Representation Analysis]")
            summary.append(f"  Most Changed Layer: {most_changed[0].get('layer', 'N/A')}")
            summary.append(f"  Max L2 Distance: {most_changed[0].get('mean_l2_dist', 'N/A'):.4f}")
    
    summary.append("\n" + "=" * 80)
    
    # Write to file
    summary_text = "\n".join(summary)
    output_path = os.path.join(output_dir, 'summary.txt')
    with open(output_path, 'w') as f:
        f.write(summary_text)
    
    print(summary_text)
    logger.info(f"Summary saved to: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Generate figures from experiment results")
    parser.add_argument("--results_dir", type=str, required=True, help="Directory containing results JSON files")
    parser.add_argument("--output_dir", type=str, default=None, help="Output directory for figures")
    
    args = parser.parse_args()
    
    if args.output_dir is None:
        args.output_dir = os.path.join(args.results_dir, 'figures')
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load results
    results = load_results(args.results_dir)
    logger.info(f"Loaded results for experiments: {list(results.keys())}")
    
    # Generate figures
    plot_experiment1_temperature_comparison(results, args.output_dir)
    plot_experiment2_rank_analysis(results, args.output_dir)
    plot_experiment3_synthetic(results, args.output_dir)
    plot_experiment4_layers(results, args.output_dir)
    
    # Create summary
    create_summary_table(results, args.output_dir)
    
    logger.info(f"\nAll figures saved to: {args.output_dir}")


if __name__ == "__main__":
    main()
