#!/usr/bin/env python3
"""
Analysis script to compare MMD vs W1h results in TRACE framework.
This script extracts and compares the performance of MMD-based vs W1-based TRACE diagnostics.
"""

import argparse
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import roc_auc_score, average_precision_score

def load_and_prepare_data(csv_path):
    """Load the deployment gate results and prepare for analysis."""
    df = pd.read_csv(csv_path)
    
    # Extract MMD and W1 related columns
    mmd_cols = [col for col in df.columns if 'mmd' in col.lower()]
    w1_cols = [col for col in df.columns if 'w1' in col.lower() and 'mmd' not in col.lower()]
    trace_cols = [col for col in df.columns if 'trace' in col.lower()]
    
    print("MMD-related columns found:")
    for col in mmd_cols:
        print(f"  - {col}")
    
    print("\nW1-related columns found:")
    for col in w1_cols:
        print(f"  - {col}")
    
    print("\nTRACE-related columns found:")
    for col in trace_cols:
        print(f"  - {col}")
    
    return df, mmd_cols, w1_cols, trace_cols

def compute_correlations(df, predictors, target_col='delta_R_true'):
    """Compute rank correlations for different predictors."""
    target = df[target_col].abs().values  # Use absolute risk difference
    
    results = []
    for name, scores in predictors.items():
        if name not in df.columns:
            continue
            
        scores = df[name].values
        # Remove NaN values
        mask = ~np.isnan(scores)
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            kendall = kendalltau(scores[mask], target[mask]).correlation
        except:
            rho = np.nan
            kendall = np.nan
            
        results.append({
            'predictor': name,
            'spearman_rho': rho,
            'kendall_tau': kendall,
            'n_valid': mask.sum()
        })
    
    return pd.DataFrame(results)

def create_comparison_plots(df, outdir):
    """Create comparison plots between MMD and W1 approaches."""
    os.makedirs(outdir, exist_ok=True)
    
    # Set up the plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # 1. TRACE MMD vs TRACE W1 comparison
    if 'trace_mmd_avgpool' in df.columns and 'trace_w1_avgpool' in df.columns:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Scatter plot
        ax1.scatter(df['trace_w1_avgpool'], df['trace_mmd_avgpool'], alpha=0.7)
        ax1.set_xlabel('TRACE W1 (avgpool)')
        ax1.set_ylabel('TRACE MMD (avgpool)')
        ax1.set_title('TRACE MMD vs TRACE W1')
        ax1.grid(True, alpha=0.3)
        
        # Add diagonal line
        min_val = min(df['trace_w1_avgpool'].min(), df['trace_mmd_avgpool'].min())
        max_val = max(df['trace_w1_avgpool'].max(), df['trace_mmd_avgpool'].max())
        ax1.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5, label='y=x')
        ax1.legend()
        
        # Correlation with true risk
        target = df['delta_R_true'].abs()
        ax2.scatter(df['trace_w1_avgpool'], target, alpha=0.7, label='TRACE W1', s=50)
        ax2.scatter(df['trace_mmd_avgpool'], target, alpha=0.7, label='TRACE MMD', s=50)
        ax2.set_xlabel('TRACE Score')
        ax2.set_ylabel('|ΔR| (True Risk Difference)')
        ax2.set_title('TRACE Scores vs True Risk Difference')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, 'trace_mmd_vs_w1_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    # 2. Individual MMD vs W1 distance comparison
    if 'mmd_avgpool' in df.columns and 'w1_avgpool' in df.columns:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Scatter plot
        ax1.scatter(df['w1_avgpool'], df['mmd_avgpool'], alpha=0.7)
        ax1.set_xlabel('W1 Distance (avgpool)')
        ax1.set_ylabel('MMD Distance (avgpool)')
        ax1.set_title('MMD vs W1 Distance')
        ax1.grid(True, alpha=0.3)
        
        # Add diagonal line
        min_val = min(df['w1_avgpool'].min(), df['mmd_avgpool'].min())
        max_val = max(df['w1_avgpool'].max(), df['mmd_avgpool'].max())
        ax1.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5, label='y=x')
        ax1.legend()
        
        # Correlation with true risk
        target = df['delta_R_true'].abs()
        ax2.scatter(df['w1_avgpool'], target, alpha=0.7, label='W1', s=50)
        ax2.scatter(df['mmd_avgpool'], target, alpha=0.7, label='MMD', s=50)
        ax2.set_xlabel('Distance Metric')
        ax2.set_ylabel('|ΔR| (True Risk Difference)')
        ax2.set_title('Distance Metrics vs True Risk Difference')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, 'mmd_vs_w1_distance_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()

