#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.ticker import ScalarFormatter, MaxNLocator, FormatStrFormatter
import seaborn as sns
from pathlib import Path
import argparse

# Set the plotting style for academic publication with increased font sizes
plt.style.use('seaborn-v0_8-paper')
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman']
mpl.rcParams['axes.titlesize'] = 29
mpl.rcParams['axes.labelsize'] = 23
mpl.rcParams['xtick.labelsize'] = 20
mpl.rcParams['ytick.labelsize'] = 20
mpl.rcParams['legend.fontsize'] = 20
mpl.rcParams['lines.linewidth'] = 2.2
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['axes.grid'] = True
mpl.rcParams['grid.alpha'] = 0.4
mpl.rcParams['grid.linestyle'] = '--'

# Constants
ALPHAS = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
EPSILONS = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25]
METRICS = ["accuracy", "kl_divergence", "js_divergence"]
METRIC_LABELS = {
    "accuracy": "Accuracy", 
    "kl_divergence": "KL Divergence", 
    "js_divergence": "JS Divergence"
}
EPSILON_COLORS = {
    0.0: '#1f77b4',   # Blue
    0.05: '#ff7f0e',  # Orange
    0.1: '#2ca02c',   # Green
    0.15: '#d62728',  # Red
    0.2: '#9467bd',   # Purple
    0.25: '#8c564b',  # Brown
}
EPSILON_MARKERS = {
    0.0: 'o',      # Circle
    0.05: 's',     # Square
    0.1: '^',      # Triangle up
    0.15: 'v',     # Triangle down
    0.2: 'D',      # Diamond
    0.25: 'P',     # Plus (filled)
}

DATASET_NAMES = {
    "snli": "SNLI",
    "multinli": "MNLI",
    "summeval": "SummEval",
    "mtbench": "MTBench"
}

def load_dataset(dataset_name, model_name, target_epoch):
    """Load the CSV data for a specific dataset."""
    current_dir = os.path.dirname(os.path.abspath(__file__))
    file_path = os.path.join(current_dir, "evaluation_results", "metrics", dataset_name, "evaluation_results.csv")
    try:
        df = pd.read_csv(file_path)
        # Filter for specific model and epoch
        df = df[(df['model_name'].str.lower() == model_name.lower()) & (df['epoch'] == target_epoch)]
        if df.empty:
            print(f"Warning: No data found for model {model_name} at epoch {target_epoch} in {dataset_name}")
        return df
    except FileNotFoundError:
        print(f"Error: File not found: {file_path}")
        return pd.DataFrame()

def handle_kl_divergence_scaling(dataset_df, ax, dataset_name):
    """
    Special handling for KL divergence on MTBench to show detail for alpha < 1.0
    when alpha=1.0 has extreme values.
    Uses a broken axis effect by limiting y-axis and annotating the high value.
    """
    if dataset_name.lower() != "mtbench":
        return False

    # Data for detailed view (alpha <= 0.8)
    detailed_view_df = dataset_df[dataset_df['alpha'] <= 0.8]
    # Data for the peak value (alpha = 1.0)
    peak_value_df = dataset_df[dataset_df['alpha'] == 1.0]

    if detailed_view_df.empty or peak_value_df.empty:
        return False

    kl_values_detailed = detailed_view_df['kl_divergence'].dropna()
    kl_values_peak = peak_value_df['kl_divergence'].dropna()

    if kl_values_detailed.empty or kl_values_peak.empty:
        return False
        
    max_kl_detailed = kl_values_detailed.max()
    max_kl_at_peak = kl_values_peak.max()

    if pd.isna(max_kl_detailed) or pd.isna(max_kl_at_peak):
        return False

    # Condition for applying the broken axis: peak is significantly larger (e.g., > 2x)
    if max_kl_at_peak > (max_kl_detailed * 2.0):
        y_limit_detail = 0
        if len(kl_values_detailed) > 0:
            p95_detailed = np.percentile(kl_values_detailed, 95)
            y_limit_detail = max(p95_detailed * 1.2, max_kl_detailed * 1.1 if max_kl_detailed > 0 else p95_detailed * 1.2)
        elif max_kl_detailed > 0 : # Fallback if only max_kl_detailed is available
             y_limit_detail = max_kl_detailed * 1.1
        else: # All detailed values are zero or NaN
            return False


        y_limit_detail = max(y_limit_detail, 0.001) # Ensure a small minimum y-limit if values are tiny

        # If the calculated detail limit is already encompassing or too close to the peak, no special handling needed.
        if y_limit_detail >= max_kl_at_peak:
            return False

        # Ensure y_limit_detail is meaningfully larger than the new lower bound 0.6
        # User requested y-axis to start at 0.6
        # The visible range must be y_limit_detail - 0.6 > 0
        if y_limit_detail <= 0.605: # Needs to be slightly larger than 0.6 to be a valid range
             print(f"Warning: Calculated y_limit_detail ({y_limit_detail:.2f}) for MTBench KL is not sufficiently greater than 0.6. Skipping special scaling.")
             return False

        ax.set_ylim(0.6, y_limit_detail) # User's specific request for lower bound
            
        # Add simple text-based break marks near the top of the y-axis.
        # Using axes coordinates for positioning relative to the axes box.
        ax.text(0.02, 0.98, "//", transform=ax.transAxes,
                fontsize=18, color='gray', ha='center', va='top', weight='bold')  # Increased fontsize

        # Simpler text annotation for the peak value at alpha=1.0
        # Placed in the top-right area of the plot, using axes coordinates for robustness
        ax.text(0.98, 0.95, f"KL(α=1.0) ≈ {max_kl_at_peak:.2f}",
                transform=ax.transAxes, 
                fontsize=14,  # Increased fontsize
                color="darkred",
                ha='right', va='top',
                bbox=dict(boxstyle="round,pad=0.2", fc="ivory", ec="darkred", alpha=0.75))
        return True # Special scaling applied
    
    return False

