import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import RectBivariateSpline
from matplotlib.ticker import FormatStrFormatter
from fnmatch import fnmatch
import pickle
import os





list_widths = [16,32,64,128,256,512,1024,2048]
list_depths = [4,5,6,7,8,9,10,11,12,13]



the_folder1 = "./CodeR_BTmodel/result"
the_folder2 = "./CodeR_THmodel/result"
list_allfile1 = os.listdir(the_folder1)
list_allfile2 = os.listdir(the_folder2)
Regret1 = np.zeros([len(list_depths),len(list_widths)])
Regret2 = np.zeros([len(list_depths),len(list_widths)])

for i in range(len(list_depths)):
    for j in range(len(list_widths)):
        The_DATA_MARK = '*dCOR_idata_The_widths_%s_depths_%s_*' %(list_widths[j],list_depths[i])
        list_file = []
        for ifile in list_allfile1:
            if fnmatch(ifile,The_DATA_MARK):
                list_file.append(ifile)
        len(list_file)
        the_res = []
        for idx_data in range(len(list_file)):
            with open(os.path.join(the_folder1,  list_file[idx_data]), 'rb') as file:
                ff = pickle.load(file)
                the_res.append(ff['reg_test'].item())
        Regret1[i,j] = np.mean(the_res)

for i in range(len(list_depths)):
    for j in range(len(list_widths)):
        The_DATA_MARK = '*dCOR_idata_The_widths_%s_depths_%s_*' %(list_widths[j],list_depths[i])
        list_file = []
        for ifile in list_allfile2:
            if fnmatch(ifile,The_DATA_MARK):
                list_file.append(ifile)
        len(list_file)
        the_res = []
        for idx_data in range(len(list_file)):
            with open(os.path.join(the_folder2,  list_file[idx_data]), 'rb') as file:
                ff = pickle.load(file)
                the_res.append(ff['reg_test'].item())
        Regret2[i,j] = np.mean(the_res)
# Original data points

Result=list([np.array(Regret1),np.array(Regret2)])




width_log = np.log2(list_widths)
X, Y = np.meshgrid(list_depths, width_log)

# Create finer grid for smooth visualization
depths_fine = np.linspace(min(list_depths), max(list_depths), 50)
width_log_fine = np.linspace(min(width_log), max(width_log), 50)
X_fine, Y_fine = np.meshgrid(depths_fine, width_log_fine)

fig = plt.figure(figsize=(12, 5))
gs = fig.add_gridspec(1, 3, width_ratios=[1, 1, 0.1])
axes = []

vmin = min(Result[0].min(), Result[1].min())
vmax = max(Result[0].max(), Result[1].max())

for j in range(2):
    ax = fig.add_subplot(gs[0, j], projection='3d')
    axes.append(ax)
axes = np.array(axes)

titles = [r"BT", r"Thurstonian"]

for idx, values in enumerate(Result):
    ax = axes[idx]
    Z = values.T
    
    # Create spline interpolation
    # Adjust smoothing factor s to control smoothness
    spline = RectBivariateSpline(width_log, list_depths, Z, 
                                kx=3, ky=3,  # cubic spline
                                s=0.001)  # smoothing factor, increase for more smoothing
    
    # Evaluate spline on fine grid
    Z_fine = spline(width_log_fine, depths_fine)
    
    surf = ax.plot_surface(X_fine, Y_fine, Z_fine, 
                          cmap='coolwarm',
                          vmin=vmin, vmax=vmax,
                          antialiased=True)
    vmin = min(Result[0].min(), Result[1].min())
    vmax = max(Result[0].max(), Result[1].max())
    ax.view_init(elev=20, azim=15)
    ax.set_xlabel('depth', fontsize=10)
    ax.set_ylabel(r'$\operatorname{log}_2(\text{width})$', fontsize=14)
    ax.set_zlabel('regret', labelpad=5, fontsize=14)
    ax.set_title(titles[idx], pad=-10, fontsize=14)
    ax.zaxis.set_major_formatter(FormatStrFormatter('%.3f'))


cax = fig.add_subplot(gs[0, -1])
sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=plt.Normalize(vmin=vmin, vmax=vmax))
sm.set_array([])
cbar = plt.colorbar(sm, cax=cax, shrink=0.8, aspect=10)
cax.set_position([cax.get_position().x0 - 0.02,
                 cax.get_position().y0 + 0.1,
                 cax.get_position().width * 0.6,
                 cax.get_position().height * 0.8])

plt.savefig(f"Two_Regrets.png", dpi=600)
plt.show()