#!/usr/bin/env python3

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import argparse
from pathlib import Path
from adjustText import adjust_text


def load_tau_data(dataset_name, models, tau_values, results_dir="experiments/plot_data/tau"):
    """Load per-model metrics for each tau value."""
    all_data = []
    
    for model in models:
        filename = f"{results_dir}/tau_{dataset_name}_{model}.csv"
        
        if not os.path.exists(filename):
            print(f"Missing file: {filename}")
            continue
            
        try:
            df = pd.read_csv(filename)
            print(f"Loaded {filename}: {len(df)} rows")
            
            # Compute mean and std for each tau value
            for tau in tau_values:
                tau_data = df[df['tau'] == tau]
                if len(tau_data) > 0:
                    avg_metrics = {
                        'model': model,
                        'tau': tau,
                        'shd': tau_data['shd'].mean(),
                        'sid': tau_data['sid'].mean(),
                        'precision': tau_data['precision'].mean(),
                        'recall': tau_data['recall'].mean(),
                        'F1': tau_data['F1'].mean(),
                        'fdr': tau_data['fdr'].mean(),
                        'tpr': tau_data['tpr'].mean(),
                        'fpr': tau_data['fpr'].mean(),
                        'fnr': tau_data['fnr'].mean(),
                        # Store standard deviation as well
                        'shd_std': tau_data['shd'].std(),
                        'sid_std': tau_data['sid'].std(),
                        'precision_std': tau_data['precision'].std(),
                        'recall_std': tau_data['recall'].std(),
                        'F1_std': tau_data['F1'].std(),
                        'fdr_std': tau_data['fdr'].std(),
                        'tpr_std': tau_data['tpr'].std(),
                        'fpr_std': tau_data['fpr'].std(),
                        'fnr_std': tau_data['fnr'].std()
                    }
                    all_data.append(avg_metrics)
                    
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            continue
    
    if not all_data:
        print(f"No data loaded for {dataset_name}")
        return pd.DataFrame()
        
    return pd.DataFrame(all_data)

