import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from matplotlib.patches import Rectangle
import os

# Read CSV file
csv_file = 'run_results.csv'
results_df = pd.read_csv(csv_file)
results_df['Dataset'] = results_df['Dataset'].str.lower()
results_df['Split'] = results_df['Split'].str.lower()
results_df['Method'] = results_df['Method'].str.upper()

# Extract unique datasets, splits, and methods
dataset_splits = [f"{row['Dataset']}-{row['Split']}" for _, row in results_df.iterrows()]
datasets = sorted(results_df['Dataset'].unique().tolist())
splits = sorted(results_df['Split'].unique().tolist())
methods = sorted(results_df['Method'].unique().tolist())

# Extract metrics
accuracy = results_df['Test Accuracy (%)'].values
f1 = results_df['Test F1 (%)'].values
accuracy_std = results_df['Test Accuracy Std (%)'].values
f1_std = results_df['Test F1 Std (%)'].values

# Enhanced color palette
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
method_colors = {method: colors[i % len(colors)] for i, method in enumerate(methods)}
dataset_colors = {ds: colors[i % len(colors)] for i, ds in enumerate(datasets)}
markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']
split_markers = {sp: markers[i % len(markers)] for i, sp in enumerate(splits)}

# Create vis directory if it doesn't exist
os.makedirs('vis', exist_ok=True)

# Set style
plt.style.use('seaborn-whitegrid')
plt.rcParams.update({
    'font.size': 10,
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 9,
    'ytick.labelsize': 10,
    'legend.fontsize': 9,
    'figure.figsize': (12, 8),
    'axes.grid': True,
    'grid.alpha': 0.3
})

# 1. LINE PLOTS - Accuracy and F1 Score
fig, (ax_acc_line, ax_f1_line) = plt.subplots(1, 2, figsize=(16, 6))

x_pos = np.arange(len(dataset_splits))

# Accuracy line plot
for method in methods:
    method_mask = results_df['Method'] == method
    method_accuracy = accuracy[method_mask]
    method_x = x_pos[method_mask]

    ax_acc_line.plot(method_x, method_accuracy,
                     marker='o', linewidth=2, markersize=6,
                     color=method_colors[method], label=method,
                     alpha=0.8)
    ax_acc_line.fill_between(method_x,
                             method_accuracy - accuracy_std[method_mask],
                             method_accuracy + accuracy_std[method_mask],
                             alpha=0.2, color=method_colors[method])

ax_acc_line.set_title('Accuracy Trends Across Dataset-Split Combinations', fontweight='bold')
ax_acc_line.set_ylabel('Accuracy (%)')
ax_acc_line.set_xticks(x_pos)
ax_acc_line.set_xticklabels(dataset_splits, rotation=45, ha='right')
ax_acc_line.legend(loc='best')
ax_acc_line.grid(True, alpha=0.3)

# F1 Score line plot
for method in methods:
    method_mask = results_df['Method'] == method
    method_f1 = f1[method_mask]
    method_x = x_pos[method_mask]

    ax_f1_line.plot(method_x, method_f1,
                    marker='s', linewidth=2, markersize=6,
                    color=method_colors[method], label=method,
                    alpha=0.8)
    ax_f1_line.fill_between(method_x,
                            method_f1 - f1_std[method_mask],
                            method_f1 + f1_std[method_mask],
                            alpha=0.2, color=method_colors[method])

ax_f1_line.set_title('F1-Score Trends Across Dataset-Split Combinations', fontweight='bold')
ax_f1_line.set_ylabel('F1-Score (%)')
ax_f1_line.set_xticks(x_pos)
ax_f1_line.set_xticklabels(dataset_splits, rotation=45, ha='right')
ax_f1_line.legend(loc='best')
ax_f1_line.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vis/01_line_plots_trends.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/01_line_plots_trends.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

# 2. ENHANCED SCATTER PLOTS with Labels
fig, (ax_acc_scatter, ax_f1_scatter) = plt.subplots(2, 1, figsize=(8, 8))

# Accuracy scatter plot
for i, row in results_df.iterrows():
    dataset = row['Dataset']
    split = row['Split']
    method = row['Method']
    acc_val = row['Test Accuracy (%)']
    acc_std_val = row['Test Accuracy Std (%)']

    ax_acc_scatter.errorbar(i, acc_val, yerr=acc_std_val,
                            color=dataset_colors[dataset],
                            marker=split_markers[split],
                            linestyle='none', capsize=4, markersize=10,
                            markeredgecolor='black', markeredgewidth=1,
                            alpha=0.8)

    # Add value labels
    ax_acc_scatter.annotate(f'{method}\n{acc_val:.1f}%',
                            (i, acc_val),
                            textcoords="offset points",
                            xytext=(0, 15),
                            ha='center', va='bottom',
                            fontsize=8,
                            bbox=dict(boxstyle="round,pad=0.3",
                                      facecolor='white', alpha=0.8,
                                      edgecolor=dataset_colors[dataset]))

