import numpy as np
import matplotlib.pyplot as plt
from scipy.special import jv
import pickle

with open('LOSS_0', 'rb') as fp:
    loss_0 = pickle.load(fp)


with open('LOSS', 'rb') as fp:
    loss = pickle.load(fp)

with open('LOSS2', 'rb') as fp:
    loss2 = pickle.load(fp)

with open('LOSS3', 'rb') as fp:
    loss3 = pickle.load(fp)


with open('SS', 'rb') as fp:
    succ = pickle.load(fp)

with open('RW', 'rb') as fp:
    rw = pickle.load(fp)

with open('S_P_MX_0', 'rb') as fp:
    px0 = pickle.load(fp)

with open('S_P_MN_0', 'rb') as fp:
    pn0 = pickle.load(fp)


with open('S_P_MX_1', 'rb') as fp:
    px1 = pickle.load(fp)

with open('S_P_MN_1', 'rb') as fp:
    pn1 = pickle.load(fp)

with open('VOCAB', 'rb') as fp:
    vocab = pickle.load(fp)

with open('WUSAGE', 'rb') as fp:
    wusage = pickle.load(fp)

smv = 0.0
for x in wusage:
    smv = smv + x[1]

print(np.array(wusage)[:,1]/smv)

with open('WORLD_LOC', 'rb') as fp:
    world = pickle.load(fp)

with open('WORLD_COLORS', 'rb') as fp:
    world_colors = pickle.load(fp)


with open('A_1', 'rb') as fp:
    acc_1 = pickle.load(fp)

with open('A_2', 'rb') as fp:
    acc_2 = pickle.load(fp)

with open('A_3', 'rb') as fp:
    acc_3 = pickle.load(fp)


#print(world_colors)

import matplotlib as mpl

print(plt.style.available)


plt.style.use("bmh")

#mpl.rcParams['font.size'] = 5
mpl.rcParams['grid.color'] = "black"
SMALL_SIZE = 18
MEDIUM_SIZE = 20
BIGGER_SIZE = 32


plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

font = {'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'size': 35,
        }


#print("Vocab", vocab)

fig, ax = plt.subplots()
x = np.linspace(0, 17000, 1700)
#print(x)
words = []
for w in range(25):
    wstr = 'W{}'.format(w)
    words.append(wstr)

xvals = range(len(wusage))
yvals = np.array(wusage)[:,1]/smv
bar_labels = range(len(wusage))
#bar_colors = ['tab:red', 'tab:blue', 'tab:red', 'tab:orange']

#ax.bar(words, yvals, label=words, color='tab:grey')

#for i in range(0, 3):
#  J = jv(i, x)
#ax.plot(x[0:1000], np.array(loss[0:1000])*0.8, label=r'$\mathcal{J}$')
#ax.plot(x[0:1000], np.array(loss_0[0:1000])*1, label=r'$\mathcal{L}_2$')
#ax.plot(x[0:1000], np.array(loss2[0:1000])*3+np.array(loss3[0:1000])*3, label=r'$\mathcal{L}_3$')
#ax.plot(x[0:1000], np.array(loss3[0:1000])*3, label="LOSS3")
#plt.ylabel('Loss', fontdict=font)

ax.plot(x[0:1000], succ[0:1000], label='SUCCESS RATIO', linewidth=2.5)
plt.ylabel('Success ratio', fontdict=font)

ps = 3000
#for i in range(12):
#    ax.plot(x[0:1000], vocab[i][0:1000],  linewidth=2.5)
#    #ax.annotate("Concept")
#    stt = 'Concept {}'.format(i)
#    print(stt)
#    ax.annotate(stt, (ps,vocab[i][int(ps/10)]), weight='bold', size=20)
#    ps = ps+400
#ax.plot(x[0:1000], np.array(loss_0[0:1000])/25, label='WORD ORDER LOSS x 4e-2')

#ax.plot(x[0:1000], np.array(rw[0:1000]), label='REWARD')
#plt.ylabel('Reward', fontdict=font)

#ax.plot(x[0:500], np.array(px0[0:500]), label='DL_MAX_PROB_A0')
#ax.plot(x[0:500], np.array(pn0[0:500]), label='DL_MIN_PROB_A0')
#ax.plot(x[0:500], np.array(px1[0:500]), label='DL_MAX_PROB_A1')
#ax.plot(x[0:500], np.array(pn1[0:500]), label='DL_MIN_PROB_A1')
#plt.ylabel('Probability', fontdict=font)

#ax.fill_between(x[0:530], np.array(pn0[0:530]), np.array(px0[0:530])) 
#ax.fill_between(x[0:530], np.array(pn1[0:530]), np.array(px1[0:530])) 


#ax.plot(x[0:1000], np.array(acc_1[0:1000]), label="FIRST WORD ACCURACY",  linewidth=2.5)
#ax.plot(x[0:1000], acc_2[0:1000], label='SECOND WORD ACCURACY',  linewidth=2.0)
#ax.plot(x[0:1000], np.array(acc_3[0:1000]), label='THIRD WORD ACCURACY',  linewidth=2.0)
#plt.ylabel('Accuracy', fontdict=font)


#plt.xlabel('Word', fontdict=font)
plot.xlabel('Time', fontdict=font)
#plt.ylabel('Usage freq', fontdict=font)
#ax.legend()
plt.grid(color = 'black', linestyle = '--', linewidth = 0.5)
#plt.ylim(-100, 600)     # set the ylim to bottom, top

plt.show()