import netCDF4
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.colors as colors
from scipy.special import binom
import seaborn as sns; sns.set_theme()

try:
    # rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
    # rc('text', usetex=False)
    matplotlib.rc('text', usetex=True)
    matplotlib.rcParams['text.latex.preamble']=[r"\usepackage{amsmath}", r"\usepackage{amsfonts}"]
    rc('font',family = 'sans-serif',  size=20)
    # rc('font',family = 'sans-serif',  size=20)
except:
    print("Latex is not loaded")



def number_of_harmonics(deg, dim):
    if dim == 2:
        return 2
    elif deg < dim - 2:
        binomial = np.prod(np.arange(dim - 2, deg + dim - 2) / np.arange(1, deg+1))
    else:
        binomial = np.prod(np.arange(deg + 1, deg + dim - 2) / np.arange(1, dim-2))
    # return (2 * deg + dim - 2) * binomial / (dim - 2)
    return (2 * (deg/(dim - 2)) + 1) * binomial
    # return (2 * deg + dim - 2) * scipy.special.binom(deg + dim - 3, dim - 3) / (dim - 2)

    
import scipy

def mean_confidence_interval(data, confidence=0.95, axis=0):
    a = np.array(data)
    n = len(a)
    m, se = np.mean(a, axis=axis), scipy.stats.sem(a)
    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
    return m, h

# res = torch.load("./heatmap_d3_noise1e-05.pth")
noise = 1e-6
d = 4
noise_str = "10^{{-" + str(int(f"{noise:.1e}".split("e-")[-1])) + "}}"
print(noise_str)
res = torch.load(f"./heatmap_d{d}_noise{noise}.pth")

htmp = np.zeros_like(res['heatmap'])

for i in range(htmp.shape[0]):
    for j in range(htmp.shape[1]):
        htmp[i,j] = (res['errors'][i,:,j] < 1 * res['l2norm_noise'][i,:,j]).mean()

xsize = 50
lsize = 32
tsize = 44
colors = plt.get_cmap('tab10').colors
markers = ('o', 'v', '^', '<', '>', '$\clubsuit$', 's', 'P', 'X', 'D', 'H', 'D', 'd', 'P', 'X')

fig,(ax1,axcb,ax) = plt.subplots(1,3, figsize=(18, 8), gridspec_kw={'width_ratios':[1,0.08,1]})

p = sns.heatmap(htmp.T[::-1,:], ax=ax1, vmin=0, vmax=1, cbar_ax=axcb, linewidths=0,  linecolor='black')

_ = p.set_yticks(np.arange(15)[::2] + 0.5)
_ = p.set_yticklabels(res['sall'][::-1][::2], rotation=0, fontsize=lsize)
_ = p.set_xticks(np.arange(len(res['qall'][::2])) * 2 + 1.5)
_ = p.set_xticklabels(res['qall'][::2]+1, rotation=0, fontsize=lsize)

p.set_ylabel("number of samples $s$", fontsize = xsize, labelpad=20)
p.set_xlabel("degree $q$", fontsize = xsize, labelpad=10)


# cbar = ax1.collections[0].colorbar
# cbar.ax.tick_params(labelsize=lsize)
ax1.set_title(f"noise scale ${noise_str}$", fontsize=tsize)
ax1.grid('off', linewidth=0)

for i, q_idx in enumerate(np.arange(0,14,2)):
    q = res['qall'][q_idx]
    x = res['sall']
#     y = res['errors'][q_idx,:,:]
    y = np.moveaxis(res['errors'][q_idx,:,:], 2, 1).reshape(-1, len(res['sall']))
    y_mean, y_std = mean_confidence_interval(y, 0.95, 0)

    ax.errorbar(x, y_mean, y_std, linewidth=2, markersize=10, color=colors[i], marker=markers[i], linestyle='-', capsize=7, capthick=1.2, ecolor="gray", label=f"$q={q}$")

# _ = p.set_yticklabels([f"${int(i)}$" for i in res['sall'][::-1]] , fontsize = lsize, rotation=0)
ax.set_yticks(np.logspace(-5, 0, 6))
ax.set_yticklabels(np.logspace(-5, 0, 6), rotation=0, fontsize=lsize)
# ax.set_yticklabel(rotation=0, fontsize=lsize)
if d == 3:
    ax.set_xticks(res['sall'][1::3])
    ax.set_xticklabels(res['sall'][1::3],  rotation=0, fontsize=lsize)
elif d == 4:
    ax.set_xticks(res['sall'][2::3])
    ax.set_xticklabels(res['sall'][2::3],  rotation=0, fontsize=lsize)

    

ax.set_yscale('log')
_ = ax.set_ylabel("avg. $\ell_2$-norm of errors", fontsize = xsize, labelpad=10)
_ = ax.set_xlabel("number of samples $s$", fontsize = xsize, labelpad=10)

ax.set_title(f"noise scale ${noise_str}$", fontsize=40)
# ax.set_title(f"noise scale $10^{{-5}}$", fontsize=lsize)
if d == 3:
    ax.plot([20, 620], [noise]*2, 'k',linestyle=((0,(3,1,1,1))),linewidth=3.5, zorder=3, label='noise')
elif d == 4:
    ax.plot([20, 2500], [noise]*2, 'k',linestyle=((0,(3,1,1,1))),linewidth=3.5, zorder=3, label='noise')
    

ax.spines['bottom'].set_color('k')
ax.spines['top'].set_color('k')
ax.spines['right'].set_color('k')
ax.spines['left'].set_color('k')
if d == 3:
    ax.set_xlim([20, 620])
elif d == 4:
    ax.set_xlim([20, 2520])

ax.legend(fontsize=lsize-4)

plt.tight_layout()
axcb.tick_params(labelsize=lsize)
axcb.set_position([0.42, 0.197, 0.03, 0.703])
# ax.set_position(matplotlib.transforms.Bbox([[0.5970523104569126, 0.20201817448247728], [0.8833333333333331, 0.942782443711516]]))
mg = 0.065
newbbox = matplotlib.transforms.Bbox([[ax.get_position().x0 - mg, ax.get_position().y0], [ax.get_position().x1-mg, ax.get_position().y1]])
ax.set_position(newbbox)

plt.savefig(f'./fig_heatmap_d{d}_noise.png', dpi=300, facecolor='w', format='png', bbox_inches = "tight")