import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
from matplotlib.colors import LinearSegmentedColormap

# Load the existing PKL file
pkl_file = 'f1_loss_landscape_reload.pkl'  # Update this to your actual filename

print(f"Loading data from {pkl_file}...")
with open(pkl_file, 'rb') as f:
    cache = pickle.load(f)

# Extract data
alphas = cache['alphas']
betas = cache['betas']
f1_loss_surface = cache['f1_loss']
acc_surface = cache['accuracy']

# Original metrics
orig_f1_loss = cache['original_f1_loss']
orig_acc = cache['original_accuracy']

print(f"Original metrics - F1 Loss: {orig_f1_loss:.4f}, Accuracy: {orig_acc:.2f}%")

# Convert accuracy to error rate (1 - accuracy/100)
error_surface = 1 - acc_surface/100
orig_error = 1 - orig_acc/100

print(f"Original error rate (1 - Accuracy/100): {orig_error:.4f}")

# Filter out NaN values for visualization
f1_loss_surface_masked = np.ma.masked_invalid(f1_loss_surface)
error_surface_masked = np.ma.masked_invalid(error_surface)

# Set plot aesthetics
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'font.size': 16,
    'axes.labelsize': 16,
    'axes.titlesize': 16,
    'xtick.labelsize': 32,
    'ytick.labelsize': 32,
    'legend.fontsize': 40,
    'figure.titlesize': 18
})

# Custom color maps for better visualization
error_cmap = plt.cm.plasma_r  # Reversed plasma colormap - darker is higher error

# Create high-resolution plots with proper axes and formatting
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 9), dpi=150)

from matplotlib.colors import Normalize

# F1 Loss landscape
levels_f1 = np.linspace(np.min(f1_loss_surface_masked), np.max(f1_loss_surface_masked), 25)
cf1 = ax1.contourf(alphas, betas, f1_loss_surface_masked.T, levels=levels_f1, cmap='viridis', alpha=0.95)
ctr1 = ax1.contour(alphas, betas, f1_loss_surface_masked.T, levels=10, colors='white', alpha=0.5, linewidths=0.8)
ax1.clabel(ctr1, inline=True, fontsize=10, fmt='%.3f')
cbar1 = fig.colorbar(cf1, ax=ax1, pad=0.01)
cbar1.set_label('F1 Loss (1-F1/100)', fontsize=14)
ax1.scatter([0], [0], color='red', marker='*', s=250, label='Original model')
ax1.set_title("F1 Loss Landscape", fontsize=16, pad=20)
ax1.set_xlabel("Direction 1 (α)", fontsize=14)
ax1.set_ylabel("Direction 2 (β)", fontsize=14)
ax1.legend(loc='upper left', framealpha=0.9)
ax1.grid(True, linestyle='--', alpha=0.7)
ax1.set_aspect('equal')

# Find minimum points
min_f1_idx = np.unravel_index(np.argmin(f1_loss_surface_masked), f1_loss_surface_masked.shape)
min_f1_alpha = alphas[min_f1_idx[0]]
min_f1_beta = betas[min_f1_idx[1]]
min_f1_loss = f1_loss_surface_masked[min_f1_idx]

# Add annotation for minimum f1 loss point if it's different from origin
if abs(min_f1_alpha) > 0.001 or abs(min_f1_beta) > 0.001:
    ax1.scatter([min_f1_alpha], [min_f1_beta], color='cyan', marker='o', s=400,
                edgecolor='black', label=f'Min loss: {min_f1_loss:.4f}')
    ax1.annotate(f'Min: ({min_f1_alpha:.3f}, {min_f1_beta:.3f})',
                xy=(min_f1_alpha, min_f1_beta), xytext=(min_f1_alpha+0.01, min_f1_beta+0.01),
                arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
                fontsize=10, backgroundcolor='white')
    ax1.legend()

# Error landscape (1 - Accuracy/100)
levels_err = np.linspace(np.min(error_surface_masked), np.max(error_surface_masked), 25)
cf2 = ax2.contourf(alphas, betas, error_surface_masked.T, levels=levels_err, cmap=error_cmap, alpha=0.95)
ctr2 = ax2.contour(alphas, betas, error_surface_masked.T, levels=10, colors='white', alpha=0.5, linewidths=0.8)
ax2.clabel(ctr2, inline=True, fontsize=15, fmt='%.3f')
cbar2 = fig.colorbar(cf2, ax=ax2, pad=0.01)
# cbar2.set_label('Error Rate (1-Accuracy/100)', fontsize=14)
ax2.scatter([0], [0], color='red', marker='*', s=400, label='Original model')
# ax2.set_title("Error Rate Landscape", fontsize=16, pad=20)
ax2.set_xlabel("Direction 1 (α)", fontsize=17)
ax2.set_ylabel("Direction 2 (β)", fontsize=17)
ax2.legend(loc='upper left', framealpha=0.9)
ax2.grid(True, linestyle='--', alpha=0.7)
ax2.set_aspect('equal')

