import numpy as np
import matplotlib.pyplot as plt
from scipy.special import jv
import pickle
from numpy import random
import seaborn as sns
import matplotlib.cm as cm
from scipy.special import zeta  



from labellines import labelLine, labelLines





#labelLines(ax.get_lines(), zorder=2.5)


with open('LOSS_0', 'rb') as fp:
    loss_0 = pickle.load(fp)


with open('LOSS', 'rb') as fp:
    loss = pickle.load(fp)

with open('LOSS2', 'rb') as fp:
    loss2 = pickle.load(fp)

with open('LOSS3', 'rb') as fp:
    loss3 = pickle.load(fp)


with open('SS', 'rb') as fp:
    succ = pickle.load(fp)

with open('RW', 'rb') as fp:
    rw = pickle.load(fp)

with open('S_P_MX_0', 'rb') as fp:
    px0 = pickle.load(fp)

with open('S_P_MN_0', 'rb') as fp:
    pn0 = pickle.load(fp)


with open('S_P_MX_1', 'rb') as fp:
    px1 = pickle.load(fp)

with open('S_P_MN_1', 'rb') as fp:
    pn1 = pickle.load(fp)

with open('VOCAB', 'rb') as fp:
    vocab = pickle.load(fp)

with open('WUSAGE', 'rb') as fp:
    wusage = pickle.load(fp)

with open('WOSTAS0', 'rb') as fp:
    wos = pickle.load(fp)

with open('WOSTAS2', 'rb') as fp:
    wos2 = pickle.load(fp)

with open('WOSTAS3', 'rb') as fp:
    wos3 = pickle.load(fp)

print(wos)
#exit()

with open('WOSTAS1', 'rb') as fp:
    wos1 = pickle.load(fp)

smv = 0.0
for x in wusage:
    smv = smv + x[1]

print(np.array(wusage)[:,1]/smv)

with open('WORLD_LOC', 'rb') as fp:
    world = pickle.load(fp)

with open('WORLD_COLORS', 'rb') as fp:
    world_colors = pickle.load(fp)


with open('A_1', 'rb') as fp:
    acc_1 = pickle.load(fp)

with open('A_2', 'rb') as fp:
    acc_2 = pickle.load(fp)

with open('A_3', 'rb') as fp:
    acc_3 = pickle.load(fp)


with open('CUSAGE', 'rb') as fp:
    cusage = pickle.load(fp)

#print(cusage)
csum = 0.0
for k in cusage:
    csum += cusage[k]

csuage_norm = dict()
for k in cusage:
    csuage_norm[k] = cusage[k]/csum 
#print(world_colors)

import matplotlib as mpl

print(plt.style.available)


plt.style.use("seaborn-v0_8-poster")

#mpl.rcParams['font.size'] = 5
mpl.rcParams['grid.color'] = "black"
SMALL_SIZE = 18
MEDIUM_SIZE = 20
BIGGER_SIZE = 32


plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

font = {'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'size': 35,
        }


#print("Vocab", vocab)

x = np.linspace(0, 50000, 50000)
fig, ax = plt.subplots()

if(False):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_color('#DDDDDD')
    #ax.tick_params(bottom=False, left=False)
    ax.set_axisbelow(True)
    ax.yaxis.grid(True)
    ax.xaxis.grid(False)
    ax.tick_params(bottom=False, left=True)


    x = np.linspace(0, 50000, 50000)
    sorted_wusage = sorted(wusage, key=lambda x: x[1], reverse=True)
    words = []
    for w in range(len(wusage)):
        wstr = '{}'.format(sorted_wusage[w][0])
        words.append(wstr)

    xvals = range(len(wusage))
    yvals = np.array(sorted_wusage)[:,1]/smv
    bar_labels = range(len(wusage))
    bar_colors = ['tab:red', 'tab:blue', 'tab:red', 'tab:orange']

    ax.bar(words, yvals,  color='tab:Gray')

    #ax.text(x=0.5, y=0.4, s='the highest  nspirit servings',
    #       color='Red', size=17, weight='bold')
    plt.ylabel('Normalized Word Usage', fontsize=20)
    plt.xlabel('Words', fontsize=20)

#for i in range(0, 3):
#  J = jv(i, x)
#ax.plot(x[0:1000], np.array(loss[0:1000])*0.8, label=r'$\mathcal{J}$')

