import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# Read CSV file
csv_file = 'run_results.csv'
results_df = pd.read_csv(csv_file)
results_df['Dataset'] = results_df['Dataset'].str.lower()
results_df['Split'] = results_df['Split'].str.lower()
results_df['Method'] = results_df['Method'].str.upper()

# Extract unique datasets, splits, and methods
dataset_splits = [f"{row['Dataset']}-{row['Split']}" for _, row in results_df.iterrows()]
datasets = sorted(results_df['Dataset'].unique().tolist())
splits = sorted(results_df['Split'].unique().tolist())
methods = sorted(results_df['Method'].unique().tolist())

# Extract metrics
accuracy = results_df['Test Accuracy (%)'].values / 100  # Convert to [0,1]
f1 = results_df['Test F1 (%)'].values / 100
accuracy_std = results_df['Test Accuracy Std (%)'].values / 100
f1_std = results_df['Test F1 Std (%)'].values / 100

# Dynamic color and marker mappings
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray',
          'tab:olive', 'tab:cyan']
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']
dataset_colors = {ds: colors[i % len(colors)] for i, ds in enumerate(datasets)}
print(dataset_colors)
split_markers = {sp: markers[i % len(markers)] for i, sp in enumerate(splits)}

# Plotting configuration
plt.style.use('seaborn')
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 10,
    'ytick.labelsize': 12,
    'legend.fontsize': 10,
    'figure.figsize': (14, 10),
    'axes.grid': True,
    'grid.alpha': 0.5
})

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(14, 10))

# Plot for each method
x = np.arange(len(dataset_splits))
for method in methods:
    # Filter data for the current method
    method_mask = results_df['Method'] == method
    method_x = x[method_mask]
    method_accuracy = accuracy[method_mask]
    method_f1 = f1[method_mask]
    method_accuracy_std = accuracy_std[method_mask]
    method_f1_std = f1_std[method_mask]
    method_datasets = [results_df['Dataset'][i] for i in range(len(results_df)) if method_mask[i]]
    method_splits = [results_df['Split'][i] for i in range(len(results_df)) if method_mask[i]]

    # Scatter plot for accuracy with labels
    for i, (ds, split) in enumerate(zip(method_datasets, method_splits)):
        ax1.errorbar(method_x[i], method_accuracy[i], yerr=method_accuracy_std[i],
                     color=dataset_colors[ds], marker=split_markers[split],
                     linestyle='none', capsize=5, markersize=8, markeredgecolor='black',
                     label=f'{method}' if i == 0 else None)

        # Add value labels for accuracy
        acc_value = method_accuracy[i] # Convert back to percentage
        ax1.annotate(f'{acc_value:.1f}%',
                     (method_x[i], method_accuracy[i]),
                     textcoords="offset points",
                     xytext=(0, 10),
                     ha='center',
                     fontsize=8,
                     bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.7))

    # Scatter plot for F1-score with labels
    for i, (ds, split) in enumerate(zip(method_datasets, method_splits)):
        ax2.errorbar(method_x[i], method_f1[i], yerr=method_f1_std[i],
                     color=dataset_colors[ds], marker=split_markers[split],
                     linestyle='none', capsize=5, markersize=8, markeredgecolor='black',
                     label=f'{method}' if i == 0 else None)

        # Add value labels for F1-score
        f1_value = method_f1[i]  # Convert back to percentage
        ax2.annotate(f'{f1_value:.1f}%',
                     (method_x[i], method_f1[i]),
                     textcoords="offset points",
                     xytext=(0, 10),
                     ha='center',
                     fontsize=8,
                     bbox=dict(boxstyle="round,pad=0.3", facecolor='white', alpha=0.7))

# Create custom legend
legend_elements = [
                      Line2D([0], [0], marker='o', color='w', markerfacecolor=dataset_colors[ds], label=ds,
                             markersize=10, markeredgecolor='black')
                      for ds in datasets
                  ] + [
                      Line2D([0], [0], marker=split_markers[sp], color='w', markerfacecolor='gray', label=sp,
                             markersize=10, markeredgecolor='black')
                      for sp in splits
                  ] + [
                      Line2D([0], [0], color='black', label=method)
                      for method in methods
                  ]

# Customize plots
ax1.set_ylabel('Accuracy')
ax1.set_title('Accuracy of Methods Across Datasets and Splits')
ax1.set_ylim(50, 100)  # Adjusted for percentage values with labels
ax1.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left', title='Datasets, Splits, and Methods')
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Dataset and Split')
ax2.set_ylabel('F1-Score')
ax2.set_title('F1-Score of Methods Across Datasets and Splits')
ax2.set_xticks(x)
ax2.set_xticklabels(dataset_splits, rotation=45, ha='right')
ax2.set_ylim(50, 100)  # Adjusted for percentage values with labels
ax2.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left', title='Datasets, Splits, and Methods')
ax2.grid(True, alpha=0.3)

# Adjust layout
plt.tight_layout()

# Save plot
plt.savefig('methods_datasets_with_labels.png', dpi=300, bbox_inches='tight')
plt.savefig('methods_datasets_with_labels.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

print("Visualization with labels saved as 'methods_datasets_with_labels.pdf'")