import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import os
import pandas as pd
import seaborn as sns
import deepdish as dd



#####################################################################################

def plot1():
    data1, data2, data3 = dd.io.load('data/exp1.h5')
    data4 = dd.io.load('data/exp3.h5')
    
    data1, data2, data3, data4 = data1>0, data2>0, data3>0, data4>0
    
    s = 20
    rc_ = {'figure.figsize':(10,6),'axes.labelsize': 30, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 25}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    n = data1.shape[1]
    task = np.arange(1,n+1)
    mean1 = data1.mean(axis=0).cumsum()
    std1 = data1.std(axis=0)
    mean2 = data2.mean(axis=0).cumsum()
    std2 = data2.std(axis=0)
    
    ax.plot(task, mean1, ls='--', alpha=0.5, label=r'$D_{best}$', lw = 3.0)
    ax.fill_between(task, mean1 - std1, mean1 + std1, alpha=0.4)
    ax.plot(task, mean2, ls='--', alpha=0.5, label=r'$D_{worst}$', lw = 3.0)
    ax.fill_between(task, mean2 - std2, mean2 + std2, alpha=0.4)
    
    # for i in range(data3.shape[2]):
    mean3 = data3.mean(axis=0).cumsum()
    std3 = data3.std(axis=0)
    ax.plot(task, mean3, alpha=0.5, label= r'$D_{Sampled}$', lw = 3.0, color='black')
    ax.fill_between(task, mean3 - std3, mean3 + std3, alpha=0.4)
    mean3 = data4.mean(axis=0).cumsum()
    std3 = data4.std(axis=0)
    ax.plot(task, mean3, alpha=0.5, label= r'$max \mathcal{Q}$', lw = 3.0)
    ax.fill_between(task, mean3 - std3, mean3 + std3, alpha=0.4)
    
    ax.legend()
    plt.xlabel("Tasks")
    plt.ylabel('Policies stored')
    ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    #ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    plt.show()
    fig.savefig("plots/exp1.pdf", bbox_inches='tight')

#####################################################################################

def plot2():
    data1, data2, data3 = dd.io.load('data/exp1.h5')
    data4 = dd.io.load('data/exp3.h5')
    # print(data4)
    
    s = 20
    rc_ = {'figure.figsize':(10,6),'axes.labelsize': 30, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 20}
    sns.set(rc=rc_, style="darkgrid")
    rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    n = data1.shape[1]
    task = np.arange(1,n+1)
    mean1 = data1.mean(axis=0)
    std1 = data1.std(axis=0)
    mean2 = data2.mean(axis=0)
    std2 = data2.std(axis=0)
    
    M=100000
    ax.plot(task, mean1, ls='--', label=r'$D_{best}$', lw = 3.0)
    ax.fill_between(task, np.clip(mean1 - std1,0,M), np.clip(mean1 + std1,0,M), alpha=0.4)
    ax.plot(task, mean2, ls='--', label=r'$D_{worst}$', lw = 3.0)
    ax.fill_between(task, np.clip(mean2 - std2,0,M), np.clip(mean2 + std2,0,M), alpha=0.4)
    
    # for i in range(data3.shape[2]):
    mean3 = data3.mean(axis=0)
    std3 = data3.std(axis=0)
    ax.plot(task, mean3, alpha=0.5, label=r'$D_{Sampled}$', lw = 3.0, color='black')
    ax.fill_between(task, np.clip(mean3 - std3,0,M), np.clip(mean3 + std3,0,M), alpha=0.4, color='black')
    mean3 = data4.mean(axis=0)
    std3 = data4.std(axis=0)
    ax.plot(task, mean3, label=r'$max \mathcal{Q}$', lw = 3.0)
    ax.fill_between(task, np.clip(mean3 - std3,0,M), np.clip(mean3 + std3,0,M), alpha=0.4)
    
    ax.legend()
    plt.xlabel("Tasks")
    plt.ylabel('Timesteps')
    ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
    #ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    plt.show()
    fig.savefig("plots/exp2.pdf", bbox_inches='tight')

#####################################################################################

def plot3():
    s = 20
    rc_ = {'figure.figsize':(30,10),'axes.labelsize': 30, 'font.size': 30, 
          'legend.fontsize': 20, 'axes.titlesize': 30}
    sns.set(rc=rc_, style="darkgrid",font_scale = 1.8)
    rc('text', usetex=False)
    
    data0, data1 = dd.io.load('data/exp2.h5')        
    types = ["Optimal",
              "Composed",
            ]
    n = data0.shape[1]
    data = pd.DataFrame(
    [[data0[i,t] for t in range(n)]+[types[0]] for i in range(len(data1))] +
    [[data1[i,t] for t in range(n)]+[types[1]] for i in range(len(data1))],
      columns=list(range(1,n+1))+[""])
    data = pd.melt(data, "", var_name="Tasks", value_name="Average Returns")
    
    fig, ax = plt.subplots()
    ax = sns.boxplot(x="Tasks", y="Average Returns", hue="", data=data, linewidth=3, showfliers = False)
    plt.show()
    fig.savefig("plots/exp3.pdf", bbox_inches='tight')

#####################################################################################
    
plot1(); plot2(); plot3();