import numpy as np
import matplotlib.pyplot as plt
import os
import re
import seaborn as sns
sns.set_theme(style="white")

str_to_reg = 'CURRENT VAL LOSS: ' #BEST VAL LOSS:    


file_index = -1
files_to_open = os.listdir("./data/neg/")
regex_score = r'CURRENT VAL LOSS: ([\d]+).([\d]+)'
y = np.zeros((5,150))
for file_to_open in files_to_open:
    print(file_to_open)
    if "out" not in file_to_open:
        continue
    file_index += 1
    figure, axis = plt.subplots(1,1)
    x_green = []
    x_yellow = []
    y_green = []
    y_yellow = []
    prev_score = 0
    step=0
    
    
    
    with open("./data/neg/" + file_to_open) as f:
        time_limit_breaker = True
        time_in_sec = 0

        for line in f:
            result_score = re.findall(regex_score, line)

            if len(result_score) !=0:
                print(file_index,float(result_score[0][0])+float(result_score[0][1])*10**-6)
                y[file_index, step] -= float(result_score[0][0])+float(result_score[0][1])*10**-6
                step +=1

file_index = -1
files_to_open = os.listdir("./data/gail/")
regex_score = r'CURRENT VAL LOSS: ([\d]+).([\d]+)'
y_g = np.zeros((5,150))
for file_to_open in files_to_open:
    print(file_to_open)
    if "out" not in file_to_open:
        continue
    file_index += 1
    figure, axis = plt.subplots(1,1)
    x_green = []
    x_yellow = []
    y_green = []
    y_yellow = []
    prev_score = 0
    step=0
    
    
    
    with open("./data/gail/" + file_to_open) as f:
        time_limit_breaker = True
        time_in_sec = 0

        for line in f:
            result_score = re.findall(regex_score, line)

            if len(result_score) !=0:
                print(file_index,float(result_score[0][0])+float(result_score[0][1])*10**-6)
                y_g[file_index, step] -= float(result_score[0][0])+float(result_score[0][1])*10**-6
                step +=1

print(y_g[:,-5:])

file_index = -1
files_to_open = os.listdir("./data/wgail/")
regex_score = r'CURRENT VAL LOSS: ([\d]+).([\d]+)'
y_w = np.zeros((5,150))
for file_to_open in files_to_open:
    print(file_to_open)
    if "out" not in file_to_open:
        continue
    file_index += 1
    figure, axis = plt.subplots(1,1)
    x_green = []
    x_yellow = []
    y_green = []
    y_yellow = []
    prev_score = 0
    step=0
    
    
    
    with open("./data/wgail/" + file_to_open) as f:
        time_limit_breaker = True
        time_in_sec = 0

        for line in f:
            result_score = re.findall(regex_score, line)

            if len(result_score) !=0:
                print(file_index,float(result_score[0][0])+float(result_score[0][1])*10**-6)
                y_w[file_index, step] -= float(result_score[0][0])+float(result_score[0][1])*10**-6
                step +=1

plt.xlabel('Epochs')
plt.ylabel('Training Reward (Log Scale)')
plt.semilogy()
plt.yscale('symlog')
plt.plot(np.arange(150), y.mean(axis=0), marker='o', ms=0, linewidth=1, label='SS-GAIL')
plt.fill_between(range(150), y.mean(axis=0)-y.std(axis=0), y.mean(axis=0)+y.std(axis=0), alpha = 0.1)
# plt.plot(np.arange(150), np.array([432]*150),  marker='o', ms=0, linewidth=2,linestyle='dotted', label='Random Policy')
plt.plot(np.arange(150), y_w.mean(axis=0), marker='o', ms=0, linewidth=1, label='WAIL')
plt.fill_between(range(150), y_w.mean(axis=0)-y_w.std(axis=0), y_w.mean(axis=0)+y_w.std(axis=0), alpha = 0.1)
plt.plot(np.arange(150), y_g.mean(axis=0),  marker='o', ms=0, linewidth=1, label='GAIL')
plt.fill_between(range(150), y_g.mean(axis=0)-y_g.std(axis=0), y_g.mean(axis=0)+y_g.std(axis=0), alpha = 0.1)

plt.ylim([-10000, -5])
plt.xlim([0, 150])
plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=3)
plt.show()