#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Compare results from all three approaches:
1. Mixed benchmark (original)
2. Single benchmark 
3. Improved mixed benchmark

Generate comprehensive comparison visualizations
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_metrics_data(base_dir, metric_name):
    """
    Load metrics data from a specific directory
    """
    metrics_file = os.path.join(base_dir, "04_metrics", f"{metric_name}_summary.csv")
    if os.path.exists(metrics_file):
        return pd.read_csv(metrics_file)
    else:
        print(f"Warning: {metrics_file} not found")
        return None

def create_comparison_plot(mixed_df, single_df, improved_df, metric_name, output_dir):
    """
    Create comparison plot for a specific metric
    """
    plt.figure(figsize=(12, 8))
    
    # Define styles for each approach
    styles = {
        "Mixed Benchmark": ("red", "solid"),
        "Single Benchmark": ("blue", "dashed"),
        "Improved Mixed Benchmark": ("green", "dashdot")
    }
    
    # Plot mixed benchmark data
    if mixed_df is not None and 'multibench_irt' in mixed_df.columns:
        plt.plot(mixed_df['Train_Ratio'], 
                [float(x.split(' ± ')[0]) for x in mixed_df['multibench_irt']], 
                label="Mixed Benchmark (Multi-IRT)", 
                color=styles["Mixed Benchmark"][0], 
                linestyle=styles["Mixed Benchmark"][1],
                linewidth=2, marker="o", markersize=8)
    
    # Plot single benchmark data
    if single_df is not None:
        if metric_name == 'auc' and 'irt_1pl' in single_df.columns:
            plt.plot(single_df['Train_Ratio'], 
                    [float(x.split(' ± ')[0]) for x in single_df['irt_1pl']], 
                    label="Single Benchmark (IRT-1PL)", 
                    color=styles["Single Benchmark"][0], 
                    linestyle=styles["Single Benchmark"][1],
                    linewidth=2, marker="s", markersize=8)
        elif metric_name == 'auc' and 'irt_2pl' in single_df.columns:
            plt.plot(single_df['Train_Ratio'], 
                    [float(x.split(' ± ')[0]) for x in single_df['irt_2pl']], 
                    label="Single Benchmark (IRT-2PL)", 
                    color="purple", 
                    linestyle="--",
                    linewidth=2, marker="^", markersize=8)
    
    # Plot improved mixed benchmark data
    if improved_df is not None and 'multibench_irt' in improved_df.columns:
        plt.plot(improved_df['Train_Ratio'], 
                [float(x.split(' ± ')[0]) for x in improved_df['multibench_irt']], 
                label="Improved Mixed Benchmark (Multi-IRT)", 
                color=styles["Improved Mixed Benchmark"][0], 
                linestyle=styles["Improved Mixed Benchmark"][1],
                linewidth=2, marker="d", markersize=8)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=14)
    plt.ylabel(metric_name.upper(), fontsize=14)
    plt.title(f"{metric_name.upper()} Comparison Across Different Approaches", fontsize=16)
    plt.legend(loc="best", fontsize=12)
    plt.grid(True, alpha=0.3)
    
    # Set x-axis ticks
    all_ratios = set()
    if mixed_df is not None:
        all_ratios.update(mixed_df['Train_Ratio'])
    if single_df is not None:
        all_ratios.update(single_df['Train_Ratio'])
    if improved_df is not None:
        all_ratios.update(improved_df['Train_Ratio'])
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=12)
    plt.yticks(fontsize=12)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, f"{metric_name}_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved {metric_name} comparison plot to: {output_path}")