ax_acc_scatter.set_title(' Accuracy Results ', fontweight='bold')
ax_acc_scatter.set_ylabel('Accuracy (%)')
#ax_acc_scatter.set_xlabel('Experiment Index')
ax_acc_scatter.grid(True, alpha=0.3)
ax_acc_scatter.set_ylim(60, 100)

# F1 Score scatter plot
for i, row in results_df.iterrows():
    dataset = row['Dataset']
    split = row['Split']
    method = row['Method']
    f1_val = row['Test F1 (%)']
    f1_std_val = row['Test F1 Std (%)']

    ax_f1_scatter.errorbar(i, f1_val, yerr=f1_std_val,
                           color=dataset_colors[dataset],
                           marker=split_markers[split],
                           linestyle='none', capsize=4, markersize=10,
                           markeredgecolor='black', markeredgewidth=1,
                           alpha=0.8)

    # Add value labels
    ax_f1_scatter.annotate(f'{method}\n{f1_val:.1f}%',
                           (i, f1_val),
                           textcoords="offset points",
                           xytext=(0, 15),
                           ha='center', va='bottom',
                           fontsize=8,
                           bbox=dict(boxstyle="round,pad=0.3",
                                     facecolor='white', alpha=0.8,
                                     edgecolor=dataset_colors[dataset]))

ax_f1_scatter.set_title('F1-Score Results ', fontweight='bold')
ax_f1_scatter.set_ylabel('F1-Score (%)')
#ax_f1_scatter.set_xlabel('Experiment Index')
ax_f1_scatter.grid(True, alpha=0.3)
ax_f1_scatter.set_ylim(55, 100)

# Add legend for scatter plots
legend_elements = []
legend_elements.extend([
    Line2D([0], [0], marker='o', color='w', markerfacecolor=dataset_colors[ds],
           label=f'Dataset: {ds}', markersize=8, markeredgecolor='black')
    for ds in datasets
])
legend_elements.extend([
    Line2D([0], [0], marker=split_markers[sp], color='w', markerfacecolor='gray',
           label=f'Split: {sp}', markersize=8, markeredgecolor='black')
    for sp in splits
])

fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 0.02),
           ncol=len(datasets) + len(splits), title='Legend')

plt.tight_layout()
plt.subplots_adjust(bottom=0.15)
plt.savefig('vis/02_scatter_plots_detailed.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/02_scatter_plots_detailed.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

# 3. SUMMARY STATISTICS BAR PLOT
fig, ax_summary = plt.subplots(1, 1, figsize=(10, 6))

method_stats = results_df.groupby('Method').agg({
    'Test Accuracy (%)': ['mean', 'std', 'max', 'min'],
    'Test F1 (%)': ['mean', 'std', 'max', 'min']
}).round(2)

methods_list = list(method_stats.index)
acc_means = method_stats['Test Accuracy (%)']['mean'].values
f1_means = method_stats['Test F1 (%)']['mean'].values
acc_stds = method_stats['Test Accuracy (%)']['std'].values
f1_stds = method_stats['Test F1 (%)']['std'].values

x_methods = np.arange(len(methods_list))
width = 0.35

bars1 = ax_summary.bar(x_methods - width / 2, acc_means, width,
                       yerr=acc_stds, capsize=5,
                       label='Accuracy', alpha=0.8,
                       color=[method_colors[m] for m in methods_list])
bars2 = ax_summary.bar(x_methods + width / 2, f1_means, width,
                       yerr=f1_stds, capsize=5,
                       label='F1-Score', alpha=0.8,
                       color=[method_colors[m] for m in methods_list],
                       hatch='///')

# Add value labels on bars
for bar in bars1:
    height = bar.get_height()
    ax_summary.annotate(f'{height:.1f}%',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

for bar in bars2:
    height = bar.get_height()
    ax_summary.annotate(f'{height:.1f}%',
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3),
                        textcoords="offset points",
                        ha='center', va='bottom', fontsize=9)

ax_summary.set_title('Average Performance by Method (with Standard Deviation)', fontweight='bold')
ax_summary.set_ylabel('Performance (%)')
ax_summary.set_xlabel('Methods')
ax_summary.set_xticks(x_methods)
ax_summary.set_xticklabels(methods_list)
ax_summary.legend()
ax_summary.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vis/03_summary_statistics.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/03_summary_statistics.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

# 4. PERFORMANCE COMPARISON (Accuracy vs F1)
fig, ax_comparison = plt.subplots(1, 1, figsize=(10, 8))

for i, row in results_df.iterrows():
    ax_comparison.scatter(row['Test Accuracy (%)'], row['Test F1 (%)'],
                          color=dataset_colors[row['Dataset']],
                          marker=split_markers[row['Split']],
                          s=120, alpha=0.7, edgecolor='black', linewidth=1)
    ax_comparison.annotate(row['Method'],
                           (row['Test Accuracy (%)'], row['Test F1 (%)']),
                           xytext=(5, 5), textcoords='offset points',
                           fontsize=9, alpha=0.8, fontweight='bold')

ax_comparison.set_title('Accuracy vs F1-Score Correlation Analysis', fontweight='bold')
ax_comparison.set_xlabel('Accuracy (%)')
ax_comparison.set_ylabel('F1-Score (%)')

# Add perfect correlation line
min_val = min(min(results_df['Test Accuracy (%)']), min(results_df['Test F1 (%)']))
max_val = max(max(results_df['Test Accuracy (%)']), max(results_df['Test F1 (%)']))
ax_comparison.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5,
                   label='Perfect correlation', linewidth=2)