def format_axis_for_metric(ax, metric, y_values):
    """Format the y-axis based on the metric and its range."""
    if len(y_values) == 0:
        return
    
    y_min, y_max = min(y_values), max(y_values)
    y_range = y_max - y_min
    y_padding = y_range * 0.1  # 10% padding
    
    # Special handling for different metrics
    if metric == "accuracy":
        # For accuracy, try to start from 0 or at least make sure the range makes sense
        y_bottom = max(0, y_min - y_padding)
        y_top = min(1.0, y_max + y_padding)
        ax.set_ylim(y_bottom, y_top)
        ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1.0))
    else:  # KL and JS divergence
        ax.set_ylim(max(0, y_min - y_padding), y_max + y_padding)
        # Set formatter to always show 2 decimal places
        ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        ax.yaxis.set_major_locator(MaxNLocator(nbins=6))
    
    # Add grid for better readability
    ax.grid(True, linestyle='--', alpha=0.4)

def plot_metric_by_dataset(dataset_name, dataset_df, metric, model_name, target_epoch):
    """Create a plot for a specific metric and dataset."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    y_values = dataset_df[metric].values
    format_axis_for_metric(ax, metric, y_values)
    
    # Special handling for KL divergence to better show details
    if metric == "kl_divergence":
        handle_kl_divergence_scaling(dataset_df, ax, dataset_name)
    
    # Create a line for each epsilon value
    for epsilon in EPSILONS:
        epsilon_df = dataset_df[dataset_df['epsilon'] == epsilon]
        if not epsilon_df.empty:
            # Sort by alpha for proper line plotting
            epsilon_df = epsilon_df.sort_values('alpha')
            ax.plot(
                epsilon_df['alpha'], 
                epsilon_df[metric], 
                marker=EPSILON_MARKERS[epsilon], 
                linestyle='-', 
                color=EPSILON_COLORS[epsilon],
                label=f"ε = {epsilon}",
                markersize=10  # Increased from 8
            )
    
    # Setup the axes and labels
    ax.set_xlabel("Alpha (α)", fontweight='bold')
    ax.set_ylabel(METRIC_LABELS[metric], fontweight='bold')
    ax.set_title(f"{DATASET_NAMES.get(dataset_name, dataset_name)}", fontweight='bold')
    
    # Set x-axis ticks to match alpha values
    ax.set_xticks(ALPHAS)
    ax.set_xticklabels([str(alpha) for alpha in ALPHAS])
    
    # Add legend with good placement
    legend = ax.legend(loc='best', frameon=True, framealpha=0.9)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_edgecolor('lightgray')
    
    # Ensure tight layout
    plt.tight_layout()
    
    # Create output directory with model and epoch info
    current_dir = os.path.dirname(os.path.abspath(__file__))
    base_output_dir = os.path.join(current_dir, "evaluation_results", "img")
    output_dir = f"{base_output_dir}/{model_name}_epoch{target_epoch}"
    svg_dir = f"{output_dir}/svg"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(svg_dir, exist_ok=True)
    
    # Save PNG version
    output_path = f"{output_dir}/{dataset_name}_{metric}.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    # Save SVG version
    svg_path = f"{svg_dir}/{dataset_name}_{metric}.svg"
    plt.savefig(svg_path, format='svg', bbox_inches='tight')
    
    print(f"Saved plot to {output_path} and {svg_path}")
    plt.close(fig)

def plot_combined_view(dataset_name, dataset_df, model_name, target_epoch):
    """Create a combined plot with all three metrics for a dataset."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    for i, metric in enumerate(METRICS):
        ax = axes[i]
        y_values = dataset_df[metric].values
        format_axis_for_metric(ax, metric, y_values)
        
        # Special handling for KL divergence to better show details
        if metric == "kl_divergence":
            handle_kl_divergence_scaling(dataset_df, ax, dataset_name)
        
        for epsilon in EPSILONS:
            epsilon_df = dataset_df[dataset_df['epsilon'] == epsilon]
            if not epsilon_df.empty:
                epsilon_df = epsilon_df.sort_values('alpha')
                ax.plot(
                    epsilon_df['alpha'], 
                    epsilon_df[metric], 
                    marker=EPSILON_MARKERS[epsilon], 
                    linestyle='-', 
                    color=EPSILON_COLORS[epsilon],
                    label=f"ε = {epsilon}",
                    markersize=9  # Increased from 7
                )
        
        ax.set_xlabel("Alpha (α)", fontweight='bold')
        ax.set_ylabel(METRIC_LABELS[metric], fontweight='bold')
        
        # Set x-axis ticks to match alpha values
        ax.set_xticks(ALPHAS)
        ax.set_xticklabels([str(alpha) for alpha in ALPHAS])
    
    # Add a single title for the entire figure
    fig.suptitle(f"{DATASET_NAMES.get(dataset_name, dataset_name)}", 
                 fontweight='bold', fontsize=30)  # Increased from 20
    
    # Add a single legend for the entire figure
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=len(EPSILONS), 
               bbox_to_anchor=(0.5, 0), frameon=True, framealpha=0.9)
    
    # Adjust layout to make room for the legend and caption
    fig.tight_layout(rect=[0, 0.08, 1, 0.95])
    
    # Create output directory with model and epoch info
    current_dir = os.path.dirname(os.path.abspath(__file__))
    base_output_dir = os.path.join(current_dir, "evaluation_results", "img")
    output_dir = f"{base_output_dir}/{model_name}_epoch{target_epoch}"
    svg_dir = f"{output_dir}/svg"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(svg_dir, exist_ok=True)
    
    # Save PNG version
    output_path = f"{output_dir}/{dataset_name}_combined.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    # Save SVG version
    svg_path = f"{svg_dir}/{dataset_name}_combined.svg"
    plt.savefig(svg_path, format='svg', bbox_inches='tight')
    
    print(f"Saved combined plot to {output_path} and {svg_path}")
    plt.close(fig)

