import matplotlib.pyplot as plt
import matplotlib as mpl
import json
from matplotlib import rcParams
rcParams.update({'figure.autolayout': True})
import numpy as np
import pickle
import sys

if len(sys.argv)!=2 and len(sys.argv)!=1:
    print("Usage: python3 plot.py (npz file)")
    print("Or: python3 plot.py #this will plot ave_entropy.pickle")
    quit()

if len(sys.argv)==2:
    data=np.load(sys.argv[1])
    entropy=data['entropy']
else:
    with open('ave_entropy.pickle', 'rb') as f:
        entropy=np.array(pickle.load(f))


xy=np.load("XAndY_optimized.npz")
plotRange_xlb=500
plotRange_xub=1001
plotRange_ylb=0000
plotRange_yub=2001

hfont = {'fontname':'DejaVu Serif'}

fig=plt.figure(dpi=300,figsize=(4,3))
ax = fig.subplots()
print("sum entropy={:.2e}, num of bins non-zero={:.2e}".format(np.nansum(entropy), np.sum(entropy>1e-3)))
print("sum entropy in interested region={:.2e}, num of bins non-zero={:.2e}".format(np.nansum(entropy[plotRange_ylb:plotRange_yub,plotRange_xlb:plotRange_xub]), np.sum(entropy[plotRange_ylb:plotRange_yub,plotRange_xlb:plotRange_xub]>1e-3)))
nanmax=np.nanmax(entropy)
entropy=entropy-np.nanmax(entropy)
#entropy[entropy<1e-3-nanmax]=float('nan')
entropy=entropy+nanmax
entropy[entropy<1e-3-nanmax]=0
norm=mpl.colors.Normalize(vmin=0,vmax=25000)
heatmap=ax.pcolormesh(xy['x'][plotRange_ylb:plotRange_yub,plotRange_xlb:plotRange_xub], 
                      xy['y'][plotRange_ylb:plotRange_yub,plotRange_xlb:plotRange_xub]/10, 
                      entropy[plotRange_ylb:plotRange_yub,plotRange_xlb:plotRange_xub], 
                      cmap=plt.cm.jet,
                      norm=norm)
fig.colorbar(heatmap)

plt.xlabel('ln(train loss)',**hfont)
plt.ylabel("smoothed test accuracy",**hfont)

#find the highest-entropy state at each train loss, plot with magenta line
xx=[]
yy=[]
for i in range(plotRange_xlb, plotRange_xub):
    temp=entropy[plotRange_ylb:plotRange_yub,i]
    if not np.isnan(temp).all():
        maxEntropyIndex=np.nanargmax(entropy[plotRange_ylb:plotRange_yub, i])+plotRange_ylb
        xx.append(xy['x'][0, i])
        yy.append(xy['y'][maxEntropyIndex, 0]/10)
ax.plot(xx,yy,'o',ms=2,color='magenta', label='max entropy')

prefix=sys.argv[1].split('/')[-1].split('.')
if len(sys.argv)==2:
    filename=prefix[0]+".png" 
         
plt.xlim(-10,0)
plt.ylim(0,1)
plt.legend(loc = "lower left",facecolor="silver")    
plt.title("Entropy landscape",**hfont)
plt.savefig(filename)