import os
import pickle
import numpy as np
import pandas as pd
import jax
import time
import matplotlib.pyplot as plt
import scienceplots

from ot_jax.optimal_transport.jax_wasserstein import exact_wasserstein
from ot_jax.optimal_transport.jax_wasserstein import (bilevel_upper_bound, weighted_cost_upper_bound, bilevel_lower_bound, min_cost_lower_bound)
from ot_jax.optimal_transport.jax_wasserstein import (entropy_upper_bound, entropy_lower_bound)
from dot_benchmark_plot import fun_name_to_display, create_marker_color_maps, setup_plot_style, create_legend, METHOD_NAMES_MAPPING

from logging import getLogger, basicConfig, INFO
basicConfig(level=INFO)
logger = getLogger(__name__)

DATA_FOLDER = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../../data/rotmol3d/"))
BILEVEL_METHOD_NAMES = ["bilevel_lower_bound", "bilevel_upper_bound", "weighted_cost_upper_bound", "min_cost_lower_bound"]
ENTROPY_METHOD_NAMES = ["entropy_upper_bound", "entropy_lower_bound"]
METHOD_ORDER = list(METHOD_NAMES_MAPPING.keys())

# Set consistent font sizes
plt.rcParams.update({
    'font.size': 10,
    'axes.titlesize': 12,
    'axes.labelsize': 10,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
})
plt.style.use('science')

def run_benchmarks(rotated_volume_series_npy_filename: str):
    RESULTS_FILENAME = rotated_volume_series_npy_filename.replace(".npy", "_results.pickle")

    rotated_volumes_filename = os.path.join(DATA_FOLDER, rotated_volume_series_npy_filename)
    logger.info(f'Loading {rotated_volumes_filename}...')
    vols = np.load(rotated_volumes_filename)
    assert vols.ndim == 4
    assert vols.shape[1] == vols.shape[2] == vols.shape[3]
    N = vols.shape[1]

    logger.info(f'Loaded {vols.shape[0]} volumes of shape {N}x{N}x{N}')
    angles = np.linspace(0, 2*np.pi, vols.shape[0], endpoint=False).tolist()
    logger.info(f'Angles: {angles}')
    
    EXACT_BENCHMARKS = [("exact_wasserstein", p) for p in [1,2]]
    BILEVEL_BENCHMARKS = [(name, p, kappa) for name in BILEVEL_METHOD_NAMES for p in [1, 2] for kappa in [2, 4]]
    ENTROPY_BENCHMARKS = [(name, p, epsilon) for name in ENTROPY_METHOD_NAMES for p in [1, 2] for epsilon in [1e-3 * N**p, 4e-3 * N**p]]
    BENCHMARKS = EXACT_BENCHMARKS + BILEVEL_BENCHMARKS + ENTROPY_BENCHMARKS

    # results = defaultdict(list)
    results = {}
    for i, angle in enumerate(angles):
        logger.info(f'==== Volume {i}/{vols.shape[0]} (vs. Volume 0) ==============')
        for benchmark in BENCHMARKS:
            func_name = benchmark[0]
            arguments = [vols[0], vols[i]] + list(benchmark[1:])
            logger.info(f'Running {func_name} with arguments: vols[0], vols[{i}], {benchmark[1:]}...')
            func = globals()[func_name]

            start_time = time.time()
            result = float(jax.block_until_ready(func(*arguments)))
            duration = time.time() - start_time
            logger.info(f'Result: {result} (Duration: {duration} seconds)')

            results[(angle, *benchmark)] = (result, duration)
            with open(RESULTS_FILENAME, "wb") as f:
                pickle.dump(results, f)

def process_results(results, N=32):
    results_exact = {k: v for k, v in results.items() if k[1] == 'exact_wasserstein'}
    results_bilevel = {k: v for k, v in results.items() if k[1] in BILEVEL_METHOD_NAMES}
    results_entropy = {k: v for k, v in results.items() if k[1] in ENTROPY_METHOD_NAMES}
    
    df_exact = pd.DataFrame(results_exact.values(), columns=['ot_value', 'time'], 
                          index=pd.MultiIndex.from_tuples(results_exact.keys(), names=['angle', 'method', 'p']))
    df_exact = df_exact.reset_index('method', drop=True)
    
    df_bilevel = pd.DataFrame(results_bilevel.values(), columns=['ot_value', 'time'], 
                           index=pd.MultiIndex.from_tuples(results_bilevel.keys(), names=['angle', 'method', 'p', 'param']))
    df_bilevel = df_bilevel.reset_index(['param', 'method'])
    df_bilevel['param'] = r"\kappa_" + df_bilevel['param'].astype(int).astype(str)
    
    df_entropy = pd.DataFrame(results_entropy.values(), columns=['ot_value', 'time'], 
                           index=pd.MultiIndex.from_tuples(results_entropy.keys(), names=['angle', 'method', 'p', 'param']))
    df_entropy = df_entropy.reset_index(['param', 'method'])
    df_entropy['param'] = r"\varepsilon_" + (1000*df_entropy['param']/(N**df_entropy.index.get_level_values('p'))).astype(int).astype(str)
    df_bounds = pd.concat([df_bilevel, df_entropy])
    
    df_joined = df_bounds.join(df_exact, rsuffix='_exact')
    df_joined['relative_time'] = df_joined['time'] / df_joined['time_exact']
    denominator = np.where(df_joined['ot_value_exact'] == 0, df_joined['ot_value'], df_joined['ot_value_exact'])
    df_joined['relative_error'] = abs(df_joined['ot_value'] - df_joined['ot_value_exact']) / denominator
    df_joined = df_joined.fillna(0).reset_index()
    df_joined = df_joined.sort_values(by='method', key=lambda x: pd.Categorical(x, categories=METHOD_ORDER, ordered=True),
                                      axis=0)
    df_joined["angle"] = np.round(df_joined["angle"] * 180 / np.pi, 0).astype(int)
    df_joined = df_joined[df_joined["angle"] <= 180]
    return df_joined

