import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import rc
plt.rcParams['text.usetex'] = True
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath,amsfonts}'  #for \text command



# Configure Matplotlib to use LaTeX
rc('text', usetex=True)
rc('font', family='serif')
rc('font', size=16)

# Modify the Q_star function to have 3 peaks with more significant height differences
# Generate a meshgrid for action space (a1, a2)
a1_values, a2_values = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))


def Q_star_3_significantly_diff_height_peaks(a1, a2):
    peak1 = 0.5 * np.exp(-((a1 - 0.2)**2 + (a2 - 0.3)**2) / (2 * 0.005))  # Much shorter peak
    peak2 = 1.5 * np.exp(-((a1 - 0.5)**2 + (a2 - 0.5)**2) / (2 * 0.005))  # Tallest peak
    peak3 = 0.7 * np.exp(-((a1 - 0.7)**2 + (a2 - 0.7)**2) / (2 * 0.005))  # Another shorter peak
    return peak1 + peak2 + peak3

# Generate new Q*(a1, a2) values for a fixed state with 3 peaks with significantly different heights
q_star_3_significantly_diff_height_peaks_values = Q_star_3_significantly_diff_height_peaks(a1_values, a2_values)

# Save the plots to a PDF file
pdf_file_path = 'q_function_plots.pdf'
# Create a PDF file to save the plots
pdf_pages = PdfPages(pdf_file_path)

# Replotting with the previous Q* curve
fig, axs = plt.subplots(1, 3, figsize=(21, 6), subplot_kw={'projection': '3d'})



# Lighter shade of purple for the planes, with 90% transparency
light_purple_rgba = (0.8, 0.6, 0.8, 0.1)




# Adjusting the previous actions selected by Q1 and Q2 to correspond to the suboptimal peaks
prev_actions_Q1 = [0.25, 0.25]  # Lower of the two bottom peaks
prev_actions_Q2 = [0.25, 0.25, 0.55, 0.55]  # Medium suboptimal peak
# Recalculate Q1 and Q2 using the new previous actions
q_Q1_values = np.maximum(q_star_3_significantly_diff_height_peaks_values, q_star_3_significantly_diff_height_peaks_values[int(prev_actions_Q1[0]*100), int(prev_actions_Q1[1]*100)])
q_Q2_values = np.maximum(q_Q1_values, q_star_3_significantly_diff_height_peaks_values[int(prev_actions_Q2[2]*100), int(prev_actions_Q2[3]*100)])

# Q* plot
axs[0].plot_surface(a1_values, a2_values, q_star_3_significantly_diff_height_peaks_values, cmap='plasma', alpha=0.7)
# axs[0].set_title('$Q^*(a[0], a[1])$')
# axs[0].set_title(r'$\mathcal{Q}(a_1, a_2)$')
# axs[0].set_title(r'$\mathbb{Q} (a[0], a[1])$')

# axs[0].set_xlabel('$a[0]$')
# axs[0].set_ylabel('$a[1]$')
# axs[0].set_zlabel(r'$\mathbb{Q}$')
# axs[0].text2D(0.05, 0.95, r'$\mathbb{Q}$', transform=axs[0].transAxes)

# Q1 plot
axs[1].plot_surface(a1_values, a2_values, q_Q1_values, cmap='plasma', alpha=0.7)
axs[1].plot_surface(a1_values, a2_values, np.full_like(a1_values, q_star_3_significantly_diff_height_peaks_values[int(prev_actions_Q1[0]*100), int(prev_actions_Q1[1]*100)]), alpha=0.1, color=light_purple_rgba)
axs[1].plot_surface(a1_values, a2_values, q_star_3_significantly_diff_height_peaks_values, alpha=0.3, color='purple')  # Overridden part of Q* in translucent purple
# axs[1].set_title('$Q_1(a[0], a[1])$')
# axs[1].set_xlabel('$a[0]$')
# axs[1].set_ylabel('$a[1]$')
# axs[1].set_zlabel('$Q_1$')

# Q2 plot
axs[2].plot_surface(a1_values, a2_values, q_Q2_values, cmap='plasma', alpha=0.7)
axs[2].plot_surface(a1_values, a2_values, np.full_like(a1_values, q_star_3_significantly_diff_height_peaks_values[int(prev_actions_Q2[2]*100), int(prev_actions_Q2[3]*100)]), alpha=0.1, color=light_purple_rgba)
axs[2].plot_surface(a1_values, a2_values, q_star_3_significantly_diff_height_peaks_values, alpha=0.3, color='purple')  # Overridden part of Q* in translucent purple
# axs[2].set_title('$Q_2(a[0], a[1])$')
# axs[2].set_xlabel('$a[0]$')
# axs[2].set_ylabel('$a[1]$')
# axs[2].set_zlabel('$Q_2$')

axs[0].grid(False)
axs[1].grid(False)
axs[2].grid(False)

axs[0].set_xticklabels([])
axs[0].set_yticklabels([])
axs[0].set_zticklabels([])

axs[1].set_xticklabels([])
axs[1].set_yticklabels([])
axs[1].set_zticklabels([])

axs[2].set_xticklabels([])
axs[2].set_yticklabels([])
axs[2].set_zticklabels([])

svg_file_path = 'q_function_plots.svg'

plt.tight_layout()
fig.patch.set_visible(False)
axs[0].axis('off')  # turn off the axis
axs[1].axis('off')  # turn off the axis
axs[2].axis('off')  # turn off the axis
# Save the current figure to the PDF file
# pdf_pages.savefig(fig, transparent=True)
plt.savefig(svg_file_path, transparent=True, format='svg')
# Close the PDF file
# pdf_pages.close()

# plt.show()