# Add legend
legend_elements = []
legend_elements.extend([
    Line2D([0], [0], marker='o', color='w', markerfacecolor=dataset_colors[ds],
           label=f'Dataset: {ds}', markersize=8, markeredgecolor='black')
    for ds in datasets
])
legend_elements.extend([
    Line2D([0], [0], marker=split_markers[sp], color='w', markerfacecolor='gray',
           label=f'Split: {sp}', markersize=8, markeredgecolor='black')
    for sp in splits
])
legend_elements.append(Line2D([0], [0], color='k', linestyle='--',
                              label='Perfect correlation'))

ax_comparison.legend(handles=legend_elements, loc='best')
ax_comparison.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vis/04_correlation_analysis.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/04_correlation_analysis.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

# 5. PERFORMANCE HEATMAP
fig, ax_heatmap = plt.subplots(1, 1, figsize=(12, 8))

# Create heatmap for accuracy
pivot_acc = results_df.pivot_table(values='Test Accuracy (%)',
                                   index='Method',
                                   columns='Dataset',
                                   aggfunc='mean')

im = ax_heatmap.imshow(pivot_acc.values, cmap='RdYlGn', aspect='auto',
                       vmin=pivot_acc.min().min(), vmax=pivot_acc.max().max())

# Add text annotations
for i in range(len(pivot_acc.index)):
    for j in range(len(pivot_acc.columns)):
        if not pd.isna(pivot_acc.iloc[i, j]):
            text = ax_heatmap.text(j, i, f'{pivot_acc.iloc[i, j]:.1f}%',
                                   ha="center", va="center", color="black",
                                   fontweight='bold', fontsize=10)

ax_heatmap.set_title('Performance Heatmap: Accuracy (%) by Method and Dataset',
                     fontweight='bold', pad=20)
ax_heatmap.set_xticks(range(len(pivot_acc.columns)))
ax_heatmap.set_xticklabels(pivot_acc.columns, rotation=45, ha='right')
ax_heatmap.set_yticks(range(len(pivot_acc.index)))
ax_heatmap.set_yticklabels(pivot_acc.index)

# Add colorbar
cbar = plt.colorbar(im, ax=ax_heatmap, shrink=0.8)
cbar.set_label('Accuracy (%)', rotation=270, labelpad=20)

