#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualization script to compare multi-benchmark joint modeling vs single-benchmark modeling results
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_mse_data(mixed_benchmark_dir, single_benchmark_dir):
    """
    Load MSE data from both result directories
    """
    # Load multi-benchmark results
    mixed_metrics_dir = os.path.join(mixed_benchmark_dir, "04_metrics")
    mixed_mse_file = os.path.join(mixed_metrics_dir, "mse_summary.csv")
    mixed_df = pd.read_csv(mixed_mse_file)
    
    # Load single-benchmark results
    single_metrics_dir = os.path.join(single_benchmark_dir, "04_metrics")
    single_mse_file = os.path.join(single_metrics_dir, "mse_summary.csv")
    single_df = pd.read_csv(single_mse_file)
    
    return mixed_df, single_df

def parse_mse_values(df, method_column):
    """
    Parse MSE values from "mean ± std" format to separate mean and std
    """
    means = []
    stds = []
    for val in df[method_column]:
        if '±' in val:
            mean, std = val.split(' ± ')
            means.append(float(mean))
            stds.append(float(std))
        else:
            means.append(float(val))
            stds.append(0.0)
    return np.array(means), np.array(stds)

def create_mse_comparison_plot(mixed_df, single_df, output_dir):
    """
    Create MSE comparison plot between multi-benchmark and single-benchmark approaches
    """
    train_ratios = mixed_df["Train_Ratio"].values
    
    # Extract MSE values for multi-benchmark approach
    mixed_mse_means, mixed_mse_stds = parse_mse_values(mixed_df, "Multi-benchmark IRT")
    
    # Extract MSE values for single-benchmark approaches
    single_1pl_means, single_1pl_stds = parse_mse_values(single_df, "IRT-1PL")
    single_2pl_means, single_2pl_stds = parse_mse_values(single_df, "IRT-2PL")
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot multi-benchmark IRT
    plt.errorbar(train_ratios, mixed_mse_means, yerr=mixed_mse_stds, 
                 label="Multi-benchmark IRT", color="red", marker="o", 
                 linewidth=2, markersize=6, capsize=3)
    
    # Plot single-benchmark IRT-1PL
    plt.errorbar(train_ratios, single_1pl_means, yerr=single_1pl_stds, 
                 label="Single-benchmark IRT-1PL", color="blue", marker="s", 
                 linewidth=2, markersize=6, capsize=3, linestyle="--")
    
    # Plot single-benchmark IRT-2PL
    plt.errorbar(train_ratios, single_2pl_means, yerr=single_2pl_stds, 
                 label="Single-benchmark IRT-2PL", color="green", marker="^", 
                 linewidth=2, markersize=6, capsize=3, linestyle="-.")
    
    # Add baseline methods from multi-benchmark (as reference)
    global_means, _ = parse_mse_values(mixed_df, "Global Mean")
    model_means, _ = parse_mse_values(mixed_df, "Model Mean")
    question_means, _ = parse_mse_values(mixed_df, "Question Mean")
    
    plt.plot(train_ratios, global_means, label="Global Mean", color="gray", 
             linewidth=1, linestyle=":", alpha=0.7)
    plt.plot(train_ratios, model_means, label="Model Mean", color="orange", 
             linewidth=1, linestyle=":", alpha=0.7)
    plt.plot(train_ratios, question_means, label="Question Mean", color="purple", 
             linewidth=1, linestyle=":", alpha=0.7)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=12)
    plt.ylabel("Test Set MSE", fontsize=12)
    plt.title("MSE Comparison: Multi-benchmark vs Single-benchmark IRT Models", fontsize=14)
    plt.legend(loc="upper right", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "mse_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"MSE comparison plot saved to: {output_path}")
    return output_path

def create_relative_improvement_plot(mixed_df, single_df, output_dir):
    """
    Create plot showing relative improvement of multi-benchmark approach over single-benchmark approaches
    """
    train_ratios = mixed_df["Train_Ratio"].values
    
    # Extract MSE values
    mixed_mse_means, _ = parse_mse_values(mixed_df, "Multi-benchmark IRT")
    single_1pl_means, _ = parse_mse_values(single_df, "IRT-1PL")
    single_2pl_means, _ = parse_mse_values(single_df, "IRT-2PL")
    
    # Calculate relative improvements (%)
    improvement_1pl = (single_1pl_means - mixed_mse_means) / single_1pl_means * 100
    improvement_2pl = (single_2pl_means - mixed_mse_means) / single_2pl_means * 100
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Plot improvements
    plt.plot(train_ratios, improvement_1pl, label="Improvement over IRT-1PL", 
             color="blue", marker="s", linewidth=2, markersize=6)
    plt.plot(train_ratios, improvement_2pl, label="Improvement over IRT-2PL", 
             color="green", marker="^", linewidth=2, markersize=6, linestyle="--")
    
    # Add zero line for reference
    plt.axhline(y=0, color="red", linestyle="-", linewidth=1, alpha=0.7)
    
    # Formatting
    plt.xlabel("Training Data Ratio", fontsize=12)
    plt.ylabel("Relative Improvement (%)", fontsize=12)
    plt.title("Relative Improvement of Multi-benchmark IRT over Single-benchmark IRT", fontsize=14)
    plt.legend(loc="upper left", fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "relative_improvement.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"Relative improvement plot saved to: {output_path}")
    return output_path

def print_numerical_comparison(mixed_df, single_df):
    """
    Print numerical comparison of results
    """
    print("\n" + "="*80)
    print("NUMERICAL COMPARISON OF MULTI-BENCHMARK VS SINGLE-BENCHMARK IRT MODELS")
    print("="*80)
    
    train_ratios = mixed_df["Train_Ratio"].values
    
    print(f"{'Train Ratio':<12} {'Multi-IRT':<12} {'Single-1PL':<12} {'Single-2PL':<12} {'Improvement 1PL':<18} {'Improvement 2PL':<18}")
    print("-" * 80)
    
    for i, ratio in enumerate(train_ratios):
        # Extract values
        mixed_val, _ = parse_mse_values(mixed_df.iloc[[i]], "Multi-benchmark IRT")
        single_1pl_val, _ = parse_mse_values(single_df.iloc[[i]], "IRT-1PL")
        single_2pl_val, _ = parse_mse_values(single_df.iloc[[i]], "IRT-2PL")
        
        mixed_mse = mixed_val[0]
        single_1pl_mse = single_1pl_val[0]
        single_2pl_mse = single_2pl_val[0]
        
        # Calculate improvements
        imp_1pl = (single_1pl_mse - mixed_mse) / single_1pl_mse * 100
        imp_2pl = (single_2pl_mse - mixed_mse) / single_2pl_mse * 100
        
        print(f"{ratio:<12.1f} {mixed_mse:<12.6f} {single_1pl_mse:<12.6f} {single_2pl_mse:<12.6f} {imp_1pl:<18.2f}% {imp_2pl:<18.2f}%")
    
    print("-" * 80)
    
    # Calculate average improvements
    mixed_means, _ = parse_mse_values(mixed_df, "Multi-benchmark IRT")
    single_1pl_means, _ = parse_mse_values(single_df, "IRT-1PL")
    single_2pl_means, _ = parse_mse_values(single_df, "IRT-2PL")
    
    avg_imp_1pl = np.mean((single_1pl_means - mixed_means) / single_1pl_means * 100)
    avg_imp_2pl = np.mean((single_2pl_means - mixed_means) / single_2pl_means * 100)
    
    print(f"{'Average':<12} {'-':<12} {'-':<12} {'-':<12} {avg_imp_1pl:<18.2f}% {avg_imp_2pl:<18.2f}%")

def main():
    # Define paths
    mixed_benchmark_dir = "yourpath/result_mixed_benchmark"
    single_benchmark_dir = "yourpath/result_single_benchmark"
    output_dir = "yourpath/comparison_results"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    print("Loading MSE data...")
    mixed_df, single_df = load_mse_data(mixed_benchmark_dir, single_benchmark_dir)
    
    # Create visualizations
    print("Creating MSE comparison plot...")
    mse_plot_path = create_mse_comparison_plot(mixed_df, single_df, output_dir)
    
    print("Creating relative improvement plot...")
    improvement_plot_path = create_relative_improvement_plot(mixed_df, single_df, output_dir)
    
    # Print numerical comparison
    print_numerical_comparison(mixed_df, single_df)
    
    print(f"\nAll comparison results saved to: {output_dir}")
    print(f"MSE comparison plot: {mse_plot_path}")
    print(f"Relative improvement plot: {improvement_plot_path}")

if __name__ == "__main__":
    main()