if(False):
    saq = 0
    loss_0_avg = []
    eloss = []
    kk=1
    for qr in loss_0:
        saq = saq + (1.0/kk)*(qr - saq)
        eloss.append(qr-saq)
        loss_0_avg.append(saq)
        kk = kk+1

    kk=1
    saq=0
    eloss2 = []



    succ_avg = []
    for qr in succ:
        saq = saq + (0.09)*(qr - saq)
        eloss2.append(qr-saq)
        succ_avg.append(saq)
        kk = kk+1

    print("#$$$")
    print(len(succ))
    print("#####")
    #x = np.array(range(1126))
    #print(10*x[0:1120:1])
    lim = len(loss_0_avg)
    ax.plot(100*x[0:lim:1], np.array(loss_0_avg[0:lim:1])*0.5, label=r'$\mathcal{L}_2$', linewidth=2.5)
    #print(len(eloss))
    #print(len(loss_0_avg))
    #ax.fill_between(10*x[0:1000], (np.array(loss_0_avg[0:1000])-np.array(eloss[0:1000])*0.3)*0.5, 0.5*(np.array(loss_0_avg[0:1000])+0.3*np.array(eloss[0:1000])))
    ax.plot(100*x[0:lim], np.array(loss2[0:lim])*3+np.array(loss3[0:lim])*3, label=r'$\mathcal{L}_3$')
    #ax.plot(100*x[0:lim], np.array(loss3[0:lim])*3, label="LOSS3")
    #plt.ylabel('Loss', fontdict=font)

    #print(np.size(succ))

        
    #ax.plot(100*x[0:lim:1], np.array(succ_avg[0:lim:1]), label='SUCCESS RATIO', linewidth=2.5)
    #ax.fill_between(100*x[0:lim], (np.array(succ_avg[0:lim])-0.2*np.array(eloss2[0:lim])), (np.array(succ_avg[0:lim])+0.1*np.array(eloss2[0:lim])))
    plt.xlabel('Time', fontdict=font)
    #ax.plot(10*x[0:1120:1], succ[0:1120:1], label='SUCCESS RATIO', linewidth=2.5)
    #plt.ylabel('Success ratio', fontdict=font)

#ps = 3000
#for i in range(12):
#    ax.plot(x[0:1000], vocab[i][0:1000],  linewidth=2.5)
#    #ax.annotate("Concept")
#    stt = 'Concept {}'.format(i)
#    print(stt)
#    ax.annotate(stt, (ps,vocab[i][int(ps/10)]), weight='bold', size=20)
#    ps = ps+400
#ax.plot(x[0:1000], np.array(loss_0[0:1000])/25, label='WORD ORDER LOSS x 4e-2')

if(False):
    lim = len(rw)
    ax.plot(100*x[0:lim], np.array(rw[0:lim]), label='REWARD')
    plt.ylabel('Reward', fontdict=font)
    plt.xlabel('Time', fontdict=font)

#ax.plot(x[0:500], np.array(px0[0:500]), label='DL_MAX_PROB_A0')
#ax.plot(x[0:500], np.array(pn0[0:500]), label='DL_MIN_PROB_A0')
#ax.plot(x[0:500], np.array(px1[0:500]), label='DL_MAX_PROB_A1')
#ax.plot(x[0:500], np.array(pn1[0:500]), label='DL_MIN_PROB_A1')
#plt.ylabel('Probability', fontdict=font)

#ax.fill_between(x[0:530], np.array(pn0[0:530]), np.array(px0[0:530])) 
#ax.fill_between(x[0:530], np.array(pn1[0:530]), np.array(px1[0:530])) 


