import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import os
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import MultipleLocator
from matplotlib.transforms import Bbox

class AttrDict(dict):

  __setattr__ = dict.__setitem__
  __getattr__ = dict.__getitem__

def format_num(x, y):
    return '%.0fk' % (x*5)

def define_config():
  config = AttrDict()
  config.noda_dreamer_dir = "NODA-Dreamer"
  config.noda_bird_dir = "NODA-BIRD"
  config.dreamer_dir = "Dreamer"
  config.bird_dir = "BIRD"
  config.task = "dmc_walker_walk"
  return config

def data_helper(task, noda_dreamer_dir, noda_bird_dir, dreamer_dir, bird_dir):
    results_noda_dreamer = np.zeros((4,40))
    results_noda_bird = np.zeros((4,40))
    results_dreamer = np.zeros((4,40))
    results_bird = np.zeros((4,40))
    results = [results_noda_dreamer, results_noda_bird, results_dreamer, results_bird]
    target_dir = [noda_dreamer_dir, noda_bird_dir, dreamer_dir, bird_dir]
    for k, result in enumerate(results):
        for j in range(len(result)):
            count = 0
            with open("../example_data/" + target_dir[k] + "/" + task + "/seed_{}.json".format(j), 'r') as load_f:
                for i, line in enumerate(load_f):
                    result[j][count] += float(line.split(",")[1].split(":")[1])
                    if i % 10 == 9:
                        result[j][count] = result[j][count] / 10
                        count += 1
                    
    results_noda_dreamer_mean = np.mean(np.array(results_noda_dreamer),axis = 0)
    results_noda_bird_mean = np.mean(np.array(results_noda_bird),axis = 0)
    results_dreamer_mean = np.mean(np.array(results_dreamer),axis = 0)
    results_bird_mean = np.mean(np.array(results_bird),axis = 0)
    results_noda_dreamer_std = np.std(np.array(results_noda_dreamer),axis = 0)
    results_noda_bird_std = np.std(np.array(results_noda_bird),axis = 0)
    results_dreamer_std = np.std(np.array(results_dreamer),axis = 0)
    results_bird_std = np.std(np.array(results_bird),axis = 0)

    return results_noda_dreamer_mean, results_noda_dreamer_std, results_noda_bird_mean, results_noda_bird_std, results_dreamer_mean, results_dreamer_std, results_bird_mean, results_bird_std

def plot_helper(task, ax, results_noda_dreamer_mean, results_noda_dreamer_std, results_noda_bird_mean, results_noda_bird_std, results_dreamer_mean, results_dreamer_std, results_bird_mean, results_bird_std, asymp_line):    
    
    colors = ['#1f77b4', '#2ca02c','#d62728','orange']
    mean_data = [results_dreamer_mean, results_bird_mean, results_noda_dreamer_mean, results_noda_bird_mean]
    std_data = [results_dreamer_std, results_bird_std, results_noda_dreamer_std, results_noda_bird_std]
    for i in range(len(colors)):
        ax.plot(range(1, 41), mean_data[i], linewidth = 3.5, color = colors[i])
        lower_bound = list(map(lambda x : x[0] - x[1], zip(mean_data[i], std_data[i])))
        upper_bound = list(map(lambda x : x[0] + x[1], zip(mean_data[i], std_data[i])))
        ax.fill_between(range(1,41), lower_bound, upper_bound, alpha = 0.3, color = colors[i])
    
    ax.plot(range(1, 41), asymp_line*np.ones(40), linewidth = 3.5, color = '#9467bd', linestyle = '--')
    ax.set_xlabel('Environment Steps')
    ax.set_ylabel("Episode Return")
    ax.set_title(task.split('_')[1].capitalize() + ' ' + task.split('_')[2].capitalize())
    formatter = FuncFormatter(format_num)
    x_major_locator = MultipleLocator(10)
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(x_major_locator)
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    plt.rc('font', size=23)
    plt.grid()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    for key, value in define_config().items():
        parser.add_argument(f'--{key}', type=type(value), default=value)
    config = parser.parse_args()

    fig = plt.figure(figsize = (10,6))
    color = cm.viridis(0.7)

    ax1 = fig.add_subplot(1,1,1)
    noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std = data_helper(config.task, config.noda_dreamer_dir, config.noda_bird_dir, config.dreamer_dir, config.bird_dir)
    plot_helper("dmc_walker_walk", ax1, noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std, 961.67)

    #fig = plt.figure(figsize = (21,5))
    #color = cm.viridis(0.7)
    #
    #ax2 = fig.add_subplot(1,3,1)
    #noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std = data_helper("dmc_hopper_stand", "NODA-Dreamer", "NODA-BIRD", "Dreamer", "BIRD")
    #plot_helper("dmc_hopper_stand", ax2, noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std, 923.72)
    #
    #ax3 = fig.add_subplot(1,3,2)
    #noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std = data_helper("dmc_walker_walk", "NODA-Dreamer", "NODA-BIRD", "Dreamer", "BIRD")
    #plot_helper("dmc_walker_walk", ax3, noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std, 961.67)
    #
    #ax4 = fig.add_subplot(1,3,3)
    #noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std = data_helper("dmc_walker_run", "NODA-Dreamer", "NODA-BIRD", "Dreamer", "BIRD")
    #plot_helper("dmc_walker_run", ax4, noda_dreamer_mean, noda_dreamer_std, noda_bird_mean, noda_bird_std, dreamer_mean, dreamer_std, bird_mean, bird_std, 824.67)

    fig.legend(["Dreamer", "BIRD", "NODA-Dreamer", "NODA-BIRD"], loc = "lower center", ncol=4, bbox_to_anchor = (0.5, -0.1))
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    plt.rc('font', size = 23)
    plt.show()
    fig.savefig(os.path.join('../','Results.pdf'), bbox_inches = Bbox([[-1, -0.5], [11, 6]]))
    #fig.savefig(os.path.join('../','Results.pdf'), bbox_inches = Bbox([[0, -0.5], [21, 5]]))
