#!/usr/bin/env python3
"""
Analysis of iteration depth differences in tracker records.
Analyzes how perplexity, topk values and indices change with respect to iter depth.
"""

import pandas as pd
import ast
import statistics
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
from pathlib import Path
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def load_and_parse_data(csv_path):
    """Load and parse the tracker records CSV file."""
    if not os.path.exists(csv_path):
        print(f'CSV file not found: {csv_path}', file=sys.stderr)
        sys.exit(1)
    
    df = pd.read_csv(csv_path)
    
    # Parse list columns
    for col in ['topk_values', 'topk_indices']:
        df[col] = df[col].apply(ast.literal_eval)
    
    return df

def get_multi_depth_steps(df):
    """Get steps that have multiple iteration depths."""
    multi_steps = df.groupby('step_index')['iter_depth'].nunique()
    steps_multi = multi_steps[multi_steps > 1].index.tolist()
    return steps_multi

def compute_delta_statistics(df, steps_multi):
    """Compute delta statistics between consecutive iter_depths and between first and last."""
    records = []
    perp_deltas = []
    val_deltas = []
    first_val_deltas = []
    overlaps = []
    top1_overlaps = []  # Track top-1 index overlaps
    
    # Also track progression through all depths
    all_progressions = []
    
    for step in steps_multi:
        sub = df[df['step_index'] == step].sort_values('iter_depth')
        depths = sorted(sub['iter_depth'].unique())
        
        if len(depths) >= 2:
            # Compare first (depth 0) with last (max depth)
            row_first = sub[sub['iter_depth'] == depths[0]].iloc[0]
            row_last = sub[sub['iter_depth'] == depths[-1]].iloc[0]
            
            # Perplexity delta (first to last)
            perp_delta = row_last['perplexity'] - row_first['perplexity']
            perp_deltas.append(perp_delta)
            
            # Average topk value delta (first to last)
            avg_val_first = statistics.mean(row_first['topk_values'])
            avg_val_last = statistics.mean(row_last['topk_values'])
            val_delta = avg_val_last - avg_val_first
            val_deltas.append(val_delta)
            
            # Top-1 value delta (first to last)
            first_val_delta = row_last['topk_values'][0] - row_first['topk_values'][0]
            first_val_deltas.append(first_val_delta)
            
            # Index overlap (first to last)
            set_first = set(row_first['topk_indices'])
            set_last = set(row_last['topk_indices'])
            overlap = len(set_first & set_last) / len(set_first)
            overlaps.append(overlap)
            
            # Top-1 index overlap (first to last)
            top1_first = row_first['topk_indices'][0]
            top1_last = row_last['topk_indices'][0]
            top1_overlap = 1.0 if top1_first == top1_last else 0.0
            top1_overlaps.append(top1_overlap)
            
            # Record progression through all depths
            step_progression = {
                'step_index': step,
                'max_depth': depths[-1],
                'perp_delta_total': perp_delta,
                'avg_val_delta_total': val_delta,
                'first_val_delta_total': first_val_delta,
                'overlap_total': overlap,
                'progression': []
            }
            
            # Track consecutive deltas
            for i in range(len(depths) - 1):
                depth_curr = depths[i]
                depth_next = depths[i + 1]
                row_curr = sub[sub['iter_depth'] == depth_curr].iloc[0]
                row_next = sub[sub['iter_depth'] == depth_next].iloc[0]
                
                consecutive_delta = {
                    'from_depth': depth_curr,
                    'to_depth': depth_next,
                    'perp_delta': row_next['perplexity'] - row_curr['perplexity'],
                    'avg_val_delta': statistics.mean(row_next['topk_values']) - statistics.mean(row_curr['topk_values']),
                    'first_val_delta': row_next['topk_values'][0] - row_curr['topk_values'][0],
                    'overlap': len(set(row_curr['topk_indices']) & set(row_next['topk_indices'])) / len(set(row_curr['topk_indices'])),
                    'top1_overlap': 1.0 if row_curr['topk_indices'][0] == row_next['topk_indices'][0] else 0.0
                }
                step_progression['progression'].append(consecutive_delta)
            
            all_progressions.append(step_progression)
            
            records.append({
                'step_index': step,
                'max_depth': depths[-1],
                'perp_delta': perp_delta,
                'avg_val_delta': val_delta,
                'first_val_delta': first_val_delta,
                'overlap': overlap,
                'top1_overlap': top1_overlap
            })
    
    return records, perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, all_progressions

