#!/usr/bin/env python3
"""
Analysis script to compare MMD vs W1h results in TRACE framework using the full results dataset.
This script analyzes the main results folder structure.
"""

import argparse
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr, kendalltau
from sklearn.metrics import roc_auc_score, average_precision_score

def load_and_prepare_data(csv_path):
    """Load the deployment gate results and prepare for analysis."""
    df = pd.read_csv(csv_path)
    
    print(f"Loaded {len(df)} candidates")
    print("Available columns:")
    for i, col in enumerate(df.columns):
        print(f"  {i+1}. {col}")
    
    return df

def compute_correlations(df, predictors, target_col='delta_R_true'):
    """Compute rank correlations for different predictors."""
    target = df[target_col].abs().values  # Use absolute risk difference
    
    results = []
    for name, scores in predictors.items():
        if name not in df.columns:
            continue
            
        scores = df[name].values
        # Remove NaN values
        mask = ~np.isnan(scores)
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            kendall = kendalltau(scores[mask], target[mask]).correlation
        except:
            rho = np.nan
            kendall = np.nan
            
        results.append({
            'predictor': name,
            'spearman_rho': rho,
            'kendall_tau': kendall,
            'n_valid': mask.sum()
        })
    
    return pd.DataFrame(results)

def create_comparison_plots(df, outdir):
    """Create comparison plots between MMD and W1 approaches."""
    os.makedirs(outdir, exist_ok=True)
    
    # Set up the plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # 1. trace (TRACE W1) vs MMD comparison
    if 'trace_score' in df.columns and 'mmd_score' in df.columns:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Scatter plot
        ax1.scatter(df['trace_score'], df['mmd_score'], alpha=0.7)
        ax1.set_xlabel('trace Score (TRACE W1)')
        ax1.set_ylabel('MMD Score')
        ax1.set_title('trace vs MMD Scores')
        ax1.grid(True, alpha=0.3)
        
        # Add diagonal line
        min_val = min(df['trace_score'].min(), df['mmd_score'].min())
        max_val = max(df['trace_score'].max(), df['mmd_score'].max())
        ax1.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5, label='y=x')
        ax1.legend()
        
        # Correlation with true risk
        target = df['delta_R_true'].abs()
        ax2.scatter(df['trace_score'], target, alpha=0.7, label='trace (TRACE W1)', s=50)
        ax2.scatter(df['mmd_score'], target, alpha=0.7, label='MMD', s=50)
        ax2.set_xlabel('Score')
        ax2.set_ylabel('|ΔR| (True Risk Difference)')
        ax2.set_title('Scores vs True Risk Difference')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, 'trace_vs_mmd_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    # 2. Component analysis: W1 distance vs MMD distance
    if 'trace_w1' in df.columns and 'mmd_score' in df.columns:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Scatter plot
        ax1.scatter(df['trace_w1'], df['mmd_score'], alpha=0.7)
        ax1.set_xlabel('W1 Distance Component')
        ax1.set_ylabel('MMD Distance')
        ax1.set_title('W1 Distance vs MMD Distance')
        ax1.grid(True, alpha=0.3)
        
        # Add diagonal line
        min_val = min(df['trace_w1'].min(), df['mmd_score'].min())
        max_val = max(df['trace_w1'].max(), df['mmd_score'].max())
        ax1.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5, label='y=x')
        ax1.legend()
        
        # Correlation with true risk
        target = df['delta_R_true'].abs()
        ax2.scatter(df['trace_w1'], target, alpha=0.7, label='W1 Distance', s=50)
        ax2.scatter(df['mmd_score'], target, alpha=0.7, label='MMD Distance', s=50)
        ax2.set_xlabel('Distance Metric')
        ax2.set_ylabel('|ΔR| (True Risk Difference)')
        ax2.set_title('Distance Metrics vs True Risk Difference')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, 'w1_vs_mmd_distance_comparison.png'), dpi=300, bbox_inches='tight')
        plt.close()

