import pickle
import pdb
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

grid_res = pickle.load(open("ml25grid.pkl", "rb"))
# print (grid_res)
factor = 100
# val2 = factor*0.649 + 2e-6*1090000 + (0.01-2e-6)*(6+6)
val3 = factor*0.658 + 2e-6*970000 + (0.01-3e-6)*(8+8)

item_as = [1,2,3,5,8,10,15,20,30,50,80,100,120,150,200,300,500,800,1000,1500,2000,3000,3500]
val_loss_mat = grid_res['val_loss_mat']
num_params_mat = grid_res['num_params_mat']
plt.semilogx(item_as[:19], [factor*val_loss_mat[1,i,i] + (3e-6)*num_params_mat[1,i,i] + 0.01*i for i in range(19)], linestyle='-', linewidth=2, marker='s', markersize=10, label='ANT')
plt.semilogx(item_as[:19], [val3]*19, linewidth=4)
plt.fill_between(item_as[:19], [val3-1]*19, [val3+1]*19,
   alpha=0.5, edgecolor=(1.0, 0.4980392156862745, 0.054901960784313725), facecolor=(1.0, 0.4980392156862745, 0.054901960784313725), label='NBANT')

#val2 = 100*0.649 + 2e-6*1090000 + (0.01-2e-6)*(8+8)
#val3 = 67.5
plt.gcf().subplots_adjust(bottom=0.15)
plt.tick_params(axis='both', which='major', labelsize=20)
plt.tick_params(axis='both', which='minor', labelsize=20)
plt.xlabel('Number of anchors', fontsize=20)
plt.ylabel('Equation (6)', fontsize=20)
plt.legend(fontsize=20, loc='upper left')
plt.savefig('movielens_fig.pdf')
pdb.set_trace()

lda2s = list(grid_res['val_loss'].keys())
uas = list(grid_res['val_loss'][lda2s[0]].keys())
ias = list(grid_res['val_loss'][lda2s[0]][uas[0]].keys())

lda1 = 0.01
lda2 = 2e-06
lda2_index = lda2s.index(lda2)

# print (grid_res['val_loss_mat'][lda_index])
arr = np.array(grid_res['val_loss_mat'][lda2_index])
# arr[arr>10] = 1.0
arr_params = np.array(grid_res['num_params_mat'][lda2_index])

factor = 100

xs = []
ys = []
for i in uas:
	for j in ias:
		xs.append(i)
		ys.append(j)
xs1 = xs
ys1 = ys

xx = np.array(xs).reshape(len(uas), len(ias))
yy = np.array(ys).reshape(len(uas), len(ias))
dynamic_final = factor*0.649 + lda2*1090000 + (lda1-lda2)*(8+8)
zz = np.ones_like(xx) * dynamic_final

'''
factor = [1,10,20,30,40,50,60,70,80,90,100,200,500]
for f in factor:
	arr_final = f*arr + lda2*arr_params + (lda1-lda2)*(xx+yy)
	ii,jj = np.unravel_index(arr_final.argmin(), arr_final.shape)
	print (f)
	print (arr[ii][jj])
	print (arr_params[ii][jj])
	print (xx[ii][jj], yy[ii][jj])

pdb.set_trace()
print (arr_final)
'''
arr_final = factor*arr + lda2*arr_params + (lda1-lda2)*(xx+yy)
zs1 = np.ndarray.flatten(arr_final) #np.exp(-arr))
'''
for i in range(len(uas)):
	for j in range(len(ias)):
		if arr[i][j] < 0.66 and arr_params[i][j] < 1000000:
			print (arr[i][j], arr_params[i][j]/1000000, xx[i][j], yy[i][j])
'''
# pdb.set_trace()

# plot the surface
plt3d = plt.figure().gca(projection='3d')
axes = plt.gca()
axes.set_xlim([9,0])
axes.set_ylim([0,9])
axes.xaxis.set_ticks(np.arange(8, -1, -2))
axes.yaxis.set_ticks(np.arange(0, 9, 2))
axes.set_zlim([50,150])
plt3d.plot_surface(np.log(xx), np.log(yy), zz, color='red', alpha=0.8)

ax = plt.gca()
# ax.hold(True)

a = ax.get_xticks().tolist()
a[0] = '256'
a[1] = '64'
a[2] = '16'
a[3] = '4'
a[4] = '1'
ax.set_xticklabels(a)

a = ax.get_yticks().tolist()
a[0] = '1'
a[1] = '4'
a[2] = '16'
a[3] = '64'
a[4] = '256'
ax.set_yticklabels(a)

xs = []
ys = []
zs = []
for (x,y,z) in zip(xs1,ys1,zs1):
	if z <= 150:
		xs.append(x)
		ys.append(y)
		zs.append(z)
xs = np.array(xs)
ys = np.array(ys)
zs = np.array(zs)
# ax.set_xscale('log')
# ax.set_yscale('log')
# ax.scatter(np.log(xs), np.log(ys), zs)
ax.plot_trisurf(np.log(xs), np.log(ys), zs, cmap='viridis', edgecolor='none', alpha=0.7)

# ax.set_xlim(9,0)
# ax.set_ylim(0,9)

# plt.plot(np.exp(-arr))
ax.set_xlabel('# user anchors')
ax.set_ylabel('# movie anchors')
ax.set_zlabel('objective value')
plt.show()

# pdb.set_trace()