def create_detailed_analysis_table(df, outdir):
    """Create a detailed analysis table comparing MMD vs W1 approaches."""
    
    # Define the predictors to compare
    predictors = {
        'TRACE W1 (avgpool)': 'trace_w1_avgpool',
        'TRACE MMD (avgpool)': 'trace_mmd_avgpool',
        'TRACE W1 (layer2)': 'trace_w1_layer2',
        'TRACE W1 (layer3)': 'trace_w1_layer3',
        'W1 Distance (avgpool)': 'w1_avgpool',
        'MMD Distance (avgpool)': 'mmd_avgpool',
        'W1 Distance (layer3)': 'w1_layer3',
        'MMD Distance (layer3)': 'mmd_layer3',
        'Output Discrepancy': 'outdisc_l2_mean',
        'trace Score': 'trace_score'
    }
    
    # Compute correlations
    target = df['delta_R_true'].abs()
    results = []
    
    for name, col in predictors.items():
        if col not in df.columns:
            continue
            
        scores = df[col].values
        mask = ~np.isnan(scores)
        
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            kendall = kendalltau(scores[mask], target[mask]).correlation
            
            # Compute some basic statistics
            mean_score = scores[mask].mean()
            std_score = scores[mask].std()
            min_score = scores[mask].min()
            max_score = scores[mask].max()
            
            results.append({
                'Predictor': name,
                'Column': col,
                'Spearman ρ': rho,
                'Kendall τ': kendall,
                'Mean': mean_score,
                'Std': std_score,
                'Min': min_score,
                'Max': max_score,
                'N Valid': mask.sum()
            })
        except:
            continue
    
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('Spearman ρ', ascending=False)
    
    # Save to CSV
    results_path = os.path.join(outdir, 'mmd_vs_w1_detailed_analysis.csv')
    results_df.to_csv(results_path, index=False)
    
    # Create a summary table for the paper
    summary_results = results_df[['Predictor', 'Spearman ρ', 'Kendall τ']].copy()
    summary_results = summary_results.round(3)
    summary_path = os.path.join(outdir, 'mmd_vs_w1_summary.csv')
    summary_results.to_csv(summary_path, index=False)
    
    return results_df

def create_performance_comparison_plot(df, outdir):
    """Create a performance comparison plot showing MMD vs W1 effectiveness."""
    
    # Get the correlation results
    predictors = {
        'TRACE W1 (avgpool)': 'trace_w1_avgpool',
        'TRACE MMD (avgpool)': 'trace_mmd_avgpool',
        'W1 Distance': 'w1_avgpool',
        'MMD Distance': 'mmd_avgpool',
        'Output Discrepancy': 'outdisc_l2_mean'
    }
    
    target = df['delta_R_true'].abs()
    results = []
    
    for name, col in predictors.items():
        if col not in df.columns:
            continue
            
        scores = df[col].values
        mask = ~np.isnan(scores)
        
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            results.append({'Predictor': name, 'Spearman ρ': rho, 'Type': 'W1-based' if 'W1' in name else 'MMD-based' if 'MMD' in name else 'Other'})
        except:
            continue
    
    results_df = pd.DataFrame(results)
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    colors = {'W1-based': 'blue', 'MMD-based': 'red', 'Other': 'green'}
    
    for pred_type in results_df['Type'].unique():
        data = results_df[results_df['Type'] == pred_type]
        plt.bar(data['Predictor'], data['Spearman ρ'], 
                color=colors[pred_type], alpha=0.7, label=pred_type)
    
    plt.xlabel('Predictor')
    plt.ylabel('Spearman ρ (Correlation with |ΔR|)')
    plt.title('Performance Comparison: MMD vs W1-based TRACE Diagnostics')
    plt.xticks(rotation=45, ha='right')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, 'performance_comparison_mmd_vs_w1.png'), dpi=300, bbox_inches='tight')
    plt.close()

def main():
    parser = argparse.ArgumentParser(description="Analyze MMD vs W1 results in TRACE framework")
    parser.add_argument("--csv", type=str, required=True, help="Path to deployment_gate_results.csv")
    parser.add_argument("--outdir", type=str, required=True, help="Output directory for analysis results")
    args = parser.parse_args()
    
    # Load data
    print("Loading data...")
    df, mmd_cols, w1_cols, trace_cols = load_and_prepare_data(args.csv)
    
    print(f"\nLoaded {len(df)} candidates")
    print(f"Found {len(mmd_cols)} MMD-related columns")
    print(f"Found {len(w1_cols)} W1-related columns")
    print(f"Found {len(trace_cols)} TRACE-related columns")
    
    # Create output directory
    os.makedirs(args.outdir, exist_ok=True)
    
    # Create comparison plots
    print("\nCreating comparison plots...")
    create_comparison_plots(df, args.outdir)
    
    # Create detailed analysis table
    print("Creating detailed analysis table...")
    results_df = create_detailed_analysis_table(df, args.outdir)
    
    # Create performance comparison plot
    print("Creating performance comparison plot...")
    create_performance_comparison_plot(df, args.outdir)
    
    # Print summary
    print("\n" + "="*60)
    print("SUMMARY OF MMD vs W1 COMPARISON")
    print("="*60)
    
    print("\nTop performing predictors (by Spearman ρ):")
    top_predictors = results_df.head(5)
    for _, row in top_predictors.iterrows():
        print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}, τ = {row['Kendall τ']:.3f}")
    
    # Compare TRACE variants specifically
    trace_results = results_df[results_df['Predictor'].str.contains('TRACE')]
    if len(trace_results) > 0:
        print(f"\nTRACE variants comparison:")
        for _, row in trace_results.iterrows():
            print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}")
    
    # Compare raw distance metrics
    distance_results = results_df[results_df['Predictor'].str.contains('Distance')]
    if len(distance_results) > 0:
        print(f"\nDistance metrics comparison:")
        for _, row in distance_results.iterrows():
            print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}")
    
    print(f"\nAnalysis results saved to: {args.outdir}")
    print("Files created:")
    print("  - mmd_vs_w1_detailed_analysis.csv")
    print("  - mmd_vs_w1_summary.csv")
    print("  - trace_mmd_vs_w1_comparison.png")
    print("  - mmd_vs_w1_distance_comparison.png")
    print("  - performance_comparison_mmd_vs_w1.png")

if __name__ == "__main__":
    main()


