import numpy as np
import matplotlib
import matplotlib.pyplot as plt
#%matplotlib inline
import csv
import pickle
import os
import colour
import torch
from rlkit.torch.networks import SnailEncoder,MlpEncoder
# config





exp_id = '2019_11_18_10_59_38' #mine
tlow, thigh = 100, 130 # task ID range
# see `n_tasks` and `n_eval_tasks` args in the training config json
# by convention, the test tasks are always the last `n_eval_tasks` IDs
# so if there are 100 tasks total, and 20 test tasks, the test tasks will be IDs 81-100
epoch = 361


expdir = './outputmetacure/cheetah-vel-sparse/{}/eval_trajectories/'.format(exp_id)
dir = './outputmetacure/cheetah-vel-sparse/{}/'.format(exp_id)

# helpers
def load_pkl(task):
    with open(os.path.join(expdir, 'task{}-epoch{}-run0.pkl'.format(task, epoch)), 'rb') as f:
        data = pickle.load(f)
    return data

def load_pkl_prior():
    with open(os.path.join(expdir, 'prior-epoch{}.pkl'.format(epoch)), 'rb') as f:
        data = pickle.load(f)
    return data

#paths = load_pkl_prior()
goals = [load_pkl(task)[0]['goal'] for task in range(tlow, thigh)]


all_paths_rew = []
for task in range(tlow, thigh):
    paths = [t['rewards'] for t in load_pkl(task)]
    all_paths_rew.append(paths)

'''all_paths_z_means = []
all_paths_z_vars = []
for task in range(tlow, thigh):
    means = [t['z_means'] for t in load_pkl(task)]
    vars = [t['z_vars'] for t in load_pkl(task)]
    all_paths_z_means.append(means)
    all_paths_z_vars.append(means)'''





task = 12
ap = [t for t in load_pkl(task+tlow)]


#for i in range(10):
#    print(ap[i]['z_means'])

list_vel=[]
list_goal=[]
list_z_means=[]
list_z_vars=[]
for i in range(2):
    list_vel.append([])
    list_goal.append([])
    list_z_means.append([])
    list_z_vars.append([])
    temp=0
    temp2 = 0
    for j in range(64):
        temp = temp + ap[i]['env_infos'][j]['reward_forward']
        print(ap[i]['env_infos'][j]['velocity'],ap[i]['goal'])
        temp2 = temp2 + ap[i]['env_infos'][j]['reward_ctrl']
        list_vel[i].append(ap[i]['env_infos'][j]['velocity'])
        list_goal[i].append(ap[i]['goal'])
        list_z_means[i].append(ap[i]['z_means'][j])
        list_z_vars[i].append(ap[i]['z_vars'][j])
        #print(ap[i]['env_infos'][j]['reward_ctrl'])
        #print(ap[i]['env_infos'][j]['reward_forward'])
    print('n',temp,temp2)

for task in range(30):
    ap = [t for t in load_pkl(task + tlow)]
    aver_vel = 0
    for i in range(2):
        for j in range(64):
            aver_vel = aver_vel + ap[i]['env_infos'][j]['velocity']
    aver_vel = aver_vel / 128
    print('n', aver_vel,ap[0]['goal'],task)
    #print(ap[-1]['z_means'][-1])
    #print(ap[-1]['z_vars'][-1])


plt.figure()
for i in range(2):
    plt.subplot(1, 2, i+1)
    xs = np.linspace(1,64,64)
    for j in range(len(xs)):
        plt.scatter(xs[j],list_vel[i][j],c='b')
    for j in range(5):
        plt.text(xs[j], list_vel[i][j], '%f' % np.min(list_z_vars[i][j]))
    for j in range(9,len(xs),8):
        plt.text(xs[j],list_vel[i][j],'%f'%np.min(list_z_vars[i][j]))
    plt.plot(xs, np.array(list_goal[i]), 'r')
    plt.plot(xs, np.array(list_goal[i])+0.5, 'r:')
    plt.plot(xs, np.array(list_goal[i])-0.5, 'r:')

plt.show()