# Find minimum error point
min_err_idx = np.unravel_index(np.argmin(error_surface_masked), error_surface_masked.shape)
min_err_alpha = alphas[min_err_idx[0]]
min_err_beta = betas[min_err_idx[1]]
min_err = error_surface_masked[min_err_idx]


# Create a highly polished single Error Rate landscape plot
plt.figure(figsize=(12, 10), dpi=200)
ax = plt.gca()

# Create a high-detail contour plot with refined levels and smoother transitions
vmin = 0.24
vmax = 0.41
norm = Normalize(vmin=vmin, vmax=vmax)
levels_re = np.linspace(vmin, vmax, 50)
levels = np.linspace(np.min(error_surface_masked), np.max(error_surface_masked), 50)
cf = plt.contourf(alphas, betas, error_surface_masked.T, levels=levels_re, cmap=error_cmap, alpha=0.95)
ctr = plt.contour(alphas, betas, error_surface_masked.T, levels=12, colors='white', alpha=0.6, linewidths=0.8)
plt.clabel(ctr, inline=1, fontsize=20, fmt='%.3f')

# Add clearer markers
plt.scatter([0], [0], color='red', marker='*', s=800, label='Original model')
if abs(min_err_alpha) > 0.001 or abs(min_err_beta) > 0.001:
    plt.scatter([min_err_alpha], [min_err_beta], color='cyan', marker='o', s=600,
                edgecolor='black', label=f'Min error: {min_err:.4f}')
    # plt.annotate(f'Min: ({min_err_alpha:.3f}, {min_err_beta:.3f})',
    #             xy=(min_err_alpha, min_err_beta), xytext=(min_err_alpha+0.01, min_err_beta+0.01),
    #             arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
    #             fontsize=12, backgroundcolor='white')

# Improve axes and labels
plt.xlabel("Direction 1 (α)", fontsize=38)
plt.ylabel("Direction 2 (β)", fontsize=38)
# plt.title("Error Rate Landscape (1-Accuracy/100)", fontsize=18, pad=20)
plt.legend(loc='best', frameon=True, framealpha=0.9, fontsize=32)
plt.grid(True, linestyle='--', alpha=0.6)
plt.ylim(-99,99)

# Add colorbar with better formatting
cbar = plt.colorbar(cf, pad=0.03, shrink=0.95)
cbar.set_label('Error Rate (1-Accuracy/100)', fontsize=16, labelpad=10)
cbar.ax.tick_params(labelsize=30)

# Ensure equal aspect ratio
# plt.axis('equal')
plt.gca().set_aspect('auto')

plt.tight_layout()
plt.savefig(pkl_file + "_.png", dpi=300, bbox_inches='tight')








# # Print comparison of minima
# print("\nComparison of minimum points:")
# print(f"F1 Loss - Original: {orig_f1_loss:.4f}, Minimum: {min_f1_loss:.4f} at (α={min_f1_alpha:.4f}, β={min_f1_beta:.4f})")
# print(f"Error Rate - Original: {orig_error:.4f}, Minimum: {min_err:.4f} at (α={min_err_alpha:.4f}, β={min_err_beta:.4f})")
#
# # Calculate distance between original model and minimum points
# f1_min_distance = np.sqrt(min_f1_alpha**2 + min_f1_beta**2)
# err_min_distance = np.sqrt(min_err_alpha**2 + min_err_beta**2)
# print(f"\nDistance from origin to F1 loss minimum: {f1_min_distance:.4f}")
# print(f"Distance from origin to Error minimum: {err_min_distance:.4f}")
#
# # Calculate percentage improvement
# f1_improvement = (orig_f1_loss - min_f1_loss) / orig_f1_loss * 100
# err_improvement = (orig_error - min_err) / orig_error * 100
# print(f"\nPotential F1 loss improvement: {f1_improvement:.2f}%")
# print(f"Potential Error rate improvement: {err_improvement:.2f}%")
#
# print("\nEnhanced visualizations saved to:")
# print("- loss_and_error_landscape_2d.png")
# print("- loss_and_error_landscape_3d.png")
# print("- polished_error_landscape.png")




