import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

learned_reward = np.load('/data3/zj/optimal_transport_reward/rewards/ensemble_metaworld_push-v2_initial_pairs_50_num_queries_50_num_iter_20_retrain_num_iter_20_seed_287_round_num_9.npy',allow_pickle=True)
learned_reward = np.load('/data3/zj/optimal_transport_reward/rewards/ensemble_metaworld_push-v2_initial_pairs_5_num_queries_5_num_iter_20_retrain_num_iter_20_seed_287_round_num_9.npy',allow_pickle=True)
true_reward = np.load('/data3/zj/optimal_transport_reward/data/push-v2/data_randgoal_08_200_08.npy',allow_pickle=True).tolist()['rewards']


name='metaworld_peg-unplug-side-v2'
name='metaworld_basketball-v2'
learned_reward = np.load(
        '/data3/zj/optimal_transport_reward/rewards/ensemble_'+name+'_initial_pairs_5_num_queries_5_num_iter_20_retrain_num_iter_20_seed_287_round_num_9_50_all.npy',
        allow_pickle=True)
short_name=name.split('_')[1]
true_reward = \
np.load('/data3/zj/optimal_transport_reward/data/'+short_name+'/data_randgoal_08_200_08_batch.npy', allow_pickle=True).tolist()[
    'rewards']
true_reward = \
np.load('/data3/zj/optimal_transport_reward/data/'+short_name+'/data_randgoal_08_50_08_batch.npy', allow_pickle=True).tolist()[
    'rewards']
opal_reward = np.load(
        '/data3/zj/optimal_transport_reward/opal_rewards/rewards/ensemble_'+short_name+'_initial_pairs_1_num_queries_5_num_iter_20_retrain_num_iter_20_voi_dis_seed_285_round_num_9.npy',
        allow_pickle=True)
pt_reward = np.load(
        '/data3/zj/PreferenceTransformer/learned_rewards/'+name+'.npy',
        allow_pickle=True)
pt_reward = np.load(
        '/data3/zj/optimal_transport_reward/learned_rewards/metaworld_basketball-v230queries.npy',
        allow_pickle=True)
learned_returns,true_returns,opal_returns,pt_returns = [],[],[],[]
for i in range(1000):
    learned_returns.append(np.sum(learned_reward[i*500:(i+1)*500]))
    true_returns.append(np.sum(true_reward[i*500:(i+1)*500]))
    opal_returns.append(np.sum(opal_reward[i*500:(i+1)*500]))
    pt_returns.append(np.sum(pt_reward[i*500:(i+1)*500]))
# plt.figure()
# plt.scatter(learned_reward,true_reward)




