import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize, ListedColormap, BoundaryNorm
import argparse
import seaborn as sns

# Set seaborn style
sns.set_style("whitegrid")
sns.set_palette("husl")
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 20
plt.rcParams['axes.labelsize'] = 20
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['legend.fontsize'] = 20

# Parse command line arguments
parser = argparse.ArgumentParser(description='Ablation analysis with optional pre-specified configurations')
parser.add_argument('--pareto-indices', type=int, nargs='*', help='Row indices to force as Pareto frontier (space-separated)')
parser.add_argument('--final-config', type=int, help='Row index for final configuration')
parser.add_argument('--no-interactive', action='store_true', help='Skip interactive mode')
args = parser.parse_args()

# Create DataFrame from your data
data = {
    'TV weight': [0.25, 0.3, 0.33, 0.4, 0.3, 0.3, 0.33, 0.4, 0.5, 0.25, 0.5, 0.5, 0.5, 0.25, 0.25, 0.25, 0.75, 0.4],
    'Jacobian weight': [0, 5, 5, 5, 10, 15, 10, 10, 5, 10, 0, 10, 20, 100, 20, 5, 0, 0],
    'Dice': [0.793, 0.791, 0.791, 0.791, 0.790, 0.790, 0.790, 0.790, 0.790, 0.790, 0.789, 0.788, 0.788, 0.783, 0.790, 0.791, 0.782, 0.791],
    'Folding, %': [1.119, 0.144, 0.101, 0.084, 0.137, 0.110, 0.105, 0.064, 0.041, 0.200, 0.346, 0.036, 0.024, 0.026, 0.118, 0.229, 0.177, 0.476]
}
df = pd.DataFrame(data)

# Calculate Pareto frontier (for reference)
def calculate_pareto(df):
    if args.pareto_indices is not None:
        # Use specified indices
        pareto_mask = np.zeros(len(df), dtype=bool)
        pareto_mask[args.pareto_indices] = True
        return df[pareto_mask]
    else:
        # Calculate automatically
        pareto_mask = np.ones(len(df), dtype=bool)
        for i, (_, row) in enumerate(df.iterrows()):
            if pareto_mask[i]:
                # Dominance check: Higher Dice AND Lower Folding
                pareto_mask &= ~((df['Dice'] <= row['Dice']) & 
                                 (df['Folding, %'] >= row['Folding, %']) & 
                                 ((df['Dice'] < row['Dice']) | (df['Folding, %'] > row['Folding, %'])))
                pareto_mask[i] = True  # Keep self
        return df[pareto_mask]

pareto_df = calculate_pareto(df)

# Create the plot
plt.figure(figsize=(12, 5.5))
ax = plt.gca()

# Color and size normalization - use a more appealing colormap
tv_unique = sorted(df['TV weight'].unique())
colors = sns.color_palette("viridis", len(tv_unique))
discrete_cmap = ListedColormap(colors)

# Map TV weights to discrete values for coloring
tv_discrete = df['TV weight'].map({val: i for i, val in enumerate(tv_unique)})

jacobian_sizes = 100 + 400 * np.sqrt(df['Jacobian weight'] / np.sqrt(df['Jacobian weight'].max()))  # Size scaling

# Plot all points with improved styling
scatter = ax.scatter(
    x=df['Folding, %'],
    y=df['Dice'],
    c=tv_discrete,
    s=jacobian_sizes,
    cmap=discrete_cmap,
    alpha=0.85,
    edgecolors='white',
    linewidth=1.0,
    picker=True,
    pickradius=10
)

# Highlight Pareto frontier with improved styling
pareto_jacobian_sizes = 100 + 400 * np.sqrt(pareto_df['Jacobian weight'] / np.sqrt(df['Jacobian weight'].max()))
ax.scatter(
    pareto_df['Folding, %'],
    pareto_df['Dice'],
    s=pareto_jacobian_sizes,
    facecolors='none',
    edgecolors='red',
    linewidths=2.5,
    label='Pareto Frontier',
    alpha=0.9
)

# Colorbar for TV weight
cbar = plt.colorbar(scatter, ax=ax, ticks=range(len(tv_unique)))
cbar.set_ticklabels([f'{val}' for val in tv_unique])
cbar.set_label('TV Weight', fontsize=20)