def create_detailed_analysis_table(df, outdir):
    """Create a detailed analysis table comparing MMD vs W1 approaches."""
    
    # Define the predictors to compare
    predictors = {
        'trace Score (TRACE W1)': 'trace_score',
        'W1 Distance Component': 'trace_w1',
        'Output Discrepancy': 'trace_output_dist',
        'MMD Distance': 'mmd_score',
        'MSP Score': 'msp_score',
        'Energy Score': 'energy_score'
    }
    
    # Compute correlations
    target = df['delta_R_true'].abs()
    results = []
    
    for name, col in predictors.items():
        if col not in df.columns:
            continue
            
        scores = df[col].values
        mask = ~np.isnan(scores)
        
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            kendall = kendalltau(scores[mask], target[mask]).correlation
            
            # Compute some basic statistics
            mean_score = scores[mask].mean()
            std_score = scores[mask].std()
            min_score = scores[mask].min()
            max_score = scores[mask].max()
            
            results.append({
                'Predictor': name,
                'Column': col,
                'Spearman ρ': rho,
                'Kendall τ': kendall,
                'Mean': mean_score,
                'Std': std_score,
                'Min': min_score,
                'Max': max_score,
                'N Valid': mask.sum()
            })
        except:
            continue
    
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('Spearman ρ', ascending=False)
    
    # Save to CSV
    results_path = os.path.join(outdir, 'mmd_vs_w1_detailed_analysis.csv')
    results_df.to_csv(results_path, index=False)
    
    # Create a summary table for the paper
    summary_results = results_df[['Predictor', 'Spearman ρ', 'Kendall τ']].copy()
    summary_results = summary_results.round(3)
    summary_path = os.path.join(outdir, 'mmd_vs_w1_summary.csv')
    summary_results.to_csv(summary_path, index=False)
    
    return results_df

def create_performance_comparison_plot(df, outdir):
    """Create a performance comparison plot showing MMD vs W1 effectiveness."""
    
    # Get the correlation results
    predictors = {
        'trace (TRACE W1)': 'trace_score',
        'W1 Distance': 'trace_w1',
        'Output Discrepancy': 'trace_output_dist',
        'MMD Distance': 'mmd_score',
        'MSP Score': 'msp_score',
        'Energy Score': 'energy_score'
    }
    
    target = df['delta_R_true'].abs()
    results = []
    
    for name, col in predictors.items():
        if col not in df.columns:
            continue
            
        scores = df[col].values
        mask = ~np.isnan(scores)
        
        if mask.sum() < 2:
            continue
            
        try:
            rho = spearmanr(scores[mask], target[mask]).correlation
            results.append({
                'Predictor': name, 
                'Spearman ρ': rho, 
                'Type': 'W1-based' if 'W1' in name or 'trace' in name else 'MMD-based' if 'MMD' in name else 'Other'
            })
        except:
            continue
    
    results_df = pd.DataFrame(results)
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    colors = {'W1-based': 'blue', 'MMD-based': 'red', 'Other': 'green'}
    
    for pred_type in results_df['Type'].unique():
        data = results_df[results_df['Type'] == pred_type]
        plt.bar(data['Predictor'], data['Spearman ρ'], 
                color=colors[pred_type], alpha=0.7, label=pred_type)
    
    plt.xlabel('Predictor')
    plt.ylabel('Spearman ρ (Correlation with |ΔR|)')
    plt.title('Performance Comparison: MMD vs W1-based TRACE Diagnostics')
    plt.xticks(rotation=45, ha='right')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, 'performance_comparison_mmd_vs_w1.png'), dpi=300, bbox_inches='tight')
    plt.close()

