import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
from matplotlib.colors import ListedColormap

# Configure matplotlib with basic LaTeX rendering
plt.rcParams.update({
    "text.usetex": True,
    "font.family": ["serif"],
    "font.serif": ["Times New Roman"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
    "lines.linewidth": 2,  # Slightly thicker lines for better visibility
    "lines.markersize": 5,
})

# Color, marker, and line style configuration
_COLOR_MAPS = {
    "blue": "#1f77b4",
    "orange": "#ff7f0e",
    "green": "#2ca02c",
    "red": "#d62728",
    "purple": "#9467bd",
    "brown": "#8c564b"
}

_MARKER_STYLES = {
    "circle": "o",
    "square": "s",
    "triangle": "^",
    "diamond": "D",
    "inverted_triangle": "v",
    "cross": "x"
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"
}

def load_data_with_ci(csv_path):
    """Load data with confidence intervals, parse ± symbol"""
    df = pd.read_csv(csv_path)
    
    # Original column names to preserve
    original_columns = [
        "Train_Ratio", 
        "Global Mean", 
        "Model Mean", 
        "Question Mean", 
        #"Mixed-metric IRT",
        "Simple Multi-metric IRT (Aux)", 
        "1PL IRT (bertscore_F1)",
        #"Cite IRT (bertscore_F1)"
    ]
    
    # Keep only existing columns
    existing_columns = [col for col in original_columns if col in df.columns]
    df = df[existing_columns]
    
    # Dictionary to store means and errors
    mean_data = {"Train_Ratio": df["Train_Ratio"].values}
    ci_data = {}
    
    for col in existing_columns[1:]:
        split_vals = df[col].str.split('±', expand=True)
        mean_data[col] = split_vals[0].str.strip().astype(float).values
        ci_data[col] = split_vals[1].str.strip().astype(float).values
    
    return mean_data, ci_data

def get_method_styles():
    """Define visual styles for each method"""
    return {
        "Global Mean": (
            _COLOR_MAPS["blue"], 
            _MARKER_STYLES["circle"], 
            _LINE_STYLES["solid"],
            "Global Mean"
        ),
        "Model Mean": (
            _COLOR_MAPS["orange"], 
            _MARKER_STYLES["square"], 
            _LINE_STYLES["solid"],
            "Model Mean"
        ),
        "Question Mean": (
            _COLOR_MAPS["green"], 
            _MARKER_STYLES["triangle"], 
            _LINE_STYLES["solid"],
            "Question Mean"
        ),
        # "Mixed-metric IRT": (
        #     _COLOR_MAPS["red"], 
        #     _MARKER_STYLES["diamond"], 
        #     _LINE_STYLES["dashed"],
        #     "Mixed-metric IRT"
        # ),
        "Multi-metric IRT": (
            _COLOR_MAPS["brown"], 
            _MARKER_STYLES["cross"], 
            _LINE_STYLES["dashdot"],
            "Simple Multi-metric IRT (Aux)"
        ),
        "1PL IRT": (
            _COLOR_MAPS["purple"], 
            _MARKER_STYLES["inverted_triangle"], 
            _LINE_STYLES["dashed"],
            "1PL IRT (bertscore_F1)"
        ),
        # "Cite IRT": (
        #     "#FF6B6B", 
        #     _MARKER_STYLES["circle"], 
        #     _LINE_STYLES["dotted"],
        #     "Cite IRT (bertscore_F1)"
        # )
    }

def create_comparison_plot(mean_data_1, ci_data_1, mean_data_2, ci_data_2, method_styles, save_dir="plots"):
    """Create comparison plot, place two charts side by side, legend in the middle"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Create a large figure containing two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Get training ratios
    train_ratio_1 = mean_data_1["Train_Ratio"]
    train_ratio_2 = mean_data_2["Train_Ratio"]
    
    # Plot the first dataset
    for label, (color, marker, linestyle, original_col) in method_styles.items():
        # Check if column exists
        if original_col in mean_data_1:
            means = mean_data_1[original_col]
            cis = ci_data_1[original_col]
            
            # Plot mean curve
            ax1.plot(
                train_ratio_1, 
                means, 
                label=label,
                color=color,
                marker=marker,
                linestyle=linestyle,
                alpha=0.9
            )
            
            # Plot confidence interval
            ax1.fill_between(
                train_ratio_1,
                means - cis,
                means + cis,
                color=color,
                alpha=0.15,
                linewidth=0
            )
    
    # Set labels and title for the first subplot
    ax1.set_xlabel(r"Training Data Ratio ($Train\_Ratio$)", labelpad=10)
    ax1.set_ylabel(r"Test Set MSE (Lower = Better Performance)", labelpad=10)
    ax1.set_title(r"MSE Variation with Training Data Ratio (XSUM Dataset)", 
                 pad=15, fontweight="bold")
    
    # Configure ticks
    ax1.set_xticks(train_ratio_1)
    ax1.set_xticklabels([f"{x:.1f}" for x in train_ratio_1])
    
    # Set y-axis range
    all_means_1 = []
    all_cis_1 = []
    for _, (_, _, _, original_col) in method_styles.items():
        if original_col in mean_data_1:
            all_means_1.extend(mean_data_1[original_col])
            all_cis_1.extend(ci_data_1[original_col])
    
    if all_means_1 and all_cis_1:
        all_means_np_1 = np.array(all_means_1)
        all_cis_np_1 = np.array(all_cis_1)
        y_min_1 = np.min(all_means_np_1 - all_cis_np_1) * 0.9
        y_max_1 = np.max(all_means_np_1 + all_cis_np_1) * 1.1
        ax1.set_ylim(y_min_1, y_max_1)
    
    # Add grid
    ax1.grid(True, alpha=0.3, linestyle="--")
    
    # Plot the second dataset
    for label, (color, marker, linestyle, original_col) in method_styles.items():
        # Check if column exists
        if original_col in mean_data_2:
            means = mean_data_2[original_col]
            cis = ci_data_2[original_col]
            
            # Plot mean curve
            ax2.plot(
                train_ratio_2, 
                means, 
                label=label,
                color=color,
                marker=marker,
                linestyle=linestyle,
                alpha=0.9
            )
            
            # Plot confidence interval
            ax2.fill_between(
                train_ratio_2,
                means - cis,
                means + cis,
                color=color,
                alpha=0.15,
                linewidth=0
            )
    
    # Set labels and title for the second subplot
    ax2.set_xlabel(r"Training Data Ratio ($Train\_Ratio$)", labelpad=10)
    ax2.set_ylabel(r"Test Set MSE (Lower = Better Performance)", labelpad=10)
    ax2.set_title(r"MSE Variation with Training Data Ratio (WMT Dataset)", 
                 pad=15, fontweight="bold")
    
    # Configure ticks
    ax2.set_xticks(train_ratio_2)
    ax2.set_xticklabels([f"{x:.1f}" for x in train_ratio_2])
    
    # Set y-axis range
    all_means_2 = []
    all_cis_2 = []
    for _, (_, _, _, original_col) in method_styles.items():
        if original_col in mean_data_2:
            all_means_2.extend(mean_data_2[original_col])
            all_cis_2.extend(ci_data_2[original_col])
    
    if all_means_2 and all_cis_2:
        all_means_np_2 = np.array(all_means_2)
        all_cis_np_2 = np.array(all_cis_2)
        y_min_2 = np.min(all_means_np_2 - all_cis_np_2) * 0.9
        y_max_2 = np.max(all_means_np_2 + all_cis_np_2) * 1.1
        ax2.set_ylim(y_min_2, y_max_2)
    
    # Add grid
    ax2.grid(True, alpha=0.3, linestyle="--")
    
    # Add legend between the two subplots
    # Collect all labels and corresponding styles
    handles = []
    labels = []
    for label, (color, marker, linestyle, original_col) in method_styles.items():
        # Add to legend as long as the method exists in one of the datasets
        if original_col in mean_data_1 or original_col in mean_data_2:
            handles.append(plt.Line2D([0], [0], color=color, marker=marker, linestyle=linestyle))
            labels.append(label)
    
    # Add legend at the bottom, increase bbox_to_anchor's y value to avoid overlapping with x-axis
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.0), ncol=4, frameon=True, fancybox=True, shadow=True)
    
    # Adjust layout, increase bottom margin to accommodate legend
    plt.tight_layout(rect=[0, 0.1, 1, 1])
    
    # Save plot
    timestamp = int(time.time())
    file_path = os.path.join(save_dir, f"mse_comparison_{timestamp}.pdf")
    plt.savefig(file_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"Comparison plot saved to: {file_path}")
    return file_path

def main():
    """Main function: Load data and plot"""
    # Define paths for two datasets
    csv_path_1 = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/mix_metric_v1/04_metrics/mse_summary.csv"
    csv_path_2 = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/exp_0918_mix_metric/mix_metric_v1_wmt/04_metrics/mse_summary.csv"
    
    # Load data
    mean_data_1, ci_data_1 = load_data_with_ci(csv_path_1)
    mean_data_2, ci_data_2 = load_data_with_ci(csv_path_2)
    
    # Get method styles
    method_styles = get_method_styles()
    
    # Create comparison plot
    create_comparison_plot(mean_data_1, ci_data_1, mean_data_2, ci_data_2, method_styles, "plots")

if __name__ == "__main__":
    main()