import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import scienceplots
from collections import OrderedDict
from typing import List, Dict, Tuple, Optional

# Set consistent font sizes
plt.rcParams.update({
    'font.size': 10,
    'axes.titlesize': 12,
    'axes.labelsize': 10,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
})
plt.style.use('science')

METHOD_NAMES_MAPPING = OrderedDict([
    ("weighted_cost_upper_bound", r"Weighted-Cost"),
    ("bilevel_upper_bound", r"Primal Upsacling"),
    ("entropy_upper_bound", r"Primal Entropic Reg."),
    ("bilevel_lower_bound", r"Dual Upsacling"),
    ("min_cost_lower_bound", r"Min-Cost"),
    ("entropy_lower_bound", r"Dual Entropic Reg."),
])
def fun_name_to_display(name: str) -> str:
    """Convert function name to display format using a mapping."""
    return METHOD_NAMES_MAPPING.get(name, name)

def create_marker_color_maps(df: pd.DataFrame, params_field: str = "params") -> Tuple[Dict, Dict]:
    """Create consistent marker and color mappings for plotting."""
    markers = ['x', 's', '^', 'v', '+', '<']
    markers_map = {ot_fun: marker for ot_fun, marker in zip(METHOD_NAMES_MAPPING.keys(), markers)}
    colors_map = {param: color for param, color in zip(df[params_field].unique(), 
                                                     plt.cm.viridis(np.linspace(0, 1, len(df[params_field].unique()))))}
    return markers_map, colors_map

def setup_plot_style(ax: plt.Axes, xlog: bool = True, ylog: bool = True) -> None:
    """Set up common plot styling."""
    if xlog:
        ax.set_xscale('log')
    if ylog:
        ax.set_yscale('log')
    ax.grid(True, which='both', linestyle='--', alpha=0.3)
    ax.tick_params(axis='both', which='major', labelsize=9)
    ax.tick_params(axis='both', which='minor', labelsize=8)

def create_legend(fig: plt.Figure, ax: plt.Axes, ncol: int = 1, title: Optional[str] = None, y_pos: float = 0.5, frame_alpha: float = 0.8) -> None:
    """Create a consistent legend for the figure.
    
    Args:
        fig: matplotlib figure
        ax: matplotlib axes
        ncol: number of columns in legend
        title: legend title
        y_pos: vertical position of legend
        frame_alpha: transparency of legend frame (0 to 1)
    """
    handles, labels = ax.get_legend_handles_labels()
    legend = fig.legend(handles, labels,
              loc='center right',
              bbox_to_anchor=(1., y_pos),
              ncol=ncol,
              frameon=True,
              fontsize=9,
              handletextpad=0.1,
              borderpad=0.2,
              labelspacing=0.2,
              columnspacing=0.2,
              framealpha=frame_alpha,
              )
    if title:
        legend.set_title(title, prop={'size': 10})

def save_plot(fig: plt.Figure, filename: str) -> None:
    """Save plot to output directory."""
    output_dir = Path("output")
    output_dir.mkdir(exist_ok=True)
    plt.savefig(output_dir / filename, bbox_inches='tight', dpi=300)
    plt.show()

def plot_efficiency(df: pd.DataFrame, bound_type: str) -> None:
    """Plot efficiency comparison for a specific bound type."""
    # Create a copy of the DataFrame to avoid SettingWithCopyWarning
    df = df.copy()
    df.sort_values(by=["params", "ot_fun"], inplace=True)
    n_cols = len(df['p'].unique())
    fig, axes = plt.subplots(ncols=n_cols, figsize=(5*n_cols, 4), sharey=True)
    markers_map, colors_map = create_marker_color_maps(df)

    for ax, (p, group) in zip(axes, df.groupby('p')):
        for (param, ot_fun), grp in group.groupby(['params', 'ot_fun']):
            marker = markers_map[ot_fun]
            color = colors_map[param]
            ot_fun_print = fun_name_to_display(ot_fun)
            param = f"${param}$"
            ax.scatter(grp["time_rel"], grp["error_rel"], 
                      label=f"{ot_fun_print} ({param})", 
                      marker=marker, 
                      color=color,
                      s=20,
                      alpha=0.8)
        
        setup_plot_style(ax, xlog=True, ylog=True)
        ax.set_xlabel('Relative Time', fontsize=10)
        if ax == axes[0]:  # Only set y-label for left-most subplot
            ax.set_ylabel('Relative Error', fontsize=10)
        
        # Add title text inside the plot
        ax.text(0.98, 0.98, f'p = {p}', 
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))

    create_legend(fig, axes[0], title=f"{bound_type.title()} Bounds", y_pos=0.3)  # Lower position for efficiency plots
    plt.tight_layout()
    plt.subplots_adjust(right=0.85)  # Make room for the legend
    save_plot(fig, f"efficiency_{bound_type}_bounds.pdf")

