import numpy as np
import matplotlib.pyplot as plt
from itertools import product
import matplotlib
import seaborn as sns
import brewer2mpl
matplotlib.rcParams['font.family'] = 'serif'
matplotlib.rcParams['mathtext.fontset'] = 'cm'
bmap = brewer2mpl.get_map('Set1', 'qualitative', 7)
colors = bmap.mpl_colors
################################################################################################
# Himmelblau
################################################################################################ 
sample_point = 20
x1 = np.linspace(-np.pi, np.pi, sample_point)
x2 = np.linspace(-np.pi, np.pi, sample_point)

train_X = np.array(list(product(x1,x2)))
train_y = (train_X[:,0]**2+train_X[:,1]-1.5*np.pi)**2 + (train_X[:,0]+train_X[:,1]**2-np.pi)**2

min_y, max_y = np.min(train_y), np.max(train_y)
a, b = 2/(max_y-min_y), 1-2*max_y/(max_y-min_y)

train_y = train_y * a + b  #[-1,1]

###################################################################### target function
X, Y = np.meshgrid(x1, x2)

f = (X**2+Y-1.5*np.pi)**2 + (X+Y**2-np.pi)**2
f = f*a + b

fig=plt.figure()
fig.set_size_inches(10, 10)
ax = plt.axes(projection='3d')
plt.tick_params(labelsize=30)
ax.plot_surface(X, Y, f, cmap='viridis', edgecolor='none', linewidth=2)
ax.set_xlabel(r"$x$",fontsize=50, labelpad=20)
ax.set_ylabel(r"$y$",fontsize=50, labelpad=20)
ax.set_zlabel(r"$f(x,y)$", fontsize=50, labelpad=24, rotation=180)
ax.set_zticks(np.arange(-1,2,1))


plt.savefig('limitation/Himmelblau_function.png', dpi=600)


###################################################################### loss 
sample_nums = 5
fig=plt.figure()
fig.set_size_inches(12, 10)
sns.axes_style("ticks")
plt.tick_params(labelsize=26)
d = 0
ls = ['--','-.', 'dotted','-']
for i in [4,10,20,40]: # layers
    loss_list = []
    for j in range(sample_nums):
        loss = np.load('./limitation/data/train_loss_1_%s_1_%s.npy'%(i,j))
        loss_list.append(loss.reshape(-1))

    loss_list = np.array(loss_list)
    average_loss = np.average(loss_list, axis=0)
    std_loss = np.std(loss_list, axis=0)
    plt.plot(average_loss, label="QNN L=%s"%i, lw=6, ls=ls[d], c=colors[d] , alpha=0.8)
    d += 1

plt.legend(prop={'size': 40})
plt.xlabel(r'Iteration',fontsize=40)
plt.ylabel(r'MSE',fontsize=40)
plt.xticks(fontsize=36)
plt.yticks(fontsize=36)
plt.savefig('limitation/Himmelblau_loss.png', dpi=600)



###################################################################### predict 
predict_list = []
for j in range(sample_nums):
    predict = np.load('./limitation/data/predict_y_1_40_1_%s.npy'%j)
    predict_list.append(predict)

predict_list = np.array(predict_list)
average_predict = np.average(predict_list, axis=0)
average_predict = average_predict.reshape(20,20)
std_predict = np.std(predict_list, axis=0)

fig=plt.figure()
fig.set_size_inches(10, 10)
ax = plt.axes(projection='3d')
plt.tick_params(labelsize=30)
ax.plot_surface(Y, X, average_predict, cmap='viridis', edgecolor='none')
ax.set_xlabel(r"$x$", fontsize=50, labelpad=20)
ax.set_ylabel(r"$y$", fontsize=50, labelpad=20)
ax.set_zlabel(r"$f_{\mathbf{\theta},L}(x,y)$", fontsize=50, labelpad=24, rotation=0)
# ax.set_zlim(-1,1)

ax.set_zticks(np.arange(-1,2,1))
plt.savefig('limitation/Himmelblau_predict_function.png', dpi=600)