def plot_all_datasets_grid(datasets_data, metric, model_name, target_epoch):
    """
    Create a grid plot comparing all datasets for a single metric.
    Each dataset uses its own scale to better show details.
    
    Args:
        datasets_data: dict of {dataset_name: dataframe}
        metric: The metric to plot
        model_name: Model name for display
        target_epoch: Epoch to use for display
    """
    # Skip if we have no data
    if not datasets_data:
        print(f"No datasets with data available for grid plot of {metric}")
        return
    
    # Determine grid size based on number of datasets
    n_datasets = len(datasets_data)
    if n_datasets <= 2:
        n_rows, n_cols = 1, n_datasets
    else:
        n_rows = 2
        n_cols = (n_datasets + 1) // 2  # Ceiling division
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols*8, n_rows*5.7))
    
    # Convert to 2D array for consistent indexing
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = np.array([axes])
    elif n_cols == 1:
        axes = axes.reshape(-1, 1)
    
    # Plot each dataset with its own scale
    for i, (dataset_name, dataset_df) in enumerate(datasets_data.items()):
        row_idx = i // n_cols
        col_idx = i % n_cols
        ax = axes[row_idx, col_idx]
        
        # Format axis for this specific dataset
        y_values = dataset_df[metric].values
        format_axis_for_metric(ax, metric, y_values)
        
        # Special handling for KL divergence to better show details
        if metric == "kl_divergence":
            handle_kl_divergence_scaling(dataset_df, ax, dataset_name)
        
        # Plot each epsilon line
        for epsilon in EPSILONS:
            epsilon_df = dataset_df[dataset_df['epsilon'] == epsilon]
            if not epsilon_df.empty:
                epsilon_df = epsilon_df.sort_values('alpha')
                ax.plot(
                    epsilon_df['alpha'], 
                    epsilon_df[metric], 
                    marker=EPSILON_MARKERS[epsilon], 
                    linestyle='-', 
                    color=EPSILON_COLORS[epsilon],
                    label=f"ε = {epsilon}",
                    markersize=8
                )
        
        # Set labels and title
        ax.set_xlabel("Alpha (α)", fontweight='bold')
        ax.set_ylabel(METRIC_LABELS[metric], fontweight='bold')
        
        # Add dataset title to each subplot
        ax.set_title(f"{DATASET_NAMES.get(dataset_name, dataset_name)}", 
                    fontweight='bold', fontsize=23)  
        
        # Set x-axis ticks
        ax.set_xticks(ALPHAS)
        ax.set_xticklabels([str(alpha) for alpha in ALPHAS])
    
    # Hide any unused subplots
    for i in range(len(datasets_data), n_rows * n_cols):
        row_idx = i // n_cols
        col_idx = i % n_cols
        axes[row_idx, col_idx].axis('off')
    
    # Add a single legend for the entire figure
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=len(EPSILONS), 
               bbox_to_anchor=(0.5, 0), frameon=True, framealpha=0.9, 
               fontsize=21.5)  
    
    # Add overall title
    fig.suptitle(f"{METRIC_LABELS[metric]} Comparison Across Datasets", 
                fontweight='bold', fontsize=32)  
    
    # Adjust layout to give the chart more space
    fig.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    # Create output directory with model and epoch info
    current_dir = os.path.dirname(os.path.abspath(__file__))
    base_output_dir = os.path.join(current_dir, "evaluation_results", "img")
    output_dir = f"{base_output_dir}/{model_name}_epoch{target_epoch}"
    svg_dir = f"{output_dir}/svg"
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(svg_dir, exist_ok=True)
    
    # Save PNG version
    output_path = f"{output_dir}/all_datasets_{metric}_comparison.png"
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    
    # Save SVG version
    svg_path = f"{svg_dir}/all_datasets_{metric}_comparison.svg"
    plt.savefig(svg_path, format='svg', bbox_inches='tight')
    
    print(f"Saved grid comparison plot to {output_path} and {svg_path}")
    plt.close(fig)