def aggregate_results(df_joined, field='relative_time'):
    df_aggregated = (
        df_joined
        .groupby(['method', 'p', 'param'])[field]
        .agg(['mean', 'std'])
        .stack(0)
        .unstack(['method', 'param'])
        .sort_index(axis=1, level=0, key=lambda x: pd.Categorical(x, categories=METHOD_ORDER, ordered=True))
    )
    return df_aggregated

def generate_latex_table(df_aggregated, field='relative_time'):
    # Convert to LaTeX with multicolumn headers
    latex_str = df_aggregated.applymap(lambda x: f"{x * 100:.2f}\\%").to_latex(
        multicolumn=True,
        multicolumn_format='c',
        multirow=True,
        column_format='l' + 'c' * len(df_aggregated.columns),
        escape=False
    )
    
    # Add table environment and caption
    latex_table = (
        "\\begin{table}[ht]\n"
        "\\centering\n"
        f"\\caption{{{field.replace('_', ' ').title()} Results}}\n"
        "\\small\n"
        f"{latex_str}"
        "\\end{table}"
    )
    
    return latex_table

def plot_ot_vs_angle(df, output_filename, exact=False):
        """Create a plot showing OT value vs rotation angle for each method.
        
        Args:
            df: DataFrame containing the results with 'method', 'angle' and 'ot_value' columns
            output_filename: Base filename to use for saving the plot
        """
        p_values = sorted(df['p'].unique())
        n_cols = len(p_values)
        fig, axes = plt.subplots(1, n_cols, 
                                figsize=(5*n_cols, 3), 
                                squeeze=False,
                                # sharey=True,
                                )
        markers_map, colors_map = create_marker_color_maps(df, params_field="param")
        df['method'] = pd.Categorical(df['method'], categories=METHOD_ORDER, ordered=True)
        for j, p in enumerate(p_values):
            ax = axes[0, j]
            subset = df[df['p'] == p].copy()
            subset.sort_values('method', inplace=True)
            for (method, param), group in subset.groupby(['method', 'param']):
                marker = markers_map[method]
                color = colors_map[param]
                method_print = fun_name_to_display(method)
                param_display = f"${param}$"
                ax.scatter(group['angle'], group['ot_value'], label=f"{method_print} ({param_display})",
                           marker=marker,
                           color=color,
                           s=20,
                           alpha=0.8)
            if exact:
                group.sort_values(by='angle', inplace=True)
                group = group.iloc[1:,:]
                ax.plot(group['angle'], group['ot_value_exact'], label=f"Exact Wasserstein",
                           marker='o',
                           color="black",
                           linestyle='-',
                           linewidth=2,
                           markersize=3
                           )
            setup_plot_style(ax, xlog=False, ylog=True)
            ax.set_xlabel(f"Rotation Angle (deg)", fontsize=10)
            ax.set_xlim(-5, 185)
            ax.set_xticks(np.arange(0, 181, 20))
            if j == 0:  # Only set y-label for left-most subplot
                ax.set_ylabel('Wasserstein Distance Bounds', fontsize=10)
            
            # Add title text inside the plot
            ax.text(0.98, 0.98, f'p = {p}', 
                    transform=ax.transAxes,
                    fontsize=10,
                    verticalalignment='top',
                    horizontalalignment='right',
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))

        create_legend(fig, axes[0, 0], title=f"Methods", y_pos=0.3, frame_alpha=0.95)  # Set legend frame alpha
        plt.tight_layout()
        plt.subplots_adjust(right=0.85)  # Make room for the legend
        plt.savefig(output_filename.replace(".npy", "_ot_vs_angle.pdf"))
        plt.close()

if __name__ == "__main__":
    filename = "rotated_EMDB2660_mask_radius=128_downscale_factor=8_n_angles=36.npy"
    results_file = filename.replace(".npy", "_results.pickle")
    if not os.path.exists(results_file):
        # Set JAX to use the last GPU (GPU 3)
        jax.config.update('jax_platform_name', 'gpu')
        jax.config.update('jax_default_device', jax.devices('gpu')[3])
        
        run_benchmarks(filename)
    else:
        logger.info(f"Results file {results_file} already exists, skipping...")
    
    results = pickle.load(open(results_file, "rb"))
    
    # Process results
    df_joined = process_results(results)
    df_joined.to_csv(filename.replace(".npy", "_results.csv"), index=False)
    print(df_joined.head())
    
    # Plot OT value vs angle
    plot_ot_vs_angle(df_joined, filename)
    
    # Generate tables and print/save them
    df_time = aggregate_results(df_joined, field='relative_time')
    latex_time = generate_latex_table(df_time, field='relative_time')
    df_error = aggregate_results(df_joined, field='relative_error')
    latex_error = generate_latex_table(df_error, field='relative_error')
    
    print("\nRelative Time Results:")
    print(df_time)
    print("\nRelative Error Results:")
    print(df_error)
    
    # Save LaTeX tables
    with open(filename.replace(".npy", "_time_table.tex"), 'w') as f:
        f.write(latex_time)
    with open(filename.replace(".npy", "_error_table.tex"), 'w') as f:
        f.write(latex_error)
    
    print("\nLaTeX tables have been written to '_time_table.tex' and '_error_table.tex'")