def plot_accuracy(df: pd.DataFrame, bound_type: str) -> None:
    """Plot accuracy comparison for a specific bound type."""
    # Create a copy of the DataFrame to avoid SettingWithCopyWarning
    df = df.copy()
    p_values = sorted(df['p'].unique())
    n_cols = len(p_values)
    fig, axes = plt.subplots(1, n_cols, 
                            figsize=(5*n_cols, 4), 
                            squeeze=False,
                            sharey=True)
    
    markers_map, colors_map = create_marker_color_maps(df)
    
    for j, p in enumerate(p_values):
        ax = axes[0, j]
        subset = df[(df['bound'] == bound_type) & (df['p'] == p)]
        
        for (ot_fun, param), group in subset.groupby(['ot_fun', 'params']):
            marker = markers_map[ot_fun]
            color = colors_map[param]
            ot_fun_print = fun_name_to_display(ot_fun)
            param_display = f"${param}$"
            ax.scatter(group['ot_exact'], group['error_rel'],
                      label=f"{ot_fun_print} ({param_display})",
                      marker=marker,
                      color=color,
                      s=20,
                      alpha=0.8)
        
        setup_plot_style(ax, xlog=False, ylog=True)
        ax.set_xlabel(r'$\mathcal{W}_p$', fontsize=10)
        if j == 0:  # Only set y-label for left-most subplot
            ax.set_ylabel('Relative Error', fontsize=10)
        
        # Add title text inside the plot
        ax.text(0.98, 0.98, f'p = {p}', 
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))

    create_legend(fig, axes[0, 0], title=f"{bound_type.title()} Bounds", y_pos=0.7)  # Higher position for accuracy plots
    plt.tight_layout()
    plt.subplots_adjust(right=0.85)  # Make room for the legend
    save_plot(fig, f'accuracy_{bound_type}_bounds.pdf')

def format_scientific(x: float) -> str:
    """Format a number in scientific notation without leading zeros in the exponent."""
    s = f"{x:.1e}"
    # Split into mantissa and exponent
    mantissa, exp = s.split('e')
    # Remove leading zero from exponent
    exp = str(int(exp))
    return f"{mantissa}e{exp}"

def create_accuracy_table(df: pd.DataFrame) -> pd.DataFrame:
    """Create a table of average and standard deviation of relative errors.
    
    Args:
        df: DataFrame containing the benchmark results
        
    Returns:
        DataFrame with multi-index columns for methods and parameters,
        and multi-index rows for classes, p values, and statistics (mean/std)
    """
    # Print debug information
    print("\nDebug Info:")
    print("Columns in DataFrame:", df.columns.tolist())
    print("Sample of ot_fun values:", df['ot_fun'].unique())
    print("Sample of params values:", df['params'].unique())
    
    # Define the desired order of methods
    method_order = [
        'weighted_cost_upper_bound',
        'bilevel_upper_bound',
        'entropy_upper_bound',
        'bilevel_lower_bound',
        'min_cost_lower_bound',
        'entropy_lower_bound'
    ]
    
    # Group by class, p, ot_fun, and params to calculate statistics
    try:
        grouped = df.groupby(['class', 'p', 'ot_fun', 'params'])['error_rel'].agg(['mean', 'std'])
        print("\nGrouped data shape:", grouped.shape)
        print("Grouped data sample:\n", grouped.head())
    except Exception as e:
        print(f"Error in grouping: {str(e)}")
        return pd.DataFrame()
    
    # Reset index to make it easier to pivot
    grouped = grouped.reset_index()
    
    # Create the final table using a different approach
    try:
        # Create a list of unique class and p combinations
        unique_classes = grouped['class'].unique()
        unique_p = grouped['p'].unique()
        
        # Create multi-index for columns
        unique_ot_fun = grouped['ot_fun'].unique()
        unique_params = grouped['params'].unique()
        
        # Create multi-index for rows (class, p, stat)
        row_index = pd.MultiIndex.from_product(
            [unique_classes, unique_p, ['mean', 'std']],
            names=['class', 'p', 'stat']
        )
        
        # Create an empty DataFrame with the desired structure
        table = pd.DataFrame(
            index=row_index,
            columns=pd.MultiIndex.from_product([unique_ot_fun, unique_params], names=['method', 'param'])
        )
        
        # Fill the table with the formatted values
        for _, row in grouped.iterrows():
            # Wrap params in LaTeX math environment
            param_math = f"${row['params']}$"
            # Add mean value as percentage
            table.loc[(row['class'], row['p'], 'mean'), (row['ot_fun'], param_math)] = f"{row['mean'] * 100:.2f}\\%"
            # Add std value with \pm as percentage
            table.loc[(row['class'], row['p'], 'std'), (row['ot_fun'], param_math)] = f"$\\pm$ {row['std'] * 100:.2f}\\%"
        
        # Drop columns that are all NaN
        table = table.dropna(axis=1, how='all')
        
        # Reorder columns according to method_order
        current_columns = table.columns
        new_columns = []
        for method in method_order:
            method_params = [param for (m, param) in current_columns if m == method]
            new_columns.extend([(method, param) for param in method_params])
        
        # Reindex the table with the new column order
        table = table.reindex(columns=new_columns)
        
        print("\nFinal table shape:", table.shape)
        print("Final table sample:\n", table.head())
    except Exception as e:
        print(f"Error in creating final table: {str(e)}")
        return pd.DataFrame()
    
    return table