if(False):
    kk=1
    s_acc_1=0.0
    eloss1 = []
    s_acc_1_list = []
    for qr in acc_1:
        s_acc_1 = s_acc_1 + (0.09)*(qr - s_acc_1)
        eloss1.append(qr-s_acc_1)
        s_acc_1_list.append(s_acc_1)
        kk = kk+1

    lim = len(s_acc_1_list)


    ax.plot(100*x[0:lim], np.array(s_acc_1_list[0:lim]), label="FIRST WORD ACCURACY",  linewidth=2.5)
    ax.fill_between(100*x[0:lim], np.array(s_acc_1_list[0:lim])-np.array(eloss1[0:lim])*0.2, np.array(s_acc_1_list[0:lim])+0.01*np.array(eloss1[0:lim]))


    kk=1
    s_acc_3=0.0
    eloss3 = []
    s_acc_3_list = []
    for qr in acc_3:
        s_acc_3 = s_acc_3 + (0.09)*(qr - s_acc_3)
        eloss3.append(qr-s_acc_3)
        s_acc_3_list.append(s_acc_3)
        kk = kk+1


    ax.plot(100*x[0:lim], np.array(s_acc_3_list[0:lim]), label='THIRD WORD ACCURACY',  linewidth=2.0)
    ax.fill_between(100*x[0:lim], np.array(s_acc_3_list[0:lim])-np.array(eloss3[0:lim])*0.15, np.array(s_acc_3_list[0:lim])+0.01*np.array(eloss3[0:lim]))

    kk=1
    s_acc_2=0.0
    eloss2 = []
    s_acc_2_list = []
    for qr in acc_2:
        s_acc_2 = s_acc_2 + (0.09)*(qr - s_acc_2)
        eloss2.append(qr-s_acc_2)
        s_acc_2_list.append(s_acc_2)
        kk = kk+1

    ax.plot(100*x[0:lim], s_acc_2_list[0:lim], label='SECOND WORD ACCURACY',  linewidth=2.0)
    ax.fill_between(100*x[0:lim], np.array(s_acc_2_list[0:lim])-np.array(eloss2[0:lim])*0.15, np.array(s_acc_2_list[0:lim])+0.01*np.array(eloss2[0:lim]))


    plt.ylabel('Accuracy', fontdict=font)
    plt.xlabel('Time', fontdict=font)
    #plt.ylabel('Usage freq', fontdict=font)
    plt.legend()
    # plt.grid(color = 'black', linestyle = '--', linewidth = 0.5)
    #plt.ylim(-100, 600)     # set the ylim to bottom, top