# Labels and title with improved styling
ax.set_xlabel('Folding (%) → Lower is better', fontsize=20, fontweight='medium')
ax.set_ylabel('Dice → Higher is better', fontsize=20, fontweight='medium')
title = 'Configuration Trade-off Analysis'
if not args.no_interactive and args.final_config is None:
    title += '\n(Click to select, Enter to finalize)'
# ax.set_title(title, fontsize=20, fontweight='bold', pad=20)

# Improve grid styling
ax.grid(True, alpha=0.3, linestyle='-', linewidth=0.5)
ax.set_axisbelow(True)

# Create legend for Jacobian weights with improved styling
jac_vals = [0, 5, 20, 100]
sizes = [100 + 400 * np.sqrt(v) / np.sqrt(max(jac_vals)) for v in jac_vals]
legend_elements = [plt.scatter([], [], s=s, c='gray', alpha=0.7, edgecolors='white', linewidth=0.5, label=f'J={v}') for v, s in zip(jac_vals, sizes)]
legend = ax.legend(handles=legend_elements, title='Jacobian Weights', loc='lower right', 
                  frameon=True, borderpad=1, handletextpad=1, fancybox=True, shadow=True)
legend.get_title().set_fontweight('bold')

# Interactive selection setup
selected_points = []
final_config = args.final_config
selected_artists = []

# Set initial selections if provided
if args.final_config is not None:
    selected_points = [args.final_config]

def on_pick(event):
    global selected_points
    if event.artist != scatter: return
    
    ind = event.ind[0]
    row = df.iloc[ind]
    
    # Toggle selection
    if ind in selected_points:
        selected_points.remove(ind)
    else:
        selected_points.append(ind)
    
    update_selections()

def on_key(event):
    global final_config
    if event.key == 'enter' and selected_points:
        final_config = selected_points[-1]
        update_selections()
        print(f"\nFinal configuration selected:")
        print(df.loc[final_config])
        print("\nClose plot to continue...")

def update_selections():
    # Clear previous selections
    for artist in selected_artists:
        artist.remove()
    selected_artists.clear()
    
    # Draw new selections
    if selected_points:
        # Highlight selected points
        selected_df = df.loc[selected_points]
        sel = ax.scatter(
            selected_df['Folding, %'],
            selected_df['Dice'],
            s=200,
            facecolors='none',
            edgecolors='dodgerblue',
            linewidths=2,
            zorder=10
        )
        selected_artists.append(sel)
        
        # Add annotations
        # for i, (_, row) in enumerate(selected_df.iterrows()):
        #     text = ax.annotate(
        #         f"TV:{row['TV weight']} J:{row['Jacobian weight']}",
        #         (row['Folding, %'], row['Dice']),
        #         xytext=(10, -20 - 20*i),
        #         textcoords='offset points',
        #         arrowprops=dict(arrowstyle="->", color='dodgerblue'),
        #         bbox=dict(boxstyle="round,pad=0.3", fc="w", alpha=0.8),
        #         zorder=11
        #     )
        #     selected_artists.append(text)
    
    # Mark final config
    if final_config is not None:
        row = df.loc[final_config]
        star = ax.scatter(
            [row['Folding, %']],
            [row['Dice']],
            s=400,
            marker='*',
            color='gold',
            edgecolor='k',
            zorder=12,
            label='Final Config'
        )
        selected_artists.append(star)
    
    plt.draw()

# Connect events only if interactive mode is enabled
if not args.no_interactive and args.final_config is None:
    plt.gcf().canvas.mpl_connect('pick_event', on_pick)
    plt.gcf().canvas.mpl_connect('key_press_event', on_key)

# Update initial selections if any (for both interactive and non-interactive modes)
if selected_points or final_config is not None:
    update_selections()

plt.tight_layout()
plt.savefig('abl.png')

# After closing plot
if selected_points:
    print("\nSelected configurations:")
    print(df.loc[selected_points])
if final_config is not None:
    print(f"\nFinal configuration (index {final_config}):")
    print(df.loc[final_config])