def create_time_table(df: pd.DataFrame) -> pd.DataFrame:
    """Create a table of average and standard deviation of relative times.
    
    Args:
        df: DataFrame containing the benchmark results
        
    Returns:
        DataFrame with multi-index columns for methods and parameters,
        and multi-index rows for classes, p values, and statistics (mean/std)
    """
    # Print debug information
    print("\nDebug Info for Time Table:")
    print("Columns in DataFrame:", df.columns.tolist())
    print("Sample of ot_fun values:", df['ot_fun'].unique())
    print("Sample of params values:", df['params'].unique())
    
    # Define the desired order of methods
    method_order = [
        'weighted_cost_upper_bound',
        'bilevel_upper_bound',
        'entropy_upper_bound',
        'bilevel_lower_bound',
        'min_cost_lower_bound',
        'entropy_lower_bound'
    ]
    
    # Group by class, p, ot_fun, and params to calculate statistics
    try:
        grouped = df.groupby(['class', 'p', 'ot_fun', 'params'])['time_rel'].agg(['mean', 'std'])
        print("\nGrouped data shape:", grouped.shape)
        print("Grouped data sample:\n", grouped.head())
    except Exception as e:
        print(f"Error in grouping: {str(e)}")
        return pd.DataFrame()
    
    # Reset index to make it easier to pivot
    grouped = grouped.reset_index()
    
    # Create the final table using a different approach
    try:
        # Create a list of unique class and p combinations
        unique_classes = grouped['class'].unique()
        unique_p = grouped['p'].unique()
        
        # Create multi-index for columns
        unique_ot_fun = grouped['ot_fun'].unique()
        unique_params = grouped['params'].unique()
        
        # Create multi-index for rows (class, p, stat)
        row_index = pd.MultiIndex.from_product(
            [unique_classes, unique_p, ['mean', 'std']],
            names=['class', 'p', 'stat']
        )
        
        # Create an empty DataFrame with the desired structure
        table = pd.DataFrame(
            index=row_index,
            columns=pd.MultiIndex.from_product([unique_ot_fun, unique_params], names=['method', 'param'])
        )
        
        # Fill the table with the formatted values
        for _, row in grouped.iterrows():
            # Wrap params in LaTeX math environment
            param_math = f"${row['params']}$"
            # Add mean value as percentage
            table.loc[(row['class'], row['p'], 'mean'), (row['ot_fun'], param_math)] = f"{row['mean'] * 100:.2f}\\%"
            # Add std value with \pm as percentage
            table.loc[(row['class'], row['p'], 'std'), (row['ot_fun'], param_math)] = f"$\\pm$ {row['std'] * 100:.2f}\\%"
        
        # Drop columns that are all NaN
        table = table.dropna(axis=1, how='all')
        
        # Reorder columns according to method_order
        current_columns = table.columns
        new_columns = []
        for method in method_order:
            method_params = [param for (m, param) in current_columns if m == method]
            new_columns.extend([(method, param) for param in method_params])
        
        # Reindex the table with the new column order
        table = table.reindex(columns=new_columns)
        
        print("\nFinal time table shape:", table.shape)
        print("Final time table sample:\n", table.head())
    except Exception as e:
        print(f"Error in creating final time table: {str(e)}")
        return pd.DataFrame()
    
    return table