def create_combined_comparison_plot(mixed_df, single_df, improved_df, output_dir):
    """
    Create a combined plot showing multiple metrics comparison
    """
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle("Performance Comparison Across Different Approaches", fontsize=18)
    
    metrics = ['auc', 'accuracy', 'mse', 'f1']
    metric_names = ['AUC', 'Accuracy', 'MSE', 'F1 Score']
    
    # Define styles for each approach
    styles = {
        "Mixed Benchmark": ("red", "solid"),
        "Single Benchmark": ("blue", "dashed"),
        "Improved Mixed Benchmark": ("green", "dashdot")
    }
    
    for i, (metric, metric_name) in enumerate(zip(metrics, metric_names)):
        ax = axes[i//2, i%2]
        
        # Plot mixed benchmark data
        if mixed_df is not None and metric == 'auc' and 'multibench_irt' in mixed_df.columns:
            ax.plot(mixed_df['Train_Ratio'], 
                   [float(x.split(' ± ')[0]) for x in mixed_df['multibench_irt']], 
                   label="Mixed Benchmark (Multi-IRT)", 
                   color=styles["Mixed Benchmark"][0], 
                   linestyle=styles["Mixed Benchmark"][1],
                   linewidth=2, marker="o", markersize=6)
        
        # Plot single benchmark data
        single_metric_file = os.path.join("yourpath/result_single_benchmark", 
                                         "04_metrics", f"{metric}_summary.csv")
        if os.path.exists(single_metric_file):
            single_metric_df = pd.read_csv(single_metric_file)
            if metric in ['auc', 'f1'] and 'irt_1pl' in single_metric_df.columns:
                ax.plot(single_metric_df['Train_Ratio'], 
                       [float(x.split(' ± ')[0]) for x in single_metric_df['irt_1pl']], 
                       label="Single Benchmark (IRT-1PL)", 
                       color=styles["Single Benchmark"][0], 
                       linestyle=styles["Single Benchmark"][1],
                       linewidth=2, marker="s", markersize=6)
            elif metric == 'accuracy' and 'model_mean' in single_metric_df.columns:
                ax.plot(single_metric_df['Train_Ratio'], 
                       [float(x.split(' ± ')[0]) for x in single_metric_df['model_mean']], 
                       label="Single Benchmark (Model Mean)", 
                       color=styles["Single Benchmark"][0], 
                       linestyle=styles["Single Benchmark"][1],
                       linewidth=2, marker="s", markersize=6)
            elif metric == 'mse' and 'model_mean' in single_metric_df.columns:
                ax.plot(single_metric_df['Train_Ratio'], 
                       [float(x.split(' ± ')[0]) for x in single_metric_df['model_mean']], 
                       label="Single Benchmark (Model Mean)", 
                       color=styles["Single Benchmark"][0], 
                       linestyle=styles["Single Benchmark"][1],
                       linewidth=2, marker="s", markersize=6)
        
        # Plot improved mixed benchmark data
        if improved_df is not None and 'multibench_irt' in improved_df.columns:
            ax.plot(improved_df['Train_Ratio'], 
                   [float(x.split(' ± ')[0]) for x in improved_df['multibench_irt']], 
                   label="Improved Mixed Benchmark (Multi-IRT)", 
                   color=styles["Improved Mixed Benchmark"][0], 
                   linestyle=styles["Improved Mixed Benchmark"][1],
                   linewidth=2, marker="d", markersize=6)
        
        # Formatting
        ax.set_xlabel("Training Data Ratio", fontsize=12)
        ax.set_ylabel(metric_name, fontsize=12)
        ax.set_title(metric_name, fontsize=14)
        ax.grid(True, alpha=0.3)
        
        # Set x-axis ticks
        all_ratios = set()
        if mixed_df is not None:
            all_ratios.update(mixed_df['Train_Ratio'])
        if os.path.exists(single_metric_file):
            all_ratios.update(single_metric_df['Train_Ratio'])
        if improved_df is not None:
            all_ratios.update(improved_df['Train_Ratio'])
        
        sorted_ratios = sorted(list(all_ratios))
        ax.set_xticks(sorted_ratios)
        ax.set_xticklabels([f"{r:.1f}" for r in sorted_ratios], fontsize=10)
        ax.tick_params(axis='y', labelsize=10)
    
    # Add legend to the last subplot
    axes[1, 1].legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=10)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "combined_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved combined comparison plot to: {output_path}")

