import numpy as np
from collections import defaultdict
from matplotlib import pyplot as plt
from matplotlib import font_manager as fm, rcParams
from matplotlib import rc
import os
import pandas as pd
import seaborn as sns
import torch

#####################################################################################
path = './'
path = '{}/dqn/dataA'.format(path)
data_type = 'episode_returns'

def plotdata(data, name):
    s = 20
    rc_ = {'figure.figsize':(10,6),'axes.labelsize': 30, 'xtick.labelsize': s, 
           'ytick.labelsize': s, 'legend.fontsize': 20}
    sns.set(rc=rc_, style="darkgrid")
    # rc('text', usetex=True)
    
    fig, ax = plt.subplots()
    
    lw = 2.0
    task = np.arange(data[0][0].shape[0])
    for (mean, std, label) in data:
        ax.plot(task, mean,  label=label, lw = lw)
        ax.fill_between(task, mean - std, mean + std, alpha=0.4)
    
    ax.legend()
    plt.xlabel("Episodes")
    plt.ylabel('Returns')
    # plt.ylim((-10, 2))
    ax.xaxis.get_major_formatter().set_powerlimits((0, 1))
    #ax.ticklabel_format(axis='y',style='scientific', useOffset=True)
    fig.tight_layout()
    plt.show()
    fig.savefig("./dqn/plots/{}.pdf".format(name), bbox_inches='tight')

def process_data(alldata):
    pdata = []
    m = 1000
    s = 0.5
    o = alldata[0][0].shape[0]
    for (data, label) in alldata:
        # for i in range(o):
        #     data[i] = np.convolve(data[i], np.ones(m)/m, mode='same')

        mean = data.mean(axis=0)
        std = data.std(axis=0)*s

        mean = np.convolve(mean, np.ones(m)/m, mode='valid')
        std = np.convolve(std, np.ones(m)/m, mode='valid')

        pdata.append([mean, std, label])

    return pdata

def plot1(task=0,episodes = 2000, runs=4, name='None'):
    name = 'bases_{}'.format(task)
    data1, data2, data3, data4, data5 = [], [], [], [], []
    for run in range(0,runs):
        data_path = '{}/exp1_{}_{}_{}_{}.h5'.format(path,'sop','None',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data1.append(d)

        data_path = '{}/exp1_{}_{}_{}_{}.h5'.format(path,'sop','continual',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data2.append(d)

        data_path = '{}/exp1_{}_{}_{}_{}.h5'.format(path,'sf','transfer',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data3.append(d)
        
        data_path = '{}/exp1_{}_{}_{}_{}.h5'.format(path,'sf','continual',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data4.append(d)

        data_path = '{}/exp_{}_{}_{}.h5'.format(path,'dqn',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data5.append(d)
    data1, data2, data3, data4, data5 = np.asarray(data1), np.asarray(data2), np.asarray(data3), np.asarray(data4), np.asarray(data5)
    print(data1.shape,len(data1))
    
    data = [(data1,"SOPGOL"), (data2,"SOPGOL continual"),(data3,"GPISF transfer"), (data4,"GPISF continual"), (data5, 'DQN')]    
    plotdata(process_data(data), name)

def plot2(task=0,episodes = 2000,runs=4):
    name = 'nbases_{}'.format(task)
    data1, data2, data3, data4, data5 = [], [], [], [], []
    for run in range(0,runs):
        data_path = '{}/exp2_{}_{}_{}_{}.h5'.format(path,'sop','None',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data1.append(d)

        data_path = '{}/exp2_{}_{}_{}_{}.h5'.format(path,'sop','transfer',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data2.append(d)

        data_path = '{}/exp2_{}_{}_{}_{}.h5'.format(path,'sf','transfer',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data3.append(d)

        if run!=1:
            data_path = '{}/exp2_{}_{}_{}_{}.h5'.format(path,'sf','continual',task,run)  
            d = torch.load(data_path)[data_type]
            print(len(d), data_path)
            d = d[:episodes]
            data4.append(d)

        data_path = '{}/exp_{}_{}_{}.h5'.format(path,'dqn',task,run)  
        d = torch.load(data_path)[data_type]
        print(len(d), data_path)
        d = d[:episodes]
        data5.append(d)
    data1, data2, data3, data4, data5 = np.asarray(data1), np.asarray(data2), np.asarray(data3), np.asarray(data4), np.asarray(data5)
    print(data1.shape,len(data1))
    
    data = [(data1,"SOPGOL"), (data2,"SOPGOL transfer"),(data3,"GPISF transfer"), (data4,"GPISF continual"), (data5, 'DQN')]   
    plotdata(process_data(data), name)

#####################################################################################

plot1(0,6000,3);