from scipy.io import loadmat
import numpy as np
import matplotlib.pyplot as plt
import torch
import gym

log_dir = 'log/'
PATH =['2024-08-23--02:51:02','2024-08-23--02:51:22']  # '2024-06-10--11:24:03'
# PATH = ['2024-08-05--17:33:36','2024-07-30--20:32:31']


expert_obs = []
expert_act = []
label_list = ['beta 0','beta 0.02'] # 'box constraint 0.9', 'box constraint 0.7'
# label_list = ['expert lead','no expert lead']
saved_dir = 'compare_fig/hopper/'
prefix = 'hopper_expert_actions_early_terminate_and_5'
suffix = ''
# draw reward
for idx, path, in enumerate(PATH):
    vgg = loadmat(log_dir+path+'/logs.mat')
    reward = np.array(vgg['returns'])[0]
    x = np.arange(reward.shape[0])
    plt.plot(x,reward,label=label_list[idx])
    print(label_list[idx],": ",reward[-20:].mean())

# for idxa, path, in enumerate(PATH):
#     vgg = loadmat(log_dir+path+'/logs.mat')
#     kk_rew = [0,0,0,0,0]
#     kk_dis = [0,0,0,0,0]
#     kk_idx = [0,0,0,0,0]
#     for idx, num in enumerate(vgg['returns'][0]):
#         if kk_rew[idx%5]<num:
#             kk_rew[idx%5]=num
#             kk_idx[idx%5]=idx
#             kk_dis[idx%5]=vgg['dtws'][0][idx]
#     print(label_list[idxa])
#     print(kk_idx)
#     print(kk_rew)
    
#     print(kk_dis)
#     for i in kk_idx:
#         print(vgg['observations'][0][i].shape)
#         print()
# print((kk_rew[0]+kk_rew[1]+kk_rew[2]+kk_rew[3]+kk_rew[4])/5)
# print((kk_dis[0]+kk_dis[1]+kk_dis[2]+kk_dis[3]+kk_dis[4])/5)
# quit()
# plt.axhline(y=5097, xmin=0, xmax=1,color='purple',label='expert trajectory')
# plt.axhline(y=1708, xmin=0, xmax=1,color='black',label='BC+projection')
# plt.axhline(y=-5.438, xmin=0, xmax=1,color='purple',label='expert trajectory')
plt.title("PETS_DTWIL reward curve") # title
plt.ylabel("Return") # y label
plt.xlabel("Episodes") # x label
plt.legend(
    loc='best',
    fontsize=10,
    shadow=True,
    facecolor='white',
    edgecolor='black')
plt.savefig(saved_dir+prefix+'_reward'+suffix+'.png')
plt.clf()

for idx, path, in enumerate(PATH):
    vgg = loadmat(log_dir+path+'/logs.mat')
    y = np.array(vgg["dtws"])[0][1:]
    x = np.arange(y.shape[0])
    plt.plot(x,y,label=label_list[idx])
    print(label_list[idx]," distance: ",y[-20:].mean())
plt.title("PETS_DTWIL DTW distance curve") # title
plt.ylabel("DTW distance") # y label
plt.xlabel("Episodes") # x label
plt.legend(
    loc='best',
    fontsize=10,
    shadow=True,
    facecolor='white',
    edgecolor='black')
plt.savefig(saved_dir+prefix+'_distance'+suffix+'.png')
plt.clf()