def table_to_latex(table: pd.DataFrame) -> str:
    """Convert the accuracy table to LaTeX format.
    
    Args:
        table: DataFrame with multi-index columns and rows
        
    Returns:
        LaTeX formatted string
    """
    # Convert to LaTeX with specific formatting
    latex_str = table.to_latex(
        float_format=lambda x: format_scientific(float(x)) if isinstance(x, (float, str)) and x.replace('.', '').replace('-', '').replace('±', '').replace('e', '').isdigit() else str(x),
        multirow=True,
        multicolumn=True,
        multicolumn_format='c',
        column_format='l' + 'c' * len(table.columns),
        escape=False,
        caption='Accuracy comparison of different methods',
        label='tab:accuracy'
    )
    
    # Post-process the LaTeX to add \scriptsize for std rows
    lines = latex_str.split('\n')
    processed_lines = []
    for line in lines:
        if 'std' in line and '\\pm' in line:
            # Add \scriptsize around the \pm and std value
            line = line.replace('$\\pm$', r'{\scriptsize $\pm$')
            line = line.replace(r' & {', r'} & {')
            line = line.replace(r' \\', r'} \\')
            line = line.replace(r'std}', r'std')
        processed_lines.append(line)
    latex_str = '\n'.join(processed_lines)
    
    # Add some LaTeX packages and formatting
    header = r"""\documentclass{article}
\usepackage{booktabs}
\usepackage{multirow}
\usepackage{siunitx}
\usepackage{graphicx}
\usepackage{adjustbox}

\begin{document}
"""
    
    footer = r"""
\end{document}
"""
    
    return header + latex_str + footer

if __name__ == "__main__":
    plt.style.use('science')
    
    # Read the data files
    df_full = pd.read_csv("dot_benchmark_full_LATEST.csv")
    df_full = df_full.set_index(["class","resolution","p", "i", "j",'ot_fun_full','bound']).reset_index()
    df_full["ot_fun"] = df_full.ot_fun_full.str.split('(').str[0]
    df_full["params"] = df_full.ot_fun_full.str.split('(').str[1].str.split(')').str[0]
    df_full["params_latex"] = "$" + df_full["params"] + "$"

    # Filter out rows with epsilon_factor in [0.0005, 0.002]
    df_full = df_full[~df_full['epsilon_factor'].isin([0.0005, 0.002])]

    # Print unique values for debugging
    unique_ot_fun = sorted(df_full['ot_fun'].unique())
    print("Unique ot_fun values:", unique_ot_fun)
    unique_dot_class = sorted(df_full['class'].unique())
    print("Unique dot_class values:", unique_dot_class)

    # Create output directory
    output_dir = Path("output")
    output_dir.mkdir(exist_ok=True)

    # Create and save accuracy table
    accuracy_table = create_accuracy_table(df_full)
    
    # Save as CSV
    accuracy_table.to_csv(output_dir / "accuracy_table.csv")
    
    # Convert to LaTeX and save
    latex_table = table_to_latex(accuracy_table)
    with open(output_dir / "accuracy_table.tex", "w") as f:
        f.write(latex_table)
    
    print("\nAccuracy Table:")
    print(accuracy_table)

    # Create and save time table
    time_table = create_time_table(df_full)
    
    # Save as CSV
    time_table.to_csv(output_dir / "time_table.csv")
    
    # Convert to LaTeX and save
    latex_time_table = table_to_latex(time_table)
    with open(output_dir / "time_table.tex", "w") as f:
        f.write(latex_time_table)
    
    print("\nTime Table:")
    print(time_table)

    # Plot efficiency for upper and lower bounds
    df_eff_full_upper = df_full[df_full["bound"] == "upper"].copy()  # Create a copy
    df_eff_full_lower = df_full[df_full["bound"] == "lower"].copy()  # Create a copy

    plot_efficiency(df_eff_full_upper, 'upper')
    plot_efficiency(df_eff_full_lower, 'lower')
    
    # Plot accuracy comparison
    plot_accuracy(df_eff_full_upper, 'upper')
    plot_accuracy(df_eff_full_lower, 'lower')