import argparse
import numpy as np
import torch
import gym
from vae import VAE
#from coolname import generate_slug
import utils
import d4rl
from tqdm import tqdm
import h5py
import torch.nn as nn
import tree

# CUDA_VISIBLE_DEVICES=3 python add_score_vae_otr.py --num_iters=1 --env antmaze --dataset umaze 

def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs

def merge_trajectories(trajs):
  flat = []
  for traj in trajs:
    for transition in traj:
      flat.append(transition)
  return tree.map_structure(lambda *xs: np.stack(xs), *flat)

def qlearning_dataset_with_timeouts(env,
                                    dataset=None,
                                    terminate_on_end=False,
                                    disable_goal=True,
                                    **kwargs):
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    if "infos/goal" in dataset:
        if not disable_goal:
            dataset["observations"] = np.concatenate(
                [dataset["observations"], dataset['infos/goal']], axis=1)
        else:
            pass

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i + 1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if "infos/goal" in dataset:
            final_timestep = True if (dataset['infos/goal'][i] !=
                                dataset['infos/goal'][i + 1]).any() else False
        else:
            final_timestep = dataset['timeouts'][i]

        if i < N - 1:
            done_bool += final_timestep

        if (not terminate_on_end) and final_timestep:
        # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        episode_step += 1

    return {
      'observations': np.array(obs_),
      'actions': np.array(action_),
      'next_observations': np.array(next_obs_),
      'rewards': np.array(reward_)[:],
      'terminals': np.array(done_)[:],
      'realterminals': np.array(realdone_)[:],
  }

def load_trajectories(name: str, fix_antmaze_timeout=True):
    env = gym.make(name)
    if "antmaze" in name and fix_antmaze_timeout:
        dataset = qlearning_dataset_with_timeouts(env)
    else:
        dataset = d4rl.qlearning_dataset(env)
    dones_float = np.zeros_like(dataset['rewards'])

    for i in range(len(dones_float) - 1):
        if np.linalg.norm(dataset['observations'][i + 1] -
                      dataset['next_observations'][i]
                     ) > 1e-6 or dataset['terminals'][i] == 1.0:
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1

    if 'realterminals' in dataset:
        masks = 1.0 - dataset['realterminals'].astype(np.float32)
    else:
        masks = 1.0 - dataset['terminals'].astype(np.float32)
    
    traj = split_into_trajectories(
      observations=dataset['observations'].astype(np.float32),
      actions=dataset['actions'].astype(np.float32),
      rewards=dataset['rewards'].astype(np.float32),
      masks=masks,
      dones_float=dones_float.astype(np.float32),
      next_observations=dataset['next_observations'].astype(np.float32))
    return traj

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--num_iters', type=int, default=int(1e5))
parser.add_argument('--k', type=int, default=10)
args = parser.parse_args()


device = 'cuda'

# load data
env_name = args.env+'-'+args.dataset+'-v2'
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

print(state_dim, action_dim, max_action)
latent_dim = action_dim * 2

replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../implicit_q_learning/datasets/datasets_merge_full_split_'+str(args.k)+'_ablation/'+args.env+'-'+args.dataset+'-v2.hdf5')


states = replay_buffer.state
actions = replay_buffer.action

vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=512).to(device)
vae.load_state_dict(torch.load('./models/vae_model_'+args.env+'_' + args.dataset + '_10_0_100000.pt'))

total_size = states.shape[0]
# calculate center point of expert distribution
trajs = load_trajectories(env_name)
if "antmaze" in env_name:
    returns = [sum([t[2] for t in traj]) / (1e-4 + np.linalg.norm(traj[0][0][:2])) for traj in trajs]
else:
    returns = [sum([t[2] for t in traj]) for traj in trajs]
idx = np.argpartition(returns, -args.k)[-args.k:]
demo_returns = [returns[i] for i in idx]
print(f"demo returns {demo_returns}, mean {np.mean(demo_returns)}")
expert_demo = []
for i in idx:
    expert_demo.append(trajs[i])
expert_demo = merge_trajectories(expert_demo)    
print('length of selected expert demo:', len(expert_demo[0]))

states_expert = expert_demo[0]
actions_expert = expert_demo[1]
train_states = torch.from_numpy(states_expert).to(device)
train_actions = torch.from_numpy(actions_expert).to(device)
_, mean_all, std_all = vae(train_states, train_actions)
mean = torch.mean(mean_all, 0)
std = torch.mean(std_all, 0)


scores_ = []
for step in tqdm(range(total_size)):
    states_1 = states[step]
    actions_1 = actions[step]   
    train_states = torch.from_numpy(states_1).to(device)
    train_actions = torch.from_numpy(actions_1).to(device)
    _, mean1, std1 = vae(train_states, train_actions)
    pdist = nn.PairwiseDistance(p=2)
    output = pdist(mean,mean1)

#    KL_loss = 0.5*((std/std1).pow(2)+torch.log((std1/std).pow(2))+((mean-mean1)/std1).pow(2) -1).mean()
#    KL_loss = torch.exp(-KL_loss)
#    if KL_loss.item() >= 1:
#        scores_.append(0.0)
#    else:
#        scores_.append(1.0 - KL_loss.item())
#    scores_.append(KL_loss.item())  

    scores_.append(output.item())
f1 = h5py.File('../implicit_q_learning/datasets/datasets_merge_full_split_'+str(args.k)+'_ablation/'+args.env+'-'+args.dataset+'-v2.hdf5', 'r')
f2 = h5py.File('../implicit_q_learning/datasets/datasets_merge_full_split_'+str(args.k)+'_ablation/'+args.env+'-'+args.dataset+'-v2-oriscores.hdf5',"w")
f2['observations'] = np.array(f1['observations'][:])
f2['actions'] = np.array(f1['actions'][:])
f2['next_observations'] = np.array(f1['next_observations'][:])
f2['rewards'] = np.array(f1['rewards'][:])
f2['terminals'] = np.array(f1['terminals'][:])
f2['scores'] = np.array(scores_)
f1.close()
f2.close()
