import matplotlib.pyplot as plt
import numpy as np
import pickle
import os

from scipy.interpolate import RectBivariateSpline

def plot_grid(values,depth,width,filename):
    
    width = np.log2(width)#np.arange(1, values.shape[1]+1)
    X, Y = np.meshgrid(depth,width)
    Z = values.T  # Replace this with your 2D array

    # Create finer grid for smooth visualization
    depths_fine = np.linspace(min(depth), max(depth), 50)
    width_log_fine = np.linspace(min(width), max(width), 50)
    X_fine, Y_fine = np.meshgrid(depths_fine, width_log_fine)
    vmin = min(values[0].min(), values[1].min())
    vmax = max(values[0].max(), values[1].max())

    # Create a 3D plot
    fig = plt.figure(figsize=(6, 5))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the surface
    #surf = ax.plot_surface(X, Y, Z, cmap='coolwarm')
    spline = RectBivariateSpline(width, depth, Z, 
                                kx=3, ky=3,  # cubic spline
                                s=0.001)  # smoothing factor, increase for more smoothing
    
    Z_fine = spline(width_log_fine, depths_fine)
    
    surf = ax.plot_surface(X_fine, Y_fine, Z_fine, 
                          cmap='coolwarm',
                          vmin=vmin, vmax=vmax,
                          antialiased=True)
    

    # Add labels
    ax.set_xlabel('depth',fontsize = 14)
    ax.set_ylabel(r'$\operatorname{log}_2(\text{width})$',fontsize = 14)
    ax.set_zlabel('regret', labelpad=5,fontsize = 14)
    # Adjust the view angle
    ax.view_init(elev=20, azim=15)
    # Add a color bar
    color_bar = fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10)

    # Show the plot
    plt.savefig("%s.png" %(filename))
    plt.show()




list_widths = [16,32,64,128,256,512,1024,2048]
list_depths = [4,5,6,7,8,9,10,11,12,13]
list_sim=[i for i in range(20)]


from fnmatch import fnmatch

the_folder = "./result/"
list_allfile = os.listdir(the_folder)
Regret = np.zeros([len(list_depths),len(list_widths)])

for i in range(len(list_depths)):
    for j in range(len(list_widths)):
        The_DATA_MARK = '*dCOR_idata_The_widths_%s_depths_%s_*' %(list_widths[j],list_depths[i])
        list_file = []
        for ifile in list_allfile:
            if fnmatch(ifile,The_DATA_MARK):
                list_file.append(ifile)
        len(list_file)
        the_res = []
        for idx_data in range(len(list_file)):
            with open(os.path.join(the_folder,  list_file[idx_data]), 'rb') as file:
                ff = pickle.load(file)
                the_res.append(ff['reg_test'].item())
        Regret[i,j] = np.mean(the_res)

plot_grid(np.array(Regret),list_depths,list_widths,'Regret')