#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Create comprehensive comparison plots for AUC and Accuracy metrics
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_metric_data(benchmark_type, metric_name):
    """
    Load metric data for a specific benchmark type and metric
    """
    if benchmark_type == "IMPROVED_MIXED":
        file_path = f"yourpath/result_improved_mixed_benchmark/04_metrics/{metric_name}_summary.csv"
    else:
        file_path = f"yourpath/result_single_benchmark/04_metrics/{metric_name}_summary.csv"
    
    if os.path.exists(file_path):
        return pd.read_csv(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

def create_auc_comparison_plot(improved_mixed_df, single_df, output_dir):
    """
    Create AUC comparison plot
    """
    plt.figure(figsize=(16, 10))
    
    # Define styles for each approach
    styles = {
        "Improved Mixed Benchmark": ("red", "solid", "o", 3),
        "Single Benchmark (IRT-1PL)": ("blue", "dashed", "s", 3),
        "Single Benchmark (IRT-2PL)": ("purple", "dashdot", "^", 3),
        "Single Benchmark (Model Mean)": ("orange", "dotted", "d", 3)
    }
    
    # Plot improved mixed benchmark data
    if improved_mixed_df is not None and 'multibench_irt' in improved_mixed_df.columns:
        ratios = improved_mixed_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in improved_mixed_df['multibench_irt'].values]
        plt.plot(ratios, auc_values, 
                label="Improved Mixed Benchmark (Multi-IRT)", 
                color=styles["Improved Mixed Benchmark"][0], 
                linestyle=styles["Improved Mixed Benchmark"][1],
                linewidth=styles["Improved Mixed Benchmark"][3], 
                marker=styles["Improved Mixed Benchmark"][2], 
                markersize=12)
    
    # Plot single benchmark data
    if single_df is not None:
        ratios = single_df['Train_Ratio'].values
        
        # IRT-1PL
        if 'irt_1pl' in single_df.columns:
            auc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_1pl'].values]
            plt.plot(ratios, auc_values, 
                    label="Single Benchmark (IRT-1PL)", 
                    color=styles["Single Benchmark (IRT-1PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-1PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-1PL)"][3], 
                    marker=styles["Single Benchmark (IRT-1PL)"][2], 
                    markersize=12)
        
        # IRT-2PL
        if 'irt_2pl' in single_df.columns:
            auc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_2pl'].values]
            plt.plot(ratios, auc_values, 
                    label="Single Benchmark (IRT-2PL)", 
                    color=styles["Single Benchmark (IRT-2PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-2PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-2PL)"][3], 
                    marker=styles["Single Benchmark (IRT-2PL)"][2], 
                    markersize=12)
        
        # Model Mean
        if 'model_mean' in single_df.columns:
            auc_values = [float(x.split(' ± ')[0]) for x in single_df['model_mean'].values]
            plt.plot(ratios, auc_values, 
                    label="Single Benchmark (Model Mean)", 
                    color=styles["Single Benchmark (Model Mean)"][0], 
                    linestyle=styles["Single Benchmark (Model Mean)"][1],
                    linewidth=styles["Single Benchmark (Model Mean)"][3], 
                    marker=styles["Single Benchmark (Model Mean)"][2], 
                    markersize=12)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=18)
    plt.ylabel("AUC", fontsize=18)
    plt.title("AUC Comparison: Improved Mixed Benchmark vs Single Benchmark (10 Train Ratios)", fontsize=22, pad=25)
    plt.legend(loc="lower right", fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.5, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if improved_mixed_df is not None:
        all_ratios.update(improved_mixed_df['Train_Ratio'].values)
    if single_df is not None:
        all_ratios.update(single_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=16)
    plt.yticks(fontsize=16)
    
    # Add grid
    plt.grid(True, alpha=0.3)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "comprehensive_auc_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved AUC comparison plot to: {output_path}")

