from utilsrc import *
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl

mpl.use('Agg')

from joblib import Parallel, delayed
import multiprocessing
import matplotlib.patches as mpatches

import numpy as np
copies = 1


eta = 0.2
mu = 1
epochs = 10000
# start_x, end_x, count_x = 3, 3.4, 1000
# start_y, end_y, count_y = 0.1, 0.5, 1000
start_x, end_x, count_x = 0, 4, 200
start_y, end_y, count_y = 0, 2.5, 200
# start_x, end_x, count_x = 0.25, 0.35, 400
# start_y, end_y, count_y = 2.2, 2.3, 400
x_grid = np.linspace(start_x, end_x, count_x)
y_grid = np.linspace(start_y, end_y, count_y)
x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)
xy_grid = np.c_[x_mesh.ravel(), y_mesh.ravel()]
exp_name = f"3num_square_eta{eta}_{start_x}-{end_x}-{count_x}_epoch{epochs}"

eos_mins = get_eos_mins(eta, mu, copies)

def get_sharpness_helper(x, y, eta, epochs=epochs):
    eos, ls, ss, traj = train_net_vec([x ** 2, y, y], eta, epochs, mu)
    loss_avg = -1 if eos == -1 else np.mean(ls[-10:])
    sharpness_avg = -1 if eos == -1 else np.mean(ss[-10:])
    x_min, y_min = [1, 1] if eos == -1 else traj[-1][:2]
    min_angle = crd2angle(x_min, y_min)
    return eos, loss_avg, sharpness_avg, min_angle

res = np.array(Parallel(n_jobs=multiprocessing.cpu_count(), verbose=5)(delayed(get_sharpness_helper)(x, y, eta) for (x,y) in xy_grid))
torch.save(res, f'./data/boundary_vis/{exp_name}.pkl')
eos = res[:, 0].reshape(count_x, count_y)
loss_record = res[:, 1].reshape(count_x, count_y)
sharpness_record = res[:, 2].reshape(count_x, count_y)
min_angle = res[:, 3].reshape(count_x, count_y)
eos_mask = (eos == -1)
eos_mask_rev = np.ma.masked_where((eos != -1), np.ones_like(sharpness_record))

# Plot
fig = plt.figure(figsize=(12, 12))
gs = gridspec.GridSpec(nrows=2, ncols=2)

ax = fig.add_subplot(gs[0,0])
loss_masked = np.ma.masked_where(eos_mask, np.log10(loss_record))
im = ax.imshow(loss_masked, origin='lower', extent=(start_x, end_x, start_y, end_y), cmap='binary')
ax.set_title('Converging Loss (log)')
eos_conf_ax(ax, fig, eos_mins, mu, copies)
fig.colorbar(im, ax=ax, shrink=0.69)

ax = fig.add_subplot(gs[0,1])
min_angle_masked = np.ma.masked_where(eos_mask, min_angle)
im = ax.imshow(min_angle_masked, origin='lower', extent=(start_x, end_x, start_y, end_y), cmap='rainbow')
ax.set_title('Converging Angle')
eos_conf_ax(ax, fig, eos_mins, mu, copies)
fig.colorbar(im, ax=ax, shrink=0.6)

ax = fig.add_subplot(gs[1,0])
sharpness_masked = np.ma.masked_where(eos_mask, sharpness_record)
vmax_eos = 2 / eta if np.sum(eos == 0) != 0 else None
im = ax.imshow(sharpness_masked, origin='lower', extent=(start_x, end_x, start_y, end_y), norm=mpl.colors.Normalize(vmin=4*mu**(3/2), vmax=vmax_eos, clip=True), cmap='coolwarm')
ax.set_title('Converging Sharpness')
eos_conf_ax(ax, fig, eos_mins, mu, copies)
fig.colorbar(im, ax=ax, shrink=0.6)
# plt.imshow(eos_mask_rev, origin='lower', extent=(start_x, end_x, start_y, end_y), cmap='binary', norm=mpl.colors.Normalize(vmin=0, vmax=1))

ax = fig.add_subplot(gs[1,1])
cmap = {-1:[0,0,0,1],0:[1.0,0.1,0.1,1],1:[1.0,0.5,0.1,1], 2:[0.1,0.1,1.0,1]}
labels = {1:'Lower', 2:'Flattest', 0:'EoS',-1:'Diverge'}
patches = [mpatches.Patch(color=cmap[i],label=labels[i]) for i in cmap]
arrayShow = np.array([[cmap[i] for i in j] for j in eos]) 
im = ax.imshow(arrayShow, origin='lower', extent=(start_x, end_x, start_y, end_y))
ax.legend(handles=patches, borderaxespad=0.)
eos_conf_ax(ax, fig, eos_mins, mu, copies)


plt.suptitle(r'{} layers $\mu={}, \eta={}$ {} epochs'.format(copies * 2, mu, eta, epochs))
plt.tight_layout()
plt.savefig(f'./data/boundary_vis/figs/xyy_{exp_name}.jpg', dpi=400)


# print(eos)