# Add annotation for minimum error point if it's different from origin
# if abs(min_err_alpha) > 0.001 or abs(min_err_beta) > 0.001:
#     ax2.scatter([min_err_alpha], [min_err_beta], color='cyan', marker='o', s=150,
#                 edgecolor='black', label=f'Min error: {min_err:.4f}')
#     ax2.annotate(f'Min: ({min_err_alpha:.3f}, {min_err_beta:.3f})',
#                 xy=(min_err_alpha, min_err_beta), xytext=(min_err_alpha+0.01, min_err_beta+0.01),
#                 arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
#                 fontsize=10, backgroundcolor='white')
#     ax2.legend()

# Add overall title
# fig.suptitle('Model Loss Landscape Analysis', fontsize=20, y=0.98)
# plt.tight_layout(rect=[0, 0, 1, 0.96])
# plt.savefig("loss_and_error_landscape_2d.png", dpi=300, bbox_inches='tight')

# Create 3D visualizations with enhanced aesthetics
# fig = plt.figure(figsize=(20, 9), dpi=150)
#
# # 3D F1 Loss landscape
# ax3 = fig.add_subplot(121, projection='3d')
# X, Y = np.meshgrid(alphas, betas)
# surf1 = ax3.plot_surface(X, Y, f1_loss_surface_masked.T, cmap='viridis',
#                         edgecolor='none', alpha=0.9, antialiased=True,
#                         rstride=1, cstride=1)
# ax3.contour(X, Y, f1_loss_surface_masked.T, zdir='z', offset=np.min(f1_loss_surface_masked)-0.005,
#             levels=10, cmap='viridis', alpha=0.5)
# fig.colorbar(surf1, ax=ax3, shrink=0.7, pad=0.1, label="F1 Loss (1-F1/100)")
# ax3.scatter([0], [0], [orig_f1_loss], color='red', s=200, marker='*', label='Original model')
# if abs(min_f1_alpha) > 0.001 or abs(min_f1_beta) > 0.001:
#     ax3.scatter([min_f1_alpha], [min_f1_beta], [min_f1_loss], color='cyan', s=150, marker='o',
#                 edgecolor='black', label=f'Min loss: {min_f1_loss:.4f}')
# ax3.set_xlabel('Direction 1 (α)', fontsize=14, labelpad=10)
# ax3.set_ylabel('Direction 2 (β)', fontsize=14, labelpad=10)
# ax3.set_zlabel('F1 Loss', fontsize=14, labelpad=10)
# ax3.set_title('3D F1 Loss Landscape', fontsize=16, pad=20)
# ax3.view_init(elev=30, azim=-60)
# ax3.grid(True, linestyle='--', alpha=0.5)
# ax3.legend()
#
# # 3D Error landscape
# ax4 = fig.add_subplot(122, projection='3d')
# surf2 = ax4.plot_surface(X, Y, error_surface_masked.T, cmap=error_cmap,
#                         edgecolor='none', alpha=0.9, antialiased=True,
#                         rstride=1, cstride=1)
# ax4.contour(X, Y, error_surface_masked.T, zdir='z', offset=np.min(error_surface_masked)-0.01,
#             levels=10, cmap=error_cmap, alpha=0.5)
# fig.colorbar(surf2, ax=ax4, shrink=0.7, pad=0.1, label="Error Rate (1-Accuracy/100)")
# ax4.scatter([0], [0], [orig_error], color='red', s=200, marker='*', label='Original model')
# if abs(min_err_alpha) > 0.001 or abs(min_err_beta) > 0.001:
#     ax4.scatter([min_err_alpha], [min_err_beta], [min_err], color='cyan', s=150, marker='o',
#                 edgecolor='black', label=f'Min error: {min_err:.4f}')
# ax4.set_xlabel('Direction 1 (α)', fontsize=14, labelpad=10)
# ax4.set_ylabel('Direction 2 (β)', fontsize=14, labelpad=10)
# ax4.set_zlabel('Error Rate', fontsize=14, labelpad=10)
# ax4.set_title('3D Error Rate Landscape', fontsize=16, pad=20)
# ax4.view_init(elev=30, azim=-60)
# ax4.grid(True, linestyle='--', alpha=0.5)
# ax4.legend()
#
# fig.suptitle('3D Model Loss Landscape Analysis', fontsize=20, y=0.98)
# plt.tight_layout(rect=[0, 0, 1, 0.96])
# plt.savefig("loss_and_error_landscape_3d.png", dpi=300, bbox_inches='tight')