plt.tight_layout()
plt.savefig('vis/05_performance_heatmap.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/05_performance_heatmap.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()

# 6. COMBINED PERFORMANCE OVERVIEW
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# Box plots for accuracy by method
method_data_acc = [results_df[results_df['Method'] == method]['Test Accuracy (%)'].values
                   for method in methods]
bp1 = ax1.boxplot(method_data_acc, labels=methods, patch_artist=True)
for patch, color in zip(bp1['boxes'], [method_colors[m] for m in methods]):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax1.set_title('Accuracy Distribution by Method', fontweight='bold')
ax1.set_ylabel('Accuracy (%)')
ax1.grid(True, alpha=0.3)

# Box plots for F1 by method
method_data_f1 = [results_df[results_df['Method'] == method]['Test F1 (%)'].values
                  for method in methods]
bp2 = ax2.boxplot(method_data_f1, labels=methods, patch_artist=True)
for patch, color in zip(bp2['boxes'], [method_colors[m] for m in methods]):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax2.set_title('F1-Score Distribution by Method', fontweight='bold')
ax2.set_ylabel('F1-Score (%)')
ax2.grid(True, alpha=0.3)

# Performance by dataset
dataset_acc_means = results_df.groupby('Dataset')['Test Accuracy (%)'].mean()
dataset_f1_means = results_df.groupby('Dataset')['Test F1 (%)'].mean()

x_datasets = np.arange(len(datasets))
bars_acc = ax3.bar(x_datasets, dataset_acc_means.values,
                   color=[dataset_colors[d] for d in datasets], alpha=0.8)
ax3.set_title('Average Accuracy by Dataset', fontweight='bold')
ax3.set_ylabel('Average Accuracy (%)')
ax3.set_xticks(x_datasets)
ax3.set_xticklabels(datasets)
ax3.grid(True, alpha=0.3)

# Add value labels
for bar in bars_acc:
    height = bar.get_height()
    ax3.annotate(f'{height:.1f}%',
                 xy=(bar.get_x() + bar.get_width() / 2, height),
                 xytext=(0, 3),
                 textcoords="offset points",
                 ha='center', va='bottom', fontsize=9)

bars_f1 = ax4.bar(x_datasets, dataset_f1_means.values,
                  color=[dataset_colors[d] for d in datasets], alpha=0.8)
ax4.set_title('Average F1-Score by Dataset', fontweight='bold')
ax4.set_ylabel('Average F1-Score (%)')
ax4.set_xticks(x_datasets)
ax4.set_xticklabels(datasets)
ax4.grid(True, alpha=0.3)

# Add value labels
for bar in bars_f1:
    height = bar.get_height()
    ax4.annotate(f'{height:.1f}%',
                 xy=(bar.get_x() + bar.get_width() / 2, height),
                 xytext=(0, 3),
                 textcoords="offset points",
                 ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('vis/06_performance_overview.png', dpi=300, bbox_inches='tight')
plt.savefig('vis/06_performance_overview.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.close()
for i in range(len(pivot_acc.index)):
    for j in range(len(pivot_acc.columns)):
        text = ax_heatmap.text(j, i, f'{pivot_acc.iloc[i, j]:.1f}%',
                               ha="center", va="center", color="black", fontweight='bold')

ax_heatmap.set_title('Performance Heatmap (Accuracy % by Method and Dataset)', fontweight='bold')
ax_heatmap.set_xticks(range(len(pivot_acc.columns)))
ax_heatmap.set_xticklabels(pivot_acc.columns)
ax_heatmap.set_yticks(range(len(pivot_acc.index)))
ax_heatmap.set_yticklabels(pivot_acc.index)

# Add colorbar
cbar = plt.colorbar(im, ax=ax_heatmap, shrink=0.8)
cbar.set_label('Accuracy (%)', rotation=270, labelpad=20)

# Create comprehensive legend
legend_elements = []
# Dataset colors
legend_elements.extend([
    Line2D([0], [0], marker='o', color='w', markerfacecolor=dataset_colors[ds],
           label=f'Dataset: {ds}', markersize=8, markeredgecolor='black')
    for ds in datasets
])
# Split markers
legend_elements.extend([
    Line2D([0], [0], marker=split_markers[sp], color='w', markerfacecolor='gray',
           label=f'Split: {sp}', markersize=8, markeredgecolor='black')
    for sp in splits
])

# Add legend outside the plot area
fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.85, 0.15),
           ncol=1, title='Legend', title_fontsize=12, fontsize=10)

plt.tight_layout()
plt.subplots_adjust(right=0.85)

# Print summary statistics and file information
print("All visualizations saved in 'vis' directory:")
print("  - 01_line_plots_trends.png/pdf - Line plots showing trends with confidence bands")
print("  - 02_scatter_plots_detailed.png/pdf - Detailed scatter plots with method labels")
print("  - 03_summary_statistics.png/pdf - Bar charts of average performance with std dev")
print("  - 04_correlation_analysis.png/pdf - Accuracy vs F1-Score correlation plot")
print("  - 05_performance_heatmap.png/pdf - Heatmap of method performance by dataset")
print("  - 06_performance_overview.png/pdf - Combined overview with box plots and dataset analysis")

# Calculate and print summary statistics
method_stats = results_df.groupby('Method').agg({
    'Test Accuracy (%)': ['mean', 'std', 'max', 'min'],
    'Test F1 (%)': ['mean', 'std', 'max', 'min']
}).round(2)

print("\nSummary Statistics by Method:")
print(method_stats)

print(f"\nDatasets analyzed: {datasets}")
print(f"Splits analyzed: {splits}")
print(f"Methods compared: {methods}")
print(f"Total experiments: {len(results_df)}")

# Save summary statistics to CSV
method_stats.to_csv('vis/summary_statistics.csv')
print("\nSummary statistics also saved to 'vis/summary_statistics.csv'")