def create_trace_components_analysis(df, outdir):
    """Analyze the TRACE components separately."""
    
    # Create a TRACE MMD score by combining MMD distance with output discrepancy
    if 'mmd_score' in df.columns and 'trace_output_dist' in df.columns:
        # We need to scale MMD to be comparable to W1
        # Let's use the same scaling as in the trace score
        w1_component = df['trace_w1'].values
        mmd_component = df['mmd_score'].values
        output_component = df['trace_output_dist'].values
        
        # Create a TRACE MMD score (similar to trace but with MMD instead of W1)
        # We'll use a simple linear combination for now
        trace_mmd_score = output_component + mmd_component * 1000  # Scale MMD up to be comparable
        
        # Add to dataframe
        df['trace_mmd_score'] = trace_mmd_score
        
        # Compute correlation
        target = df['delta_R_true'].abs()
        try:
            rho_trace_mmd = spearmanr(trace_mmd_score, target).correlation
            rho_trace = spearmanr(df['trace_score'], target).correlation
            
            print(f"\nTRACE Component Analysis:")
            print(f"  trace Score (TRACE W1): ρ = {rho_trace:.3f}")
            print(f"  TRACE MMD Score: ρ = {rho_trace_mmd:.3f}")
            print(f"  Output Discrepancy: ρ = {spearmanr(output_component, target).correlation:.3f}")
            print(f"  W1 Distance: ρ = {spearmanr(w1_component, target).correlation:.3f}")
            print(f"  MMD Distance: ρ = {spearmanr(mmd_component, target).correlation:.3f}")
            
        except:
            pass

def main():
    parser = argparse.ArgumentParser(description="Analyze MMD vs W1 results in TRACE framework using full results")
    parser.add_argument("--csv", type=str, required=True, help="Path to deployment_gate_results.csv")
    parser.add_argument("--outdir", type=str, required=True, help="Output directory for analysis results")
    args = parser.parse_args()
    
    # Load data
    print("Loading data...")
    df = load_and_prepare_data(args.csv)
    
    # Create output directory
    os.makedirs(args.outdir, exist_ok=True)
    
    # Create comparison plots
    print("\nCreating comparison plots...")
    create_comparison_plots(df, args.outdir)
    
    # Create detailed analysis table
    print("Creating detailed analysis table...")
    results_df = create_detailed_analysis_table(df, args.outdir)
    
    # Create performance comparison plot
    print("Creating performance comparison plot...")
    create_performance_comparison_plot(df, args.outdir)
    
    # Create TRACE components analysis
    print("Creating TRACE components analysis...")
    create_trace_components_analysis(df, args.outdir)
    
    # Print summary
    print("\n" + "="*60)
    print("SUMMARY OF MMD vs W1 COMPARISON (FULL RESULTS)")
    print("="*60)
    
    print("\nTop performing predictors (by Spearman ρ):")
    top_predictors = results_df.head(5)
    for _, row in top_predictors.iterrows():
        print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}, τ = {row['Kendall τ']:.3f}")
    
    # Compare W1 vs MMD specifically
    w1_results = results_df[results_df['Predictor'].str.contains('W1|trace')]
    mmd_results = results_df[results_df['Predictor'].str.contains('MMD')]
    
    if len(w1_results) > 0:
        print(f"\nW1-based predictors:")
        for _, row in w1_results.iterrows():
            print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}")
    
    if len(mmd_results) > 0:
        print(f"\nMMD-based predictors:")
        for _, row in mmd_results.iterrows():
            print(f"  {row['Predictor']}: ρ = {row['Spearman ρ']:.3f}")
    
    print(f"\nAnalysis results saved to: {args.outdir}")
    print("Files created:")
    print("  - mmd_vs_w1_detailed_analysis.csv")
    print("  - mmd_vs_w1_summary.csv")
    print("  - trace_vs_mmd_comparison.png")
    print("  - w1_vs_mmd_distance_comparison.png")
    print("  - performance_comparison_mmd_vs_w1.png")

if __name__ == "__main__":
    main()


