#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Plot Global Mean, Model Mean, and Question Mean MSE values as training ratio changes.
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Global Config
plt.rcParams["font.family"] = ["Arial", "Helvetica"]
plt.rcParams["axes.unicode_minus"] = False
plt.rcParams["figure.dpi"] = 150

def plot_mean_mse_comparison(csv_path, save_path):
    """
    Plot Global Mean, Model Mean, and Question Mean MSE values as training ratio changes.
    """
    # Read data
    df = pd.read_csv(csv_path)
    
    # Extract data
    train_ratios = df['Train_Ratio'].values
    global_means = []
    model_means = []
    question_means = []
    
    # Parse mean values (handle "mean ± std" format)
    for _, row in df.iterrows():
        global_mean_str = row['Global Mean (MSE)']
        model_mean_str = row['Model Mean (MSE)']
        question_mean_str = row['Question Mean (MSE)']
        
        # Extract mean value (before ±)
        global_mean = float(global_mean_str.split(' ± ')[0])
        model_mean = float(model_mean_str.split(' ± ')[0])
        question_mean = float(question_mean_str.split(' ± ')[0])
        
        global_means.append(global_mean)
        model_means.append(model_mean)
        question_means.append(question_mean)
    
    # Convert to numpy arrays
    global_means = np.array(global_means)
    model_means = np.array(model_means)
    question_means = np.array(question_means)
    
    # Create plot
    plt.figure(figsize=(10, 6))
    
    # Plot lines
    plt.plot(train_ratios, global_means, label='Global Mean', color='blue', marker='o', linewidth=2)
    plt.plot(train_ratios, model_means, label='Model Mean', color='orange', marker='s', linewidth=2)
    plt.plot(train_ratios, question_means, label='Question Mean', color='green', marker='^', linewidth=2)
    
    # Formatting
    plt.xlabel('Training Data Ratio', fontsize=12)
    plt.ylabel('Test Set MSE', fontsize=12)
    plt.title('MSE vs. Training Data Ratio for Mean Baselines', fontsize=14, pad=20)
    plt.legend(loc='upper right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.xticks(train_ratios, [f"{r:.1f}" for r in train_ratios], fontsize=10)
    plt.yticks(fontsize=10)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"MSE plot saved: {save_path}")

def main():
    # Paths
    mse_summary_path = "yourpath/result_improved_mixed_benchmark/04_metrics/mse_summary.csv"
    comparison_dir = "yourpath/comparison_results_rep3_improved"
    
    # Create plots for each benchmark
    benchmarks = ['ceval', 'csqa', 'mmlu']
    for benchmark in benchmarks:
        csv_path = os.path.join(comparison_dir, f"{benchmark}_metrics_mix_vs_single_rep3_improved.csv")
        plot_path = os.path.join(comparison_dir, f"{benchmark}_mean_mse_comparison.png")
        
        if os.path.exists(csv_path):
            plot_mean_mse_comparison(csv_path, plot_path)
        else:
            print(f"Warning: {csv_path} does not exist")
    
    print("All MSE comparison plots generated successfully!")

if __name__ == "__main__":
    main()