import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colormaps
from matplotlib.colors import ListedColormap

# Configure matplotlib with basic LaTeX rendering
plt.rcParams.update({
    "text.usetex": True,
    "font.family": ["serif"],
    "font.serif": ["Times New Roman"],
    "font.size": 10,
    "axes.linewidth": 0.8,
    "axes.labelsize": 12,
    "axes.titlesize": 14,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 10,
    "figure.dpi": 300,
    "figure.figsize": (10, 6),
    "lines.linewidth": 2,  # Slightly thicker lines for better visibility
    "lines.markersize": 5,
})

# Color, marker, and line style configuration (added dashdot line style)
_COLOR_MAPS = {
    "blue": "#1f77b4",
    "orange": "#ff7f0e",
    "green": "#2ca02c",
    "red": "#d62728",
    "purple": "#9467bd",
    "brown": "#8c564b"
}

_MARKER_STYLES = {
    "circle": "o",
    "square": "s",
    "triangle": "^",
    "diamond": "D",
    "inverted_triangle": "v",
    "cross": "x"
}

_LINE_STYLES = {
    "solid": "solid",
    "dashed": "dashed",
    "dotted": "dotted",
    "dashdot": "dashdot"  # Added dash-dot line style
}

def load_data_with_ci(csv_path):
    """Load data with confidence intervals, parse ± symbol"""
    df = pd.read_csv(csv_path)
    
    # Original column names to preserve (corresponding to CSV file)
    original_columns = [
        "Train_Ratio", 
        "Global Mean", 
        "Model Mean", 
        "Question Mean", 
        "Simple Multi-metric IRT (Aux)",  # Original column name
        "1PL IRT (bertscore_F1)",         # Original column name
        #"Cite IRT (bertscore_F1)"         # Original column name
    ]
    df = df[original_columns]
    
    # Dictionary to store means and errors (using original column names as keys)
    mean_data = {"Train_Ratio": df["Train_Ratio"].values}
    ci_data = {}
    
    for col in original_columns[1:]:
        split_vals = df[col].str.split('±', expand=True)
        mean_data[col] = split_vals[0].str.strip().astype(float).values
        ci_data[col] = split_vals[1].str.strip().astype(float).values
    
    return mean_data, ci_data

def get_method_styles():
    """Define visual styles for each method, update Tiny Benchmark IRT line style"""
    return {
        "Global Mean": (
            _COLOR_MAPS["blue"], 
            _MARKER_STYLES["circle"], 
            _LINE_STYLES["solid"],
            "Global Mean"  # Original data column name
        ),
        "Model Mean": (
            _COLOR_MAPS["orange"], 
            _MARKER_STYLES["square"], 
            _LINE_STYLES["solid"],
            "Model Mean"   # Original data column name
        ),
        "Question Mean": (
            _COLOR_MAPS["green"], 
            _MARKER_STYLES["triangle"], 
            _LINE_STYLES["solid"],
            "Question Mean"  # Original data column name
        ),
        "Multi-metric IRT":  # Updated name
            (
            _COLOR_MAPS["red"], 
            _MARKER_STYLES["diamond"], 
            _LINE_STYLES["dashed"],
            "Simple Multi-metric IRT (Aux)"  # Associated original column name
        ),
        "1PL IRT":  # Updated name
            (
            _COLOR_MAPS["purple"], 
            _MARKER_STYLES["inverted_triangle"], 
            _LINE_STYLES["dashed"],
            "1PL IRT (bertscore_F1)"  # Associated original column name
        ),
        # "Tiny Benchmark IRT":  # Updated name
        #     (
        #     _COLOR_MAPS["brown"], 
        #     _MARKER_STYLES["cross"], 
        #     _LINE_STYLES["dashdot"],  # Changed to dash-dot line (clearer)
        #     "Cite IRT (bertscore_F1)"  # Associated original column name
        # )
    }

def create_plot(mean_data, ci_data, method_styles, save_dir="plots"):
    """Create MSE variation plot with confidence intervals"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    fig, ax = plt.subplots()
    train_ratio = mean_data["Train_Ratio"]
    
    # Plot curves and confidence bands for each method
    for label, (color, marker, linestyle, original_col) in method_styles.items():
        # Get data using original column name
        means = mean_data[original_col]
        cis = ci_data[original_col]
        
        # Plot mean curve
        ax.plot(
            train_ratio, 
            means, 
            label=label,
            color=color,
            marker=marker,
            linestyle=linestyle,
            alpha=0.9
        )
        
        # Plot confidence interval
        ax.fill_between(
            train_ratio,
            means - cis,
            means + cis,
            color=color,
            alpha=0.15,
            linewidth=0
        )
    
    # Set axis labels and title
    ax.set_xlabel(r"Training Data Ratio ($Train\_Ratio$)", labelpad=10)
    ax.set_ylabel(r"Test Set MSE (Lower = Better Performance)", labelpad=10)
    ax.set_title(r"MSE Variation with Training Data Ratio (with 95% CI)", 
                 pad=15, fontweight="bold")
    
    # Configure ticks and range
    ax.set_xticks(train_ratio)
    ax.set_xticklabels([f"{x:.1f}" for x in train_ratio])
    
    # Collect all means and confidence intervals and convert to NumPy arrays
    all_means = []
    all_cis = []
    for _, (_, _, _, original_col) in method_styles.items():
        all_means.extend(mean_data[original_col])
        all_cis.extend(ci_data[original_col])
    
    all_means_np = np.array(all_means)
    all_cis_np = np.array(all_cis)
    y_min = np.min(all_means_np - all_cis_np) * 0.9
    y_max = np.max(all_means_np + all_cis_np) * 1.1
    ax.set_ylim(y_min, y_max)
    
    # Add grid and legend
    ax.grid(True, alpha=0.3, linestyle="--")
    ax.legend(loc="upper right", frameon=True, fancybox=True, shadow=True, ncol=1)
    
    # Save plot
    plt.tight_layout()
    timestamp = int(time.time())
    file_path = os.path.join(save_dir, f"mse_with_ci_updated_legend_{timestamp}.pdf")
    plt.savefig(file_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    print(f"Plot with updated legend saved to: {file_path}")
    return file_path

def main():
    """Main function: Load data and plot"""
    csv_path = "data/mse_summary.csv"  # Generic path instead of hardcoded user path
    mean_data, ci_data = load_data_with_ci(csv_path)
    method_styles = get_method_styles()
    create_plot(mean_data, ci_data, method_styles)

if __name__ == "__main__":
    main()