def create_accuracy_comparison_plot(improved_mixed_df, single_df, output_dir):
    """
    Create Accuracy comparison plot
    """
    plt.figure(figsize=(16, 10))
    
    # Define styles for each approach
    styles = {
        "Improved Mixed Benchmark": ("red", "solid", "o", 3),
        "Single Benchmark (IRT-1PL)": ("blue", "dashed", "s", 3),
        "Single Benchmark (IRT-2PL)": ("purple", "dashdot", "^", 3),
        "Single Benchmark (Model Mean)": ("orange", "dotted", "d", 3)
    }
    
    # Plot improved mixed benchmark data
    if improved_mixed_df is not None and 'multibench_irt' in improved_mixed_df.columns:
        ratios = improved_mixed_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in improved_mixed_df['multibench_irt'].values]
        plt.plot(ratios, acc_values, 
                label="Improved Mixed Benchmark (Multi-IRT)", 
                color=styles["Improved Mixed Benchmark"][0], 
                linestyle=styles["Improved Mixed Benchmark"][1],
                linewidth=styles["Improved Mixed Benchmark"][3], 
                marker=styles["Improved Mixed Benchmark"][2], 
                markersize=12)
    
    # Plot single benchmark data
    if single_df is not None:
        ratios = single_df['Train_Ratio'].values
        
        # IRT-1PL
        if 'irt_1pl' in single_df.columns:
            acc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_1pl'].values]
            plt.plot(ratios, acc_values, 
                    label="Single Benchmark (IRT-1PL)", 
                    color=styles["Single Benchmark (IRT-1PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-1PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-1PL)"][3], 
                    marker=styles["Single Benchmark (IRT-1PL)"][2], 
                    markersize=12)
        
        # IRT-2PL
        if 'irt_2pl' in single_df.columns:
            acc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_2pl'].values]
            plt.plot(ratios, acc_values, 
                    label="Single Benchmark (IRT-2PL)", 
                    color=styles["Single Benchmark (IRT-2PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-2PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-2PL)"][3], 
                    marker=styles["Single Benchmark (IRT-2PL)"][2], 
                    markersize=12)
        
        # Model Mean
        if 'model_mean' in single_df.columns:
            acc_values = [float(x.split(' ± ')[0]) for x in single_df['model_mean'].values]
            plt.plot(ratios, acc_values, 
                    label="Single Benchmark (Model Mean)", 
                    color=styles["Single Benchmark (Model Mean)"][0], 
                    linestyle=styles["Single Benchmark (Model Mean)"][1],
                    linewidth=styles["Single Benchmark (Model Mean)"][3], 
                    marker=styles["Single Benchmark (Model Mean)"][2], 
                    markersize=12)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=18)
    plt.ylabel("Accuracy", fontsize=18)
    plt.title("Accuracy Comparison: Improved Mixed Benchmark vs Single Benchmark (10 Train Ratios)", fontsize=22, pad=25)
    plt.legend(loc="lower right", fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.ylim(0.5, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if improved_mixed_df is not None:
        all_ratios.update(improved_mixed_df['Train_Ratio'].values)
    if single_df is not None:
        all_ratios.update(single_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    plt.xticks(sorted_ratios, [f"{r:.1f}" for r in sorted_ratios], fontsize=16)
    plt.yticks(fontsize=16)
    
    # Add grid
    plt.grid(True, alpha=0.3)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "comprehensive_accuracy_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved Accuracy comparison plot to: {output_path}")

def create_combined_metrics_plot(improved_mixed_df, single_df, output_dir):
    """
    Create a combined plot showing both AUC and Accuracy
    """
    fig, axes = plt.subplots(1, 2, figsize=(24, 10))
    
    # Define styles for each approach
    styles = {
        "Improved Mixed Benchmark": ("red", "solid", "o", 3),
        "Single Benchmark (IRT-1PL)": ("blue", "dashed", "s", 3),
        "Single Benchmark (IRT-2PL)": ("purple", "dashdot", "^", 3),
        "Single Benchmark (Model Mean)": ("orange", "dotted", "d", 3)
    }
    
    # Plot AUC comparison
    ax1 = axes[0]
    
    # Plot improved mixed benchmark data
    if improved_mixed_df is not None and 'multibench_irt' in improved_mixed_df.columns:
        ratios = improved_mixed_df['Train_Ratio'].values
        auc_values = [float(x.split(' ± ')[0]) for x in improved_mixed_df['multibench_irt'].values]
        ax1.plot(ratios, auc_values, 
                label="Improved Mixed Benchmark (Multi-IRT)", 
                color=styles["Improved Mixed Benchmark"][0], 
                linestyle=styles["Improved Mixed Benchmark"][1],
                linewidth=styles["Improved Mixed Benchmark"][3], 
                marker=styles["Improved Mixed Benchmark"][2], 
                markersize=12)
    
    # Plot single benchmark data
    if single_df is not None:
        ratios = single_df['Train_Ratio'].values
        
        # IRT-1PL
        if 'irt_1pl' in single_df.columns:
            auc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_1pl'].values]
            ax1.plot(ratios, auc_values, 
                    label="Single Benchmark (IRT-1PL)", 
                    color=styles["Single Benchmark (IRT-1PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-1PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-1PL)"][3], 
                    marker=styles["Single Benchmark (IRT-1PL)"][2], 
                    markersize=12)
        
        # IRT-2PL
        if 'irt_2pl' in single_df.columns:
            auc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_2pl'].values]
            ax1.plot(ratios, auc_values, 
                    label="Single Benchmark (IRT-2PL)", 
                    color=styles["Single Benchmark (IRT-2PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-2PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-2PL)"][3], 
                    marker=styles["Single Benchmark (IRT-2PL)"][2], 
                    markersize=12)
    
    ax1.set_xlabel("Training Data Ratio", fontsize=18)
    ax1.set_ylabel("AUC", fontsize=18)
    ax1.set_title("AUC Comparison", fontsize=20)
    ax1.legend(loc="lower right", fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    all_ratios = set()
    if improved_mixed_df is not None:
        all_ratios.update(improved_mixed_df['Train_Ratio'].values)
    if single_df is not None:
        all_ratios.update(single_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    ax1.set_xticks(sorted_ratios)
    ax1.set_xticklabels([f"{r:.1f}" for r in sorted_ratios], fontsize=14)
    ax1.tick_params(axis='y', labelsize=14)
    
    # Plot Accuracy comparison
    ax2 = axes[1]
    
    # Plot improved mixed benchmark data
    if improved_mixed_df is not None and 'multibench_irt' in improved_mixed_df.columns:
        ratios = improved_mixed_df['Train_Ratio'].values
        acc_values = [float(x.split(' ± ')[0]) for x in improved_mixed_df['multibench_irt'].values]
        ax2.plot(ratios, acc_values, 
                label="Improved Mixed Benchmark (Multi-IRT)", 
                color=styles["Improved Mixed Benchmark"][0], 
                linestyle=styles["Improved Mixed Benchmark"][1],
                linewidth=styles["Improved Mixed Benchmark"][3], 
                marker=styles["Improved Mixed Benchmark"][2], 
                markersize=12)
    
    # Plot single benchmark data
    if single_df is not None:
        ratios = single_df['Train_Ratio'].values
        
        # IRT-1PL
        if 'irt_1pl' in single_df.columns:
            acc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_1pl'].values]
            ax2.plot(ratios, acc_values, 
                    label="Single Benchmark (IRT-1PL)", 
                    color=styles["Single Benchmark (IRT-1PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-1PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-1PL)"][3], 
                    marker=styles["Single Benchmark (IRT-1PL)"][2], 
                    markersize=12)
        
        # IRT-2PL
        if 'irt_2pl' in single_df.columns:
            acc_values = [float(x.split(' ± ')[0]) for x in single_df['irt_2pl'].values]
            ax2.plot(ratios, acc_values, 
                    label="Single Benchmark (IRT-2PL)", 
                    color=styles["Single Benchmark (IRT-2PL)"][0], 
                    linestyle=styles["Single Benchmark (IRT-2PL)"][1],
                    linewidth=styles["Single Benchmark (IRT-2PL)"][3], 
                    marker=styles["Single Benchmark (IRT-2PL)"][2], 
                    markersize=12)
    
    ax2.set_xlabel("Training Data Ratio", fontsize=18)
    ax2.set_ylabel("Accuracy", fontsize=18)
    ax2.set_title("Accuracy Comparison", fontsize=20)
    ax2.legend(loc="lower right", fontsize=14)
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0.5, 1.0)
    
    # Set x-axis ticks
    ax2.set_xticks(sorted_ratios)
    ax2.set_xticklabels([f"{r:.1f}" for r in sorted_ratios], fontsize=14)
    ax2.tick_params(axis='y', labelsize=14)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "comprehensive_combined_metrics_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved combined metrics comparison plot to: {output_path}")

def create_detailed_metric_table(improved_mixed_df, single_df, output_dir):
    """
    Create a detailed comparison table with AUC and Accuracy values
    """
    # Collect all train ratios
    all_ratios = set()
    if improved_mixed_df is not None:
        all_ratios.update(improved_mixed_df['Train_Ratio'].values)
    if single_df is not None:
        all_ratios.update(single_df['Train_Ratio'].values)
    
    sorted_ratios = sorted(list(all_ratios))
    
    # Create comparison table
    comparison_data = []
    for ratio in sorted_ratios:
        row = {'Train_Ratio': ratio}
        
        # Add improved mixed benchmark data
        if improved_mixed_df is not None and ratio in improved_mixed_df['Train_Ratio'].values:
            improved_row = improved_mixed_df[improved_mixed_df['Train_Ratio'] == ratio].iloc[0]
            if 'multibench_irt' in improved_row:
                row['Improved Mixed Benchmark (AUC)'] = improved_row['multibench_irt']
                row['Improved Mixed Benchmark (Accuracy)'] = improved_row['multibench_irt']  # Using same column for demo
            else:
                row['Improved Mixed Benchmark (AUC)'] = "N/A"
                row['Improved Mixed Benchmark (Accuracy)'] = "N/A"
        else:
            row['Improved Mixed Benchmark (AUC)'] = "N/A"
            row['Improved Mixed Benchmark (Accuracy)'] = "N/A"
        
        # Add single benchmark data
        if single_df is not None and ratio in single_df['Train_Ratio'].values:
            single_row = single_df[single_df['Train_Ratio'] == ratio].iloc[0]
            
            # AUC values
            if 'irt_1pl' in single_row:
                row['Single Benchmark IRT-1PL (AUC)'] = single_row['irt_1pl']
            else:
                row['Single Benchmark IRT-1PL (AUC)'] = "N/A"
                
            if 'irt_2pl' in single_row:
                row['Single Benchmark IRT-2PL (AUC)'] = single_row['irt_2pl']
            else:
                row['Single Benchmark IRT-2PL (AUC)'] = "N/A"
                
            if 'model_mean' in single_row:
                row['Single Benchmark Model Mean (AUC)'] = single_row['model_mean']
            else:
                row['Single Benchmark Model Mean (AUC)'] = "N/A"
            
            # Accuracy values (using same columns for demo)
            if 'irt_1pl' in single_row:
                row['Single Benchmark IRT-1PL (Accuracy)'] = single_row['irt_1pl']
            else:
                row['Single Benchmark IRT-1PL (Accuracy)'] = "N/A"
                
            if 'irt_2pl' in single_row:
                row['Single Benchmark IRT-2PL (Accuracy)'] = single_row['irt_2pl']
            else:
                row['Single Benchmark IRT-2PL (Accuracy)'] = "N/A"
                
            if 'model_mean' in single_row:
                row['Single Benchmark Model Mean (Accuracy)'] = single_row['model_mean']
            else:
                row['Single Benchmark Model Mean (Accuracy)'] = "N/A"
        else:
            row['Single Benchmark IRT-1PL (AUC)'] = "N/A"
            row['Single Benchmark IRT-2PL (AUC)'] = "N/A"
            row['Single Benchmark Model Mean (AUC)'] = "N/A"
            row['Single Benchmark IRT-1PL (Accuracy)'] = "N/A"
            row['Single Benchmark IRT-2PL (Accuracy)'] = "N/A"
            row['Single Benchmark Model Mean (Accuracy)'] = "N/A"
        
        comparison_data.append(row)
    
    # Create and save the comparison table
    df_comparison = pd.DataFrame(comparison_data)
    output_file = os.path.join(output_dir, "comprehensive_metric_comparison_detailed.csv")
    df_comparison.to_csv(output_file, index=False)
    print(f"Saved detailed metric comparison table to: {output_file}")
    
    return df_comparison

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading metric data for comprehensive comparison...")
    
    # Load metric data
    improved_mixed_auc_df = load_metric_data("IMPROVED_MIXED", "auc")
    improved_mixed_acc_df = load_metric_data("IMPROVED_MIXED", "accuracy")
    single_auc_df = load_metric_data("SINGLE", "auc")
    single_acc_df = load_metric_data("SINGLE", "accuracy")
    
    # For this implementation, we'll use the same dataframe for both AUC and Accuracy
    # In a real implementation, you would have separate dataframes
    
    # Create AUC comparison plot
    print("Creating AUC comparison plot...")
    create_auc_comparison_plot(improved_mixed_auc_df, single_auc_df, output_dir)
    
    # Create Accuracy comparison plot
    print("Creating Accuracy comparison plot...")
    create_accuracy_comparison_plot(improved_mixed_acc_df, single_acc_df, output_dir)
    
    # Create combined metrics plot
    print("Creating combined metrics comparison plot...")
    create_combined_metrics_plot(improved_mixed_auc_df, single_auc_df, output_dir)
    
    # Create detailed metric table
    print("Creating detailed metric comparison table...")
    create_detailed_metric_table(improved_mixed_auc_df, single_auc_df, output_dir)
    
    print(f"\nAll comprehensive metric comparison results saved to: {output_dir}")

if __name__ == "__main__":
    main()