def summary_stats(data, name):
    """Compute summary statistics for a list of values."""
    if not data:
        return {}
    
    return {
        'name': name,
        'mean': statistics.mean(data),
        'median': statistics.median(data),
        'std': statistics.stdev(data) if len(data) > 1 else 0,
        'min': min(data),
        'max': max(data),
        'count': len(data)
    }

def print_summary_report(df, steps_multi, perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, all_progressions=None):
    """Print comprehensive summary report."""
    print("=" * 80)
    print("ITERATION DEPTH ANALYSIS REPORT")
    print("=" * 80)
    
    # Basic data info
    print(f"\nDATA OVERVIEW:")
    print(f"Total records: {len(df)}")
    print(f"Unique iter_depths: {sorted(df['iter_depth'].unique())}")
    print(f"Iter_depth distribution:")
    print(df['iter_depth'].value_counts().sort_index())
    print(f"Steps with multiple iter depths: {len(steps_multi)}")
    print(f"Computed delta pairs (first to last): {len(perp_deltas)}")
    
    # Analyze max depths distribution
    if all_progressions:
        max_depths = [p['max_depth'] for p in all_progressions]
        print(f"Max depth distribution:")
        from collections import Counter
        depth_counts = Counter(max_depths)
        for depth in sorted(depth_counts.keys()):
            print(f"  Max depth {depth}: {depth_counts[depth]} steps")
    
    # Perplexity analysis
    print(f"\nPERPLEXITY ANALYSIS (First to Last Iter):")
    perp_stats = summary_stats(perp_deltas, "Perplexity Delta")
    print(f"  Mean delta: {perp_stats['mean']:.4f}")
    print(f"  Median delta: {perp_stats['median']:.4f}")
    print(f"  Range: [{perp_stats['min']:.3f}, {perp_stats['max']:.3f}]")
    
    perp_decreased = sum(1 for d in perp_deltas if d < 0)
    perp_increased = sum(1 for d in perp_deltas if d > 0)
    print(f"  Perplexity decreased: {perp_decreased}/{len(perp_deltas)} ({perp_decreased/len(perp_deltas)*100:.1f}%)")
    print(f"  Perplexity increased: {perp_increased}/{len(perp_deltas)} ({perp_increased/len(perp_deltas)*100:.1f}%)")
    
    # Consecutive step analysis
    if all_progressions:
        print(f"\nCONSECUTIVE STEP ANALYSIS:")
        all_consecutive_perp = []
        step_wise_perp = {}
        
        for prog in all_progressions:
            for step_delta in prog['progression']:
                from_d, to_d = step_delta['from_depth'], step_delta['to_depth']
                step_key = f"{from_d}→{to_d}"
                if step_key not in step_wise_perp:
                    step_wise_perp[step_key] = []
                step_wise_perp[step_key].append(step_delta['perp_delta'])
                all_consecutive_perp.append(step_delta['perp_delta'])
        
        if all_consecutive_perp:
            print(f"  Overall consecutive step performance:")
            consec_decreased = sum(1 for d in all_consecutive_perp if d < 0)
            print(f"    Steps that improved: {consec_decreased}/{len(all_consecutive_perp)} ({consec_decreased/len(all_consecutive_perp)*100:.1f}%)")
            print(f"    Mean consecutive delta: {statistics.mean(all_consecutive_perp):.4f}")
            
            # Analyze top-1 consistency across consecutive steps
            all_consecutive_top1 = []
            step_wise_top1 = {}
            
            for prog in all_progressions:
                for step_delta in prog['progression']:
                    from_d, to_d = step_delta['from_depth'], step_delta['to_depth']
                    step_key = f"{from_d}→{to_d}"
                    if step_key not in step_wise_top1:
                        step_wise_top1[step_key] = []
                    step_wise_top1[step_key].append(step_delta['top1_overlap'])
                    all_consecutive_top1.append(step_delta['top1_overlap'])
            
            if all_consecutive_top1:
                top1_maintained = sum(all_consecutive_top1)
                print(f"    Top-1 token maintained: {top1_maintained:.0f}/{len(all_consecutive_top1)} ({top1_maintained/len(all_consecutive_top1)*100:.1f}%)")
            
            print(f"  Performance by step type:")
            for step_key in sorted(step_wise_perp.keys()):
                step_deltas = step_wise_perp[step_key]
                improved = sum(1 for d in step_deltas if d < 0)
                
                step_top1_overlaps = step_wise_top1.get(step_key, [])
                if step_top1_overlaps:
                    top1_maintained_pct = statistics.mean(step_top1_overlaps) * 100
                    print(f"    {step_key}: {improved}/{len(step_deltas)} improved ({improved/len(step_deltas)*100:.1f}%), mean: {statistics.mean(step_deltas):.4f}, top-1 maintained: {top1_maintained_pct:.1f}%")
                else:
                    print(f"    {step_key}: {improved}/{len(step_deltas)} improved ({improved/len(step_deltas)*100:.1f}%), mean: {statistics.mean(step_deltas):.4f}")
    
    # TopK values analysis
    print(f"\nTOPK VALUES ANALYSIS (First to Last Iter):")
    val_stats = summary_stats(val_deltas, "Avg TopK Value Delta")
    print(f"  Average topk delta mean: {val_stats['mean']:.6f}")
    print(f"  Average topk delta median: {val_stats['median']:.6f}")
    
    first_val_stats = summary_stats(first_val_deltas, "Top-1 Value Delta")
    print(f"  Top-1 value delta mean: {first_val_stats['mean']:.6f}")
    print(f"  Top-1 value delta median: {first_val_stats['median']:.6f}")
    
    first_val_decreased = sum(1 for d in first_val_deltas if d < 0)
    print(f"  Top-1 value decreased: {first_val_decreased}/{len(first_val_deltas)} ({first_val_decreased/len(first_val_deltas)*100:.1f}%)")
    
    # Index overlap analysis
    print(f"\nTOPK INDICES ANALYSIS (First to Last Iter):")
    overlap_stats = summary_stats(overlaps, "Index Overlap")
    print(f"  TopK overlap mean: {overlap_stats['mean']:.4f}")
    print(f"  TopK overlap median: {overlap_stats['median']:.4f}")
    print(f"  TopK overlap range: [{overlap_stats['min']:.2f}, {overlap_stats['max']:.2f}]")
    
    # Top-1 index overlap analysis
    if top1_overlaps:
        top1_stats = summary_stats(top1_overlaps, "Top-1 Index Overlap")
        top1_maintained = sum(top1_overlaps)
        print(f"  Top-1 token maintained (first to last): {top1_maintained:.0f}/{len(top1_overlaps)} ({top1_maintained/len(top1_overlaps)*100:.1f}%)")
    
    # Correlations
    if len(perp_deltas) > 1:
        corr_perp_val = np.corrcoef(perp_deltas, val_deltas)[0, 1]
        corr_perp_first = np.corrcoef(perp_deltas, first_val_deltas)[0, 1]
        corr_perp_overlap = np.corrcoef(perp_deltas, overlaps)[0, 1]
        
        print(f"\nCORRELATIONS (First to Last Iter):")
        print(f"  Perplexity delta vs Avg topk value delta: {corr_perp_val:.4f}")
        print(f"  Perplexity delta vs Top-1 value delta: {corr_perp_first:.4f}")
        print(f"  Perplexity delta vs Index overlap: {corr_perp_overlap:.4f}")