################################################
if(False):
    fig, axes = plt.subplots(nrows=1, ncols=3)
    mlen = 0
    for xx in wos.keys():
        if(wos[xx][-1][0] >= mlen):
            mlen = wos[xx][-1][0]


    mlen1 = 0
    for xx in wos1.keys():
        if(wos1[xx][-1][0] >= mlen1):
            mlen1 = wos1[xx][-1][0]


    mlen2 = 0
    for xx in wos2.keys():
        if(wos2[xx][-1][0] >= mlen2):
            mlen2 = wos2[xx][-1][0]



    wos_v = dict()
    wos_t = dict()
    for xx in wos.keys():
        #print(np.zeros(mlen-len(wos[xx])))
        #wos[xx] = np.append(wos[xx], np.zeros(mlen-len(wos[xx])))
        wos_v[xx] = np.array([ele[1] for ele in wos[xx]])
        wos_t[xx] = np.array([ele[0] for ele in wos[xx]])


    wos1_v = dict()
    wos1_t = dict()
    for xx in wos1.keys():
        #print(np.zeros(mlen-len(wos[xx])))
        #wos[xx] = np.append(wos[xx], np.zeros(mlen-len(wos[xx])))
        wos1_v[xx] = np.array([ele[1] for ele in wos1[xx]])
        wos1_t[xx] = np.array([ele[0] for ele in wos1[xx]])

    wos2_v = dict()
    wos2_t = dict()
    for xx in wos2.keys():
        #print(np.zeros(mlen-len(wos[xx])))
        #wos[xx] = np.append(wos[xx], np.zeros(mlen-len(wos[xx])))
        wos2_v[xx] = np.array([ele[1] for ele in wos2[xx]])
        wos2_t[xx] = np.array([ele[0] for ele in wos2[xx]])


    wos3_v = dict()
    wos3_t = dict()
    for xx in wos3.keys():
        #print(np.zeros(mlen-len(wos[xx])))
        #wos[xx] = np.append(wos[xx], np.zeros(mlen-len(wos[xx])))
        wos3_v[xx] = np.array([ele[1] for ele in wos3[xx]])
        wos3_t[xx] = np.array([ele[0] for ele in wos3[xx]])


    a = []
    wos_v_v = np.array(list(wos_v.values()))
    wos1_v_v = np.array(list(wos1_v.values()))
    wos2_v_v = np.array(list(wos2_v.values()))
    wos3_v_v = np.array(list(wos3_v.values()))
    
    print(len(wos_v_v[0]))
    print(len(wos1_v_v[0]))
    print(len(wos2_v_v[0]))
    print(len(wos3_v_v[0]))

    #exit()

    a = [wos_v_v[:,0], wos1_v_v[:,0], wos2_v_v[:,0], wos3_v_v[:,0]]
    a51 = [wos_v_v[:,10], wos1_v_v[:,10], wos2_v_v[:,10], wos3_v_v[:,10]]
    a101 = [wos_v_v[:,34], wos1_v_v[:,34], wos2_v_v[:,34], wos3_v_v[:,34]]


    print(a)
    axes[0].xaxis.grid(False)
    axes[0].yaxis.grid(False)
    axes[0].set_xticks(np.arange(0, 8))
    axes[1].xaxis.grid(False)
    axes[1].yaxis.grid(False)
    axes[1].set_xticks(np.arange(0, 8))
    axes[2].xaxis.grid(False)
    axes[2].yaxis.grid(False)
    axes[2].set_xticks(np.arange(0, 8))


    im = axes[0].imshow(a, vmin=0,vmax=1, alpha=0.9,  cmap=cm.copper, interpolation='antialiased')
    im = axes[1].imshow(a51, vmin=0,vmax=1, alpha=0.9,  cmap=cm.copper, interpolation='antialiased')
    im = axes[2].imshow(a101, vmin=0,vmax=1, alpha=0.9,  cmap=cm.copper, interpolation='antialiased')
    cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])

    fig.subplots_adjust(right=0.8)
    fig.colorbar(im, cax=cbar_ax)



    if(False):
        try:
            ax.plot(wos_t[(1,1,1)], wos_v[(1,1,1)], label='111',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='111',  linewidth=2.0)
        try:    
            ax.plot(wos_t[(0,0,1)], wos_v[(0,0,1)], label='001',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='001',  linewidth=2.0)
        try:
            ax.plot(wos_t[(0,1,0)], wos_v[(0,1,0)], label='010',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='010',  linewidth=2.0)
        try:
            ax.plot(wos_t[(0,1,1)], wos_v[(0,1,1)], label='011',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='011',  linewidth=2.0)
        try:
            ax.plot(wos_t[(1,0,0)], wos_v[(1,0,0)], label='100',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='100',  linewidth=2.0)
        try:
            ax.plot(wos_t[(1,0,1)], wos_v[(1,0,1)], label='101',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='101',  linewidth=2.0)
        try:
            ax.plot(wos_t[(1,1,0)], wos_v[(1,1,0)], label='110',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='110',  linewidth=2.0)
        try:
            ax.plot(wos_t[(0,0,0)], wos_v[(0,0,0)], label='000',  linewidth=2.0)
        except:
            ax.plot(x[0:mlen], np.zeros(mlen), label='000',  linewidth=2.0)
        labelLines(ax.get_lines(), align=False, fontsize=24)
        plt.xlabel('Time', fontdict=font)
        plt.ylabel('Probability', fontdict=font)


#print(wos[(1,1,1)][-1])
#print(wos[(0,0,1)][-1])
#print(wos[(0,1,0)][-1])
#print(wos[(0,1,1)][-1])
#print(wos[(1,0,0)][-1])
#print(wos[(1,0,1)][-1])
#print(wos[(1,1,0)][-1])
#print(wos[(0,0,0)][-1])
###############################333

#print(succ[0:100])
#print(succ[0:100:5])

if(True):
    sum22=0
    for ww in wusage:
        sum22 = sum22 + ww[1]

    fll = {}
    for ww in wusage:
        fll[ww[0]] = ww[1]/float(sum22)

    print(fll)

    sfll = list(fll.values())

    sfll.sort(reverse=True)

    print(sfll)

    print(fll)


    zp = random.zipf(a=2.0, size=2000)
    count = np.bincount(zp)
    ax.plot(range(len(sfll[0:25])), sfll[0:25], label='Word usage', linewidth=2.5)
    plt.fill_between(range(len(sfll[0:25])), sfll[0:25],color='blue',alpha=0.6)
    #ax.plot(range(6), zp[0:6], label='Word usage2', linewidth=2.5)

    k=6
    a = 4.0
    n = 20000

    #print(zp)
    #print(count[1:10])
    ax.plot(range(25), np.array(count[1:26])/sum(count), label='Zipf Dist', linewidth=2.5)
    plt.fill_between(range(25), np.array(count[1:26])/sum(count),color='Red',alpha=0.6)
#plt.legend()
plt.show()