def main(datasets, model_name, target_epoch, combined=True):
    """Main function to create plots for all specified datasets and metrics."""
    print(f"Creating plots for datasets: {', '.join(datasets)}")
    print(f"Model: {model_name}, Epoch: {target_epoch}")
    
    # Store loaded datasets for grid comparison
    datasets_data = {}
    
    for dataset_name in datasets:
        dataset_df = load_dataset(dataset_name, model_name, target_epoch)
        if dataset_df.empty:
            print(f"Skipping {dataset_name} due to empty dataset")
            continue
        
        # Store dataset for later grid comparison
        datasets_data[dataset_name] = dataset_df
        
        print(f"Processing {dataset_name}...")
        
        # Individual metric plots
        for metric in METRICS:
            print(f"  Plotting {metric}...")
            plot_metric_by_dataset(dataset_name, dataset_df, metric, model_name, target_epoch)
        
        # Combined view if requested
        if combined:
            print(f"  Creating combined view...")
            plot_combined_view(dataset_name, dataset_df, model_name, target_epoch)
    
    # Create grid comparison plots for each metric
    if len(datasets_data) > 1 and combined:
        print("Creating grid comparison plots across datasets...")
        for metric in METRICS:
            print(f"  Grid comparison for {metric}...")
            plot_all_datasets_grid(datasets_data, metric, model_name, target_epoch)
    
    print("All plots have been generated!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate performance plots for model evaluation across datasets")
    parser.add_argument("--datasets", nargs='+', default=["snli", "multinli", "summeval", "mtbench"],
                        help="List of datasets to plot (space-separated)")
    parser.add_argument("--no-combined", dest="combined", action="store_false",
                        help="Disable combined view plots")
    parser.add_argument("--model", default="qwen2.5-7b",
                        help="Model name to filter by (default: qwen2.5-7b)")
    parser.add_argument("--epoch", type=int, default=2,
                        help="Epoch to filter by (default: 2)")
    parser.add_argument("--output-dir", default="evaluation_results/img",
                        help="Base output directory for plots (default: evaluation_results/img)")
    parser.set_defaults(combined=True)
    
    args = parser.parse_args()
    
    # Validate dataset names
    valid_datasets = ["snli", "multinli", "summeval", "mtbench"]
    datasets_to_plot = [d.lower() for d in args.datasets if d.lower() in valid_datasets]
    
    if not datasets_to_plot:
        print(f"No valid datasets specified. Please choose from: {', '.join(valid_datasets)}")
        exit(1)
    
    main(datasets_to_plot, args.model, args.epoch, args.combined) 