def create_visualizations(perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, output_dir, df=None):
    """Create comprehensive visualizations."""
    # Set style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Create figure with subplots - now 3x3 to accommodate all analyses including top-1 overlap
    fig, axes = plt.subplots(3, 3, figsize=(21, 18))
    fig.suptitle('Iteration Depth Analysis: Changes Across Iterations', fontsize=16, fontweight='bold')
    
    # 1. Perplexity delta histogram
    axes[0, 0].hist(perp_deltas, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].axvline(0, color='red', linestyle='--', alpha=0.8, label='No change')
    axes[0, 0].axvline(np.mean(perp_deltas), color='orange', linestyle='-', linewidth=2, label=f'Mean: {np.mean(perp_deltas):.3f}')
    axes[0, 0].set_xlabel('Perplexity Delta (Depth 1 - Depth 0)')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Distribution of Perplexity Changes')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Average topk value delta histogram
    axes[0, 1].hist(val_deltas, bins=30, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0, 1].axvline(0, color='red', linestyle='--', alpha=0.8, label='No change')
    axes[0, 1].axvline(np.mean(val_deltas), color='orange', linestyle='-', linewidth=2, label=f'Mean: {np.mean(val_deltas):.4f}')
    axes[0, 1].set_xlabel('Avg TopK Value Delta')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Distribution of Average TopK Value Changes')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Top-1 value delta histogram (moved to make room for top-1 consistency)
    # This will be replaced by top-1 consistency analysis below
    
    # 4. Index overlap histogram
    axes[1, 0].hist(overlaps, bins=20, alpha=0.7, color='plum', edgecolor='black')
    axes[1, 0].axvline(np.mean(overlaps), color='orange', linestyle='-', linewidth=2, label=f'Mean: {np.mean(overlaps):.3f}')
    axes[1, 0].set_xlabel('TopK Index Overlap (Jaccard)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Distribution of TopK Index Overlaps')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Add top-1 overlap visualization instead of one of the scatter plots
    # Move the second scatter plot and add top-1 analysis
    if top1_overlaps:
        # Top-1 overlap pie chart
        top1_maintained = sum(top1_overlaps)
        top1_changed = len(top1_overlaps) - top1_maintained
        
        labels = ['Top-1 Maintained', 'Top-1 Changed']
        sizes = [top1_maintained, top1_changed]
        colors = ['lightgreen', 'lightcoral']
        
        # Only create pie chart if there's data to show
        if top1_changed > 0:
            axes[0, 2].pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
            axes[0, 2].set_title(f'Top-1 Token Consistency\n(First to Last Iteration)')
        else:
            axes[0, 2].text(0.5, 0.5, f'Top-1 Always Maintained\n({int(top1_maintained)}/{len(top1_overlaps)})', 
                           ha='center', va='center', transform=axes[0, 2].transAxes, fontsize=12)
            axes[0, 2].set_title('Top-1 Token Consistency')
        axes[0, 2].axis('equal')
    
    # 5. Scatter: Perplexity vs Average TopK Value Delta
    axes[1, 1].scatter(perp_deltas, val_deltas, alpha=0.6, color='blue', s=20)
    axes[1, 1].axhline(0, color='red', linestyle='--', alpha=0.5)
    axes[1, 1].axvline(0, color='red', linestyle='--', alpha=0.5)
    axes[1, 1].set_xlabel('Perplexity Delta')
    axes[1, 1].set_ylabel('Avg TopK Value Delta')
    axes[1, 1].set_title('Perplexity vs Average TopK Value Changes')
    if len(perp_deltas) > 1:
        corr = np.corrcoef(perp_deltas, val_deltas)[0, 1]
        axes[1, 1].text(0.05, 0.95, f'Correlation: {corr:.3f}', transform=axes[1, 1].transAxes, 
                       bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Top-1 value delta histogram (moved from position [0,2])
    axes[1, 2].hist(first_val_deltas, bins=30, alpha=0.7, color='salmon', edgecolor='black')
    axes[1, 2].axvline(0, color='red', linestyle='--', alpha=0.8, label='No change')
    axes[1, 2].axvline(np.mean(first_val_deltas), color='orange', linestyle='-', linewidth=2, label=f'Mean: {np.mean(first_val_deltas):.4f}')
    axes[1, 2].set_xlabel('Top-1 Value Delta')
    axes[1, 2].set_ylabel('Frequency')
    axes[1, 2].set_title('Distribution of Top-1 Value Changes')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    # Add t-SNE visualizations if dataframe is provided
    if df is not None:
        create_tsne_visualizations(df, axes)
    
    plt.tight_layout()
    
    # Save plot
    output_path = Path(output_dir) / 'iter_depth_analysis.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\nVisualization saved to: {output_path}")
    
    # Show plot
    plt.show()

def create_tsne_visualizations(df, axes):
    """Create t-SNE visualizations for topk values grouped by iteration patterns."""
    
    # Get steps with multiple iteration depths
    steps_multi = get_multi_depth_steps(df)
    
    # Prepare data for t-SNE
    tsne_data = []
    labels = []
    step_info = []
    color_mapping = {}
    
    # Generate colors for different iteration patterns
    all_depths = sorted(df['iter_depth'].unique())
    max_possible_depth = max(all_depths)
    
    # Create color palette
    colors = plt.cm.Set3(np.linspace(0, 1, (max_possible_depth + 1) * 2))
    color_idx = 0
    
    for step in steps_multi:
        sub = df[df['step_index'] == step].sort_values('iter_depth')
        depths = sorted(sub['iter_depth'].unique())
        max_depth = depths[-1]
        
        # Add all iterations for this step
        for depth in depths:
            row = sub[sub['iter_depth'] == depth].iloc[0]
            tsne_data.append(row['topk_values'])
            
            # Create label in format current_depth/max_depth
            label = f"{depth}/{max_depth}"
            labels.append(label)
            step_info.append(f"Step {step} - {label}")
            
            # Assign color for this pattern if not seen before
            if label not in color_mapping:
                color_mapping[label] = colors[color_idx % len(colors)]
                color_idx += 1
    
    # Also add steps that only have iter_depth 0 (mark as 0/0)
    only_depth0_steps = df.groupby('step_index')['iter_depth'].apply(lambda x: set(x) == {0})
    only_depth0_step_indices = only_depth0_steps[only_depth0_steps].index.tolist()
    
    # Sample some 0/0 cases to avoid overcrowding (take every 10th)
    sampled_depth0_steps = only_depth0_step_indices[::10][:50]  # Limit to 50 samples
    
    for step in sampled_depth0_steps:
        row = df[df['step_index'] == step].iloc[0]
        tsne_data.append(row['topk_values'])
        label = '0/0'
        labels.append(label)
        step_info.append(f"Step {step} - {label}")
        
        if label not in color_mapping:
            color_mapping[label] = colors[color_idx % len(colors)]
            color_idx += 1
    
    if len(tsne_data) < 10:  # Need minimum data points for t-SNE
        print("Not enough data for t-SNE visualization")
        return
    
    # Convert to numpy array and standardize
    X = np.array(tsne_data)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Apply t-SNE
    print("Computing t-SNE embedding...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(X_scaled)//4))
    X_tsne = tsne.fit_transform(X_scaled)
    
    # Create colors array using the mapping
    point_colors = [color_mapping[label] for label in labels]
    
    # Plot 1: Overall t-SNE
    axes[2, 0].scatter(X_tsne[:, 0], X_tsne[:, 1], c=point_colors, alpha=0.6, s=30)
    axes[2, 0].set_xlabel('t-SNE Dimension 1')
    axes[2, 0].set_ylabel('t-SNE Dimension 2')
    axes[2, 0].set_title('t-SNE of TopK Values by Iteration Pattern')
    
    # Create legend with all unique labels
    unique_labels = list(set(labels))
    for label in sorted(unique_labels):
        axes[2, 0].scatter([], [], color=color_mapping[label], label=label, s=50)
    axes[2, 0].legend(title='Iter Pattern (current/max)', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[2, 0].grid(True, alpha=0.3)
    
    # Plot 2: Focus on first vs last iterations for same steps
    # Find patterns that represent first and last iterations
    first_last_pairs = {}
    for i, info in enumerate(step_info):
        step_num = int(info.split()[1])
        label = labels[i]
        current, max_d = map(int, label.split('/'))
        
        if step_num not in first_last_pairs:
            first_last_pairs[step_num] = {'first': None, 'last': None, 'max_depth': max_d}
        
        if current == 0:  # First iteration
            first_last_pairs[step_num]['first'] = i
        if current == max_d:  # Last iteration
            first_last_pairs[step_num]['last'] = i
    
    # Plot first vs last iterations
    first_indices = []
    last_indices = []
    for step_data in first_last_pairs.values():
        if step_data['first'] is not None and step_data['last'] is not None and step_data['max_depth'] > 0:
            first_indices.append(step_data['first'])
            last_indices.append(step_data['last'])
    
    if first_indices and last_indices:
        # Plot first iterations
        first_points = X_tsne[first_indices]
        axes[2, 1].scatter(first_points[:, 0], first_points[:, 1], c='blue', alpha=0.7, s=40, label='First Iter (0/X)')
        
        # Plot last iterations
        last_points = X_tsne[last_indices]
        axes[2, 1].scatter(last_points[:, 0], last_points[:, 1], c='red', alpha=0.7, s=40, label='Last Iter (X/X)')
        
        axes[2, 1].set_xlabel('t-SNE Dimension 1')
        axes[2, 1].set_ylabel('t-SNE Dimension 2')
        axes[2, 1].set_title('t-SNE: First vs Last Iterations')
        axes[2, 1].legend()
        axes[2, 1].grid(True, alpha=0.3)
    
    # Plot 3: Distribution comparison
    # Compute average distance between first and last iterations (same step)
    paired_distances = []
    
    for step_data in first_last_pairs.values():
        if step_data['first'] is not None and step_data['last'] is not None and step_data['max_depth'] > 0:
            idx_first = step_data['first']
            idx_last = step_data['last']
            dist = np.linalg.norm(X_tsne[idx_first] - X_tsne[idx_last])
            paired_distances.append(dist)
    
    if paired_distances:
        axes[2, 2].hist(paired_distances, bins=20, alpha=0.7, color='purple', edgecolor='black')
        axes[2, 2].set_xlabel('t-SNE Distance Between First and Last Iter')
        axes[2, 2].set_ylabel('Frequency')
        axes[2, 2].set_title('Distribution of Changes in TopK Space')
        axes[2, 2].axvline(np.mean(paired_distances), color='orange', linestyle='-', 
                          linewidth=2, label=f'Mean: {np.mean(paired_distances):.3f}')
        axes[2, 2].legend()
        axes[2, 2].grid(True, alpha=0.3)
        
        print(f"Average t-SNE distance between first and last iterations: {np.mean(paired_distances):.3f}")
        print(f"Std deviation of distances: {np.std(paired_distances):.3f}")
    else:
        axes[2, 2].text(0.5, 0.5, 'No paired data available', ha='center', va='center', transform=axes[2, 2].transAxes)
        axes[2, 2].set_title('Distribution of Changes in TopK Space')

def analyze_extreme_cases(records, df, n_extreme=5):
    """Analyze extreme cases of improvement and degradation."""
    records_df = pd.DataFrame(records)
    
    print(f"\nEXTREME CASES ANALYSIS:")
    print(f"\nTop {n_extreme} perplexity improvements (largest negative deltas):")
    best_improvements = records_df.nsmallest(n_extreme, 'perp_delta')
    for _, row in best_improvements.iterrows():
        print(f"  Step {row['step_index']}: Δperp={row['perp_delta']:.3f}, Δavg_val={row['avg_val_delta']:.3f}, overlap={row['overlap']:.3f}")
    
    print(f"\nTop {n_extreme} perplexity degradations (largest positive deltas):")
    worst_degradations = records_df.nlargest(n_extreme, 'perp_delta')
    for _, row in worst_degradations.iterrows():
        print(f"  Step {row['step_index']}: Δperp={row['perp_delta']:.3f}, Δavg_val={row['avg_val_delta']:.3f}, overlap={row['overlap']:.3f}")

def main():
    """Main analysis function."""
    # Configuration
    csv_path = 'output/analysis/tracker_records.csv'
    output_dir = 'output/analysis'
    
    # Load data
    print("Loading and parsing data...")
    df = load_and_parse_data(csv_path)
    
    # Get multi-depth steps
    steps_multi = get_multi_depth_steps(df)
    
    # Compute delta statistics
    print("Computing delta statistics...")
    records, perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, all_progressions = compute_delta_statistics(df, steps_multi)
    
    # Print summary report
    print_summary_report(df, steps_multi, perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, all_progressions)
    
    # Analyze extreme cases
    analyze_extreme_cases(records, df)
    
    # Create visualizations
    print("\nCreating visualizations...")
    create_visualizations(perp_deltas, val_deltas, first_val_deltas, overlaps, top1_overlaps, output_dir, df)
    
    # Save detailed results
    results_df = pd.DataFrame(records)
    results_path = Path(output_dir) / 'iter_depth_analysis_results.csv'
    results_df.to_csv(results_path, index=False)
    print(f"Detailed results saved to: {results_path}")
    
    print("\nAnalysis complete!")

if __name__ == "__main__":
    main()