print(np.corrcoef(opal_returns,true_returns)[0,1],np.corrcoef(pt_returns,true_returns)[0,1],)
plt.figure()
plt.scatter(learned_returns,true_returns)
plt.figure()
plt.scatter(opal_returns,true_returns)
plt.figure(figsize=(8.5,6))
plt.style.use('seaborn-whitegrid')
plt.rc('font', family='Times New Roman')
plt.clf()
ax = plt.gca()
ax.xaxis.label.set_size(20)
ax.yaxis.label.set_size(20)
plt.tick_params(labelsize=17)
plt.scatter(pt_returns,true_returns)
plt.xlabel('Predicted Returns')
plt.ylabel('Ground-Truth Returns')
plt.title('Basketball-V2',fontdict={'fontsize':30})
plt.tight_layout()
plt.show()
def relabel_rewards(env,dataset,env_name,relabel='dense'):
  target_goal = env.target_goal if 'antmaze' in env_name else env.goal_locations[0]
  print ('Target Goal: ', target_goal)

  all_obs = dataset['observations'][:]

  if relabel == 'dense':
      """reward at the next state = dist(s', g)"""
      _rew = np.exp(-np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1))
  elif relabel == 'sparse':
      _rew = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  else:
      _rew = dataset['rewards'][:]

    # Also add terminals here
  if "antmaze" in env_name:
    _terminals = (np.linalg.norm(all_obs[1:,:2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  else:
    _terminals = (np.linalg.norm(all_obs[1:, :2] - target_goal, axis=1) <= 0.5).astype(np.float32)
  _terminals = np.concatenate([_terminals, np.array([0])], 0)
  if "maze2d" in env_name:
    current_length = 0
    for i in range(len(_terminals)):
      if not _terminals[i]:
        current_length+=1
      else:
        current_length = 0
      if current_length>=200:
        _terminals[i] = 1.0
        current_length = 0
  _rew = np.concatenate([_rew, np.array([0])], 0)
  _terminals[-1] = 1
  print ('Sum of rewards: ', _rew.sum())
  print ('Sum of terminals: ', _terminals.sum())
  dataset['rewards'] = _rew
  dataset['terminals'] = _terminals
  return dataset

learned_reward = np.load('/data3/zj/PreferenceTransformer/learned_rewards/antmaze-large-diverse-v2100queries.npy',allow_pickle=True)
import d4rl,gym
env=gym.make('antmaze-large-diverse-v2')
# dataset = env.get_dataset()
dataset=d4rl.qlearning_dataset(env)
print(learned_reward.shape,true_reward.shape)
dataset=relabel_rewards(env,dataset,'antmaze-large-diverse-v2')
true_reward = dataset['rewards']
plt.figure()
plt.scatter(learned_reward,true_reward)
plt.show()

def analyse_rewards(name):
    learned_reward = np.load(
        '/data3/zj/optimal_transport_reward/rewards/ensemble_'+name+'_initial_pairs_5_num_queries_5_num_iter_20_retrain_num_iter_20_seed_287_round_num_9_50_all.npy',
        allow_pickle=True)
    short_name=name.split('_')[1]
    true_reward = \
    np.load('/data3/zj/optimal_transport_reward/data/'+short_name+'/data_randgoal_08_200_08_batch.npy', allow_pickle=True).tolist()[
        'rewards']
    learned_returns = []
    true_returns = []
    for i in range(1000):
        learned_returns.append(np.sum(learned_reward[i * 500:(i + 1) * 500]))
        true_returns.append(np.sum(true_reward[i * 500:(i + 1) * 500]))
    learned_returns=np.array(learned_returns)
    true_returns=np.array(true_returns)
    argsorted = np.argsort(learned_returns)
    last_300_avg = np.mean(true_returns[argsorted[:300]])
    remaining_avg = np.mean(true_returns[argsorted[300:]])
    print(short_name,last_300_avg/remaining_avg,last_300_avg/np.mean(true_returns))
    return last_300_avg,np.mean(true_returns)




names=[]
names+=[ "metaworld_assembly-v2", "metaworld_basketball-v2", "metaworld_bin-picking-v2", "metaworld_box-close-v2" ,"metaworld_button-press-topdown-v2"]
names+=[ "metaworld_button-press-v2", "metaworld_coffee-button-v2" ,"metaworld_coffee-pull-v2" ,"metaworld_coffee-push-v2", "metaworld_disassemble-v2", "metaworld_door-close-v2"   ]
names+=[ "metaworld_door-lock-v2", "metaworld_door-open-v2", "metaworld_door-unlock-v2", "metaworld_drawer-close-v2", "metaworld_drawer-open-v2", "metaworld_faucet-close-v2" ,"metaworld_faucet-open-v2"]
names+=["metaworld_hammer-v2", "metaworld_hand-insert-v2", "metaworld_handle-press-side-v2", "metaworld_handle-press-v2", "metaworld_handle-pull-side-v2", "metaworld_handle-pull-v2" ,"metaworld_lever-pull-v2"  ]
names+=["metaworld_peg-insert-side-v2", "metaworld_peg-unplug-side-v2", "metaworld_pick-out-of-hole-v2" ,"metaworld_pick-place-v2", "metaworld_pick-place-wall-v2" ,"metaworld_plate-slide-back-side-v2" ,"metaworld_plate-slide-back-v2"  ]
names+=["metaworld_plate-slide-side-v2", "metaworld_plate-slide-v2", "metaworld_push-back-v2",  "metaworld_push-v2" ,"metaworld_push-wall-v2", "metaworld_reach-v2" ,"metaworld_reach-wall-v2"  ]
names+=[ "metaworld_soccer-v2", "metaworld_stick-push-v2"  ,"metaworld_sweep-v2",  "metaworld_sweep-into-v2", "metaworld_window-close-v2", "metaworld_window-open-v2"  ]
last_300_avg,all_avg=[],[]
for n in names:
    l,a=analyse_rewards(n)
    last_300_avg.append(l)
    all_avg.append(a)
print(np.mean(last_300_avg),np.mean(all_avg),np.mean(np.array(last_300_avg)/np.mean(all_avg)))


names2=[]
names2+=["metaworld_button-press-topdown-wall-v2", "metaworld_dial-turn-v2" ,"metaworld_button-press-wall-v2" ,"metaworld_shelf-place-v2" ,"metaworld_stick-pull-v2"]