def create_tau_analysis_plot(df, dataset_name, output_dir):
    """Create a five-panel visualization for tau analysis."""
    
    if df.empty:
        return
    
    # Configure target tau per dataset when available
    if 'ER' in dataset_name.upper():
        target_tau = 0.01
    elif 'SF' in dataset_name.upper():
        target_tau = 0.025
    else:
        target_tau = None  # No default target
    
    # Configure colors

    colors_2 = sns.color_palette("tab20c", 20)
    colors = sns.color_palette("Paired", 12)
    color_map = {
        "CAM": colors_2[17],  
        "SCORE": colors[7],  
        "DAS": colors[1],  
        "NoGAM": colors[3],  
        "DiffAN": colors[9],  
        "CaPS": colors[5] 
    }

    marker_map = {
    "CAM": 'o',
    "SCORE": 'o',
    "DAS": 'o',
    "NoGAM": 'o',
    "DiffAN": 'o',
    "CaPS": 'o',
    }

    ########
    models = df['model'].unique()
    model_colors = {model: color_map[model] for model in models}
    model_markers = {model: marker_map[model] for model in models}

    fig, axes = plt.subplots(1, 5, figsize=(25, 4))  # Slightly taller to keep legend readable
    # fig.suptitle(f'Tau Analysis - {dataset_name}', fontsize=20, fontweight='bold', y=0.95)
    
    # 1. SHD vs Tau Trend
    ax1 = axes[0]
    for model in models:
        model_data = df[df['model'] == model].sort_values('tau')
        if len(model_data) > 0:
            x = model_data['tau']
            y = model_data['shd']
            y_std = model_data['shd_std'].fillna(0)  # Replace NaN with zero
            
            # Plot the mean curve with markers
            marker_style = model_markers[model] + '-'
            ax1.plot(x, y, marker_style, color=model_colors[model], 
                    linewidth=2, markersize=8, alpha=0.8)
            
            # Shade one standard deviation
            ax1.fill_between(x, y - y_std, y + y_std, 
                           color=model_colors[model], alpha=0.2)
            
            # Highlight the target tau if defined
            if target_tau is not None:
                target_data = model_data[model_data['tau'] == target_tau]
                if len(target_data) > 0:
                    target_val = target_data['shd'].iloc[0]
                    ax1.scatter(target_tau, target_val, s=250, marker='*', 
                               color=model_colors[model], edgecolor='black', linewidth=1.0, zorder=10)
    
    # ax1.set_title(r'SHD by $\tau$', fontweight='bold', fontsize=14)
    ax1.set_xlabel('Threshold (tau)', fontsize=14, fontweight='normal')
    ax1.set_ylabel(r'SHD $\downarrow$', fontsize=14, fontweight='normal')
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='x', labelsize=14)
    ax1.tick_params(axis='y', labelsize=14)

    # 2. SID vs Tau Trend
    ax2 = axes[1]
    for model in models:
        model_data = df[df['model'] == model].sort_values('tau')
        if len(model_data) > 0:
            x = model_data['tau']
            y = model_data['sid']
            y_std = model_data['sid_std'].fillna(0)  # Replace NaN with zero
            
            # Plot the mean curve with markers
            marker_style = model_markers[model] + '-'
            ax2.plot(x, y, marker_style, color=model_colors[model], 
                    linewidth=2, markersize=8, alpha=0.8)
            
            # Shade one standard deviation
            ax2.fill_between(x, y - y_std, y + y_std, 
                           color=model_colors[model], alpha=0.2)
            
            # Highlight the target tau if defined
            if target_tau is not None:
                target_data = model_data[model_data['tau'] == target_tau]
                if len(target_data) > 0:
                    target_val = target_data['sid'].iloc[0]
                    ax2.scatter(target_tau, target_val, s=250, marker='*', 
                               color=model_colors[model], edgecolor='black', linewidth=1.0, zorder=10)
    
    # ax2.set_title(r'SID by $\tau$', fontweight='bold', fontsize=14)
    ax2.set_xlabel('Threshold (tau)', fontsize=16, fontweight='normal')
    ax2.set_ylabel(r'SID $\downarrow$', fontsize=16, fontweight='normal')
    ax2.grid(True, alpha=0.3)
    ax2.tick_params(axis='x', labelsize=14)
    ax2.tick_params(axis='y', labelsize=14)
    # 3. F1 vs Tau Trend
    ax3 = axes[2]
    for model in models:
        model_data = df[df['model'] == model].sort_values('tau')
        if len(model_data) > 0:
            x = model_data['tau']
            y = model_data['F1']
            y_std = model_data['F1_std'].fillna(0)  # Replace NaN with zero
            
            # Plot the mean curve with markers
            marker_style = model_markers[model] + '-'
            ax3.plot(x, y, marker_style, color=model_colors[model], 
                    linewidth=2, markersize=8, alpha=0.8)
            
            # Shade one standard deviation
            ax3.fill_between(x, y - y_std, y + y_std, 
                           color=model_colors[model], alpha=0.2)
            
            # Highlight the target tau if defined
            if target_tau is not None:
                target_data = model_data[model_data['tau'] == target_tau]
                if len(target_data) > 0:
                    target_val = target_data['F1'].iloc[0]
                    ax3.scatter(target_tau, target_val, s=250, marker='*', 
                               color=model_colors[model], edgecolor='black', linewidth=1.0, zorder=10)
    
    # ax3.set_title(r'F1 by $\tau$', fontweight='bold', fontsize=14)
    ax3.set_xlabel('Threshold (tau)', fontsize=16, fontweight='normal')
    ax3.set_ylabel(r'F1 $\uparrow$', fontsize=16, fontweight='normal')
    ax3.set_ylim(0, 1)
    ax3.grid(True, alpha=0.3)
    ax3.tick_params(axis='x', labelsize=14)
    ax3.tick_params(axis='y', labelsize=14)
    # 4. ROC curve (zoomed region)
    ax4 = axes[3]
    # Determine a suitable zoom level for the ROC curve
    all_fpr, all_tpr = [], []
    texts_roc = []  # Store annotations for adjustText
    
    for model in models:
        model_data = df[df['model'] == model].sort_values('tau')
        if len(model_data) > 0:
            tau_vals = model_data['tau']
            tpr = model_data['tpr']
            fpr = model_data['fpr']
            
            all_fpr.extend(fpr.tolist())
            all_tpr.extend(tpr.tolist())
            
            # Plot the ROC curve
            ax4.plot(fpr, tpr, '-', color=model_colors[model],
                    linewidth=2, alpha=0.8, label=rf'{model}-$\mathrm{{P}}^2$')
            ax4.scatter(fpr, tpr, s=40, marker=model_markers[model], color=model_colors[model], alpha=0.8, zorder=5)
            
            # Label only a subset of points to avoid clutter
            if model == 'CAM':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (f, t, tau) in enumerate(zip(fpr, tpr, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax4.annotate(f'{tau:.2f}', (f, t), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='left', va='top',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_roc.append(text)
            elif model == 'DiffAN':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (f, t, tau) in enumerate(zip(fpr, tpr, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax4.annotate(f'{tau:.2f}', (f, t), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='right', va='bottom',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_roc.append(text)
            elif model == 'CaPS':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (f, t, tau) in enumerate(zip(fpr, tpr, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax4.annotate(f'{tau:.2f}', (f, t), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='right', va='bottom',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_roc.append(text)
            else:
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (f, t, tau) in enumerate(zip(fpr, tpr, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax4.annotate(f'{tau:.2f}', (f, t), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='left', va='top',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_roc.append(text)

            # Highlight the target tau if defined
            if target_tau is not None:
                target_data = model_data[model_data['tau'] == target_tau]
                if len(target_data) > 0:
                    target_fpr = target_data['fpr'].iloc[0]
                    target_tpr = target_data['tpr'].iloc[0]
                    ax4.scatter(target_fpr, target_tpr, s=250, marker='*', 
                               color=model_colors[model], edgecolor='black', linewidth=1.0, zorder=10)
    
    # Reference diagonal (random classifier)
    ax4.plot([0, 1], [0, 1], 'k--', alpha=0.3)
    
    # ax4.set_title('ROC Curve', fontweight='bold', fontsize=14)
    ax4.set_xlabel('False Positive Rate (FPR)', fontsize=16, fontweight='normal')
    ax4.set_ylabel('True Positive Rate (TPR)', fontsize=16, fontweight='normal')
    ax4.tick_params(axis='x', labelsize=14)
    ax4.tick_params(axis='y', labelsize=14)
    # Adjust the zoom limits based on the data range
    if all_fpr and all_tpr:
        fpr_min, fpr_max = min(all_fpr), max(all_fpr)
        tpr_min, tpr_max = min(all_tpr), max(all_tpr)
        margin = 0.15  # Expand margins so tau labels remain visible
        ax4.set_xlim(max(0, fpr_min - margin), fpr_max + margin)
        ax4.set_ylim(max(0, tpr_min - margin), min(1, tpr_max + margin))
    
    ax4.grid(True, alpha=0.3)
    
    # Use adjustText to avoid overlapping ROC labels
    if texts_roc:
        adjust_text(texts_roc, ax=ax4, 
                   arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5, lw=0.5),
                   force_points=0.2, force_text=0.5)
    
    # 5. PR curve (zoomed region)
    ax5 = axes[4]
    # Determine a suitable zoom level for the PR curve
    all_precision, all_recall = [], []
    texts_pr = []  # Store annotations for adjustText
    
    for model in models:
        model_data = df[df['model'] == model].sort_values('tau')
        if len(model_data) > 0:
            tau_vals = model_data['tau']
            precision = model_data['precision']
            recall = model_data['recall']
            
            all_precision.extend(precision.tolist())
            all_recall.extend(recall.tolist())
            
            # Plot the PR curve
            ax5.plot(recall, precision, '-', color=model_colors[model],
                    linewidth=2, alpha=0.8, label=rf'{model}-$\mathrm{{P}}^2$')
            ax5.scatter(recall, precision, s=40, marker=model_markers[model], color=model_colors[model], alpha=0.8, zorder=5)
            
            # Label only a subset of points to avoid clutter
            if model == 'CAM':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (r, p, tau) in enumerate(zip(recall, precision, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax5.annotate(f'{tau:.2f}', (r, p), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='left', va='bottom',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_pr.append(text)

            elif model == 'DiffAN':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (r, p, tau) in enumerate(zip(recall, precision, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax5.annotate(f'{tau:.2f}', (r, p), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='right', va='top',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_pr.append(text)

            elif model == 'CaPS':
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (r, p, tau) in enumerate(zip(recall, precision, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax5.annotate(f'{tau:.2f}', (r, p), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='right', va='top',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_pr.append(text)

            else:
                label_interval = max(1, len(tau_vals) // 4)  # Label up to four points
                for i, (r, p, tau) in enumerate(zip(recall, precision, tau_vals)):
                    if i % label_interval == 0 or i == len(tau_vals) - 1:  # First, last, and evenly spaced points
                        text = ax5.annotate(f'{tau:.2f}', (r, p), 
                                        fontsize=9, alpha=0.9, color=model_colors[model],
                                        fontweight='normal', ha='left', va='bottom',
                                        bbox=dict(boxstyle="round,pad=0.2", facecolor='white', 
                                                alpha=0.8, edgecolor=model_colors[model], linewidth=0.5))
                        texts_pr.append(text)

            # Highlight the target tau if defined
            if target_tau is not None:
                target_data = model_data[model_data['tau'] == target_tau]
                if len(target_data) > 0:
                    target_recall = target_data['recall'].iloc[0]
                    target_precision = target_data['precision'].iloc[0]
                    ax5.scatter(target_recall, target_precision, s=250, marker='*', 
                               color=model_colors[model], edgecolor='black', linewidth=1.0, zorder=10)
    
    # ax5.set_title('PR Curve', fontweight='bold', fontsize=14)
    ax5.set_xlabel('Recall', fontsize=16, fontweight='normal')
    ax5.set_ylabel('Precision', fontsize=16, fontweight='normal')
    ax5.tick_params(axis='x', labelsize=14)
    ax5.tick_params(axis='y', labelsize=14)
    # Adjust the zoom limits based on the data range
    if all_precision and all_recall:
        precision_min, precision_max = min(all_precision), max(all_precision)
        recall_min, recall_max = min(all_recall), max(all_recall)
        margin = 0.15  # Expand margins so tau labels remain visible
        ax5.set_xlim(max(0, recall_min - margin), min(1, recall_max + margin))
        ax5.set_ylim(max(0, precision_min - margin), min(1, precision_max + margin))
    
    ax5.grid(True, alpha=0.3)
    
    # Use adjustText to avoid overlapping PR labels
    if texts_pr:
        adjust_text(texts_pr, ax=ax5, 
                   arrowprops=dict(arrowstyle='->', color='gray', alpha=0.5, lw=0.5),
                   force_points=0.2, force_text=0.5)
    
    # Build a single legend above the subplots
    handles = []
    labels = []
    for model in models:
        handles.append(plt.Line2D([0], [0], color=model_colors[model], marker=model_markers[model], 
                                 linewidth=2, markersize=8, alpha=0.8))
        labels.append(rf'{model} w/ PEP')
    
    fig.legend(handles, labels, loc='upper center', ncol=len(models), 
              bbox_to_anchor=(0.5, 1.02), fontsize=16) # frameon=True, fancybox=True, shadow=False
    
    plt.tight_layout(rect=[0, 0, 1, 0.9])  # Reserve space for the legend
    
    # Persist the result
    os.makedirs(output_dir, exist_ok=True)
    output_path = f"{output_dir}/tau_analysis_{dataset_name}.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    
    print(f"Tau analysis plot saved: {output_path}")

def main():
    parser = argparse.ArgumentParser(description='Generate tau analysis visualizations')
    parser.add_argument('--dataset', type=str, help='Dataset name (optional, for single dataset analysis)')
    parser.add_argument('--results_dir', type=str, default='experiments/plot_data/tau', help='Results directory')
    parser.add_argument('--output_dir', type=str, default='fig/tau', help='Output directory')
    parser.add_argument('--combined', action='store_true', help='Create combined plot for multiple datasets')
    
    args = parser.parse_args()
    
    # Configure default experiment scope
    datasets = ["SynER4", "SynSF4"] if not args.dataset else [args.dataset]
    models = ["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS"]  # Include NoGAM
    tau_values = [0, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05]
    
    print("Tau analysis started")
    print(f"Datasets: {datasets}")
    print(f"Models: {models}")
    print(f"Tau values: {tau_values}")
    print("=" * 60)
    
    if args.combined and len(datasets) > 1:
        print("Combined plot functionality is currently disabled")
        return
    else:
        # Individual plots for each dataset
        for dataset in datasets:
            print(f"\nProcessing {dataset}...")
            
            # Load data for the dataset/model combination
            df = load_tau_data(dataset, models, tau_values, args.results_dir)
            
            if df.empty:
                print(f"No data for {dataset}, skipping...")
                continue
            
            # Run the visualization pipeline
            create_tau_analysis_plot(df, dataset, args.output_dir)
    
    print("\nTau analysis completed")
    print(f"Results saved in: {args.output_dir}")

if __name__ == "__main__":
    main() 