def create_performance_table(mixed_df, single_df, improved_df, output_dir):
    """
    Create a performance comparison table at the highest train ratio
    """
    # Get the highest train ratio available in each dataset
    max_ratios = {}
    dataframes = {
        "Mixed Benchmark": mixed_df,
        "Single Benchmark": single_df,
        "Improved Mixed Benchmark": improved_df
    }
    
    for name, df in dataframes.items():
        if df is not None and 'Train_Ratio' in df.columns:
            max_ratios[name] = max(df['Train_Ratio'])
    
    # Create comparison table
    comparison_data = []
    
    # Add mixed benchmark data
    if mixed_df is not None and 'multibench_irt' in mixed_df.columns:
        max_ratio = max_ratios.get("Mixed Benchmark", 1.0)
        row_data = mixed_df[mixed_df['Train_Ratio'] == max_ratio].iloc[0]
        auc_val = row_data['multibench_irt']
        comparison_data.append({
            'Approach': 'Mixed Benchmark (Multi-IRT)',
            'Train_Ratio': max_ratio,
            'AUC': auc_val,
            'Accuracy': 'N/A',
            'MSE': 'N/A'
        })
    
    # Add single benchmark data
    if single_df is not None:
        max_ratio = max_ratios.get("Single Benchmark", 1.0)
        row_data = single_df[single_df['Train_Ratio'] == max_ratio].iloc[0]
        if 'irt_1pl' in row_data:
            auc_val = row_data['irt_1pl']
        else:
            auc_val = 'N/A'
            
        if 'model_mean' in row_data:
            acc_val = row_data['model_mean']
            mse_val = row_data['model_mean']  # This is not correct, but we'll fix it
        else:
            acc_val = 'N/A'
            mse_val = 'N/A'
            
        comparison_data.append({
            'Approach': 'Single Benchmark (IRT-1PL)',
            'Train_Ratio': max_ratio,
            'AUC': auc_val,
            'Accuracy': acc_val,
            'MSE': mse_val
        })
    
    # Add improved mixed benchmark data
    if improved_df is not None and 'multibench_irt' in improved_df.columns:
        max_ratio = max_ratios.get("Improved Mixed Benchmark", 1.0)
        row_data = improved_df[improved_df['Train_Ratio'] == max_ratio].iloc[0]
        auc_val = row_data['multibench_irt']
        acc_val = row_data['multibench_irt']  # This is not correct, but we'll fix it
        mse_val = row_data['multibench_irt']  # This is not correct, but we'll fix it
        comparison_data.append({
            'Approach': 'Improved Mixed Benchmark (Multi-IRT)',
            'Train_Ratio': max_ratio,
            'AUC': auc_val,
            'Accuracy': acc_val,
            'MSE': mse_val
        })
    
    # Load actual accuracy and MSE data
    # Load accuracy data
    acc_files = {
        "Mixed Benchmark": "yourpath/result_mixed_benchmark/04_metrics/mse_summary.csv",
        "Single Benchmark": "yourpath/result_single_benchmark/04_metrics/accuracy_summary.csv",
        "Improved Mixed Benchmark": "yourpath/result_improved_mixed_benchmark/04_metrics/accuracy_summary.csv"
    }
    
    mse_files = {
        "Mixed Benchmark": "yourpath/result_mixed_benchmark/04_metrics/mse_summary.csv",
        "Single Benchmark": "yourpath/result_single_benchmark/04_metrics/mse_summary.csv",
        "Improved Mixed Benchmark": "yourpath/result_improved_mixed_benchmark/04_metrics/mse_summary.csv"
    }
    
    # Update accuracy values
    for i, data in enumerate(comparison_data):
        approach = data['Approach'].split(' (')[0]
        if approach in acc_files and os.path.exists(acc_files[approach]):
            acc_df = pd.read_csv(acc_files[approach])
            max_ratio = data['Train_Ratio']
            if approach == "Mixed Benchmark":
                # For mixed benchmark, we need to calculate accuracy from MSE or use a different approach
                acc_val = "N/A"
            elif approach == "Single Benchmark" and 'model_mean' in acc_df.columns:
                row_data = acc_df[acc_df['Train_Ratio'] == max_ratio].iloc[0]
                acc_val = row_data['model_mean']
            elif approach == "Improved Mixed Benchmark" and 'multibench_irt' in acc_df.columns:
                row_data = acc_df[acc_df['Train_Ratio'] == max_ratio].iloc[0]
                acc_val = row_data['multibench_irt']
            else:
                acc_val = "N/A"
            comparison_data[i]['Accuracy'] = acc_val
    
    # Update MSE values
    for i, data in enumerate(comparison_data):
        approach = data['Approach'].split(' (')[0]
        if approach in mse_files and os.path.exists(mse_files[approach]):
            mse_df = pd.read_csv(mse_files[approach])
            max_ratio = data['Train_Ratio']
            if approach == "Mixed Benchmark" and 'Multi-benchmark IRT' in mse_df.columns:
                row_data = mse_df[mse_df['Train_Ratio'] == max_ratio].iloc[0]
                mse_val = row_data['Multi-benchmark IRT']
            elif approach == "Single Benchmark" and 'Model Mean' in mse_df.columns:
                row_data = mse_df[mse_df['Train_Ratio'] == max_ratio].iloc[0]
                mse_val = row_data['Model Mean']
            elif approach == "Improved Mixed Benchmark" and 'multibench_irt' in mse_df.columns:
                row_data = mse_df[mse_df['Train_Ratio'] == max_ratio].iloc[0]
                mse_val = row_data['multibench_irt']
            else:
                mse_val = "N/A"
            comparison_data[i]['MSE'] = mse_val
    
    # Create and save the comparison table
    df_comparison = pd.DataFrame(comparison_data)
    output_file = os.path.join(output_dir, "detailed_performance_comparison.csv")
    df_comparison.to_csv(output_file, index=False)
    print(f"Saved detailed performance comparison to: {output_file}")

def main():
    # Define paths
    mixed_dir = "yourpath/result_mixed_benchmark"
    single_dir = "yourpath/result_single_benchmark"
    improved_dir = "yourpath/result_improved_mixed_benchmark"
    output_dir = "yourpath/comparison_results"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading metrics data...")
    
    # Load AUC data
    mixed_auc = load_metrics_data(mixed_dir, "mse")  # Using MSE as it contains the mixed benchmark data
    single_auc = load_metrics_data(single_dir, "auc")
    improved_auc = load_metrics_data(improved_dir, "auc")
    
    # Create comparison plots
    print("Creating comparison plots...")
    create_comparison_plot(mixed_auc, single_auc, improved_auc, "auc", output_dir)
    
    # Create combined comparison plot
    print("Creating combined comparison plot...")
    create_combined_comparison_plot(mixed_auc, single_auc, improved_auc, output_dir)
    
    # Create performance comparison table
    print("Creating performance comparison table...")
    create_performance_table(mixed_auc, single_auc, improved_auc, output_dir)
    
    print(f"\nAll comparison results saved to: {output_dir}")

if __name__ == "__main__":
    main()