import os
from typing import Tuple
import gym
import numpy as np
import tqdm
from absl import app, flags
from tensorboardX import SummaryWriter
import warnings

import argparse
import torch
from vae import VAE
import utils
import d4rl
from tqdm import tqdm
import torch.nn as nn

import torch.nn.functional as F
from coolname import generate_slug
import json
from utils import get_lr
import tree

warnings.filterwarnings('ignore')
# CUDA_VISIBLE_DEVICES=2 python traj_vae_otr.py --env antmaze --no_normalize --seed=6 --dataset=medium-diverse --lambda_loss=0.3


parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
# dataset
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--lambda_loss', type=float, default=1.0)
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--version', type=str, default='v2')
parser.add_argument('--k', type=int, default=10)
# model
parser.add_argument('--model', default='VAE', type=str)
parser.add_argument('--hidden_dim', type=int, default=512) 
parser.add_argument('--beta', type=float, default=0.5)
# train
parser.add_argument('--num_iters', type=int, default=int(1e5))
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', default=0.0001, type=float)
parser.add_argument('--scheduler', default=False, action='store_true')
parser.add_argument('--gamma', default=0.95, type=float)
parser.add_argument('--no_max_action', default=False, action='store_true')
parser.add_argument('--clip_to_eps', default=False, action='store_true')
parser.add_argument('--eps', default=1e-4, type=float)
parser.add_argument('--latent_dim', default=None, type=int, help="default: action_dim * 2")
parser.add_argument('--no_normalize', default=False, action='store_true', help="do not normalize states")
args = parser.parse_args()

device = 'cuda'

def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]

    for i in tqdm(range(len(observations))):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs

def merge_trajectories(trajs):
  flat = []
  for traj in trajs:
    for transition in traj:
      flat.append(transition)
  return tree.map_structure(lambda *xs: np.stack(xs), *flat)

def qlearning_dataset_with_timeouts(env,
                                    dataset=None,
                                    terminate_on_end=False,
                                    disable_goal=True,
                                    **kwargs):
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    realdone_ = []
    if "infos/goal" in dataset:
        if not disable_goal:
            dataset["observations"] = np.concatenate(
                [dataset["observations"], dataset['infos/goal']], axis=1)
        else:
            pass

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i]
        new_obs = dataset['observations'][i + 1]
        action = dataset['actions'][i]
        reward = dataset['rewards'][i]
        done_bool = bool(dataset['terminals'][i])
        realdone_bool = bool(dataset['terminals'][i])
        if "infos/goal" in dataset:
            final_timestep = True if (dataset['infos/goal'][i] !=
                                dataset['infos/goal'][i + 1]).any() else False
        else:
            final_timestep = dataset['timeouts'][i]

        if i < N - 1:
            done_bool += final_timestep

        if (not terminate_on_end) and final_timestep:
        # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        realdone_.append(realdone_bool)
        episode_step += 1

    return {
      'observations': np.array(obs_),
      'actions': np.array(action_),
      'next_observations': np.array(next_obs_),
      'rewards': np.array(reward_)[:],
      'terminals': np.array(done_)[:],
      'realterminals': np.array(realdone_)[:],
  }

def load_trajectories(name: str, fix_antmaze_timeout=True):
    env = gym.make(name)
    if "antmaze" in name and fix_antmaze_timeout:
        dataset = qlearning_dataset_with_timeouts(env)
    else:
        dataset = d4rl.qlearning_dataset(env)
    dones_float = np.zeros_like(dataset['rewards'])

    for i in range(len(dones_float) - 1):
        if np.linalg.norm(dataset['observations'][i + 1] -
                      dataset['next_observations'][i]
                     ) > 1e-6 or dataset['terminals'][i] == 1.0:
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1

    if 'realterminals' in dataset:
        masks = 1.0 - dataset['realterminals'].astype(np.float32)
    else:
        masks = 1.0 - dataset['terminals'].astype(np.float32)
    
    traj = split_into_trajectories(
      observations=dataset['observations'].astype(np.float32),
      actions=dataset['actions'].astype(np.float32),
      rewards=dataset['rewards'].astype(np.float32),
      masks=masks,
      dones_float=dones_float.astype(np.float32),
      next_observations=dataset['next_observations'].astype(np.float32))
    return traj

def compute_returns(traj):
    episode_return = 0
    for transition in traj:
      episode_return += transition[2]
    return episode_return

def compute_rewards_per_step(traj, mean_center):
    score_ = []
    per_done = []
    per_ndone = []
    i = 1
    for j, traj in enumerate(trajs):
        if len(traj) > 1 :
            scores = []
            scores_ = []
            per_done_ = []
            per_ndone_ = []
            for step in range(len(traj)):
                states_1 = traj[step][0]
                actions_1 = traj[step][1]   
                train_states = torch.from_numpy(states_1).to(device)
                train_actions = torch.from_numpy(actions_1).to(device)
                _, mean1, std1 = vae(train_states, train_actions)
                pdist = nn.PairwiseDistance(p=2)
                output = pdist(mean_center,mean1)
                scores.append(output.item())

            for exp_lambda in range(1,11):
                scores_ = np.exp(-exp_lambda*np.array(scores))              
                if traj[step][2] == 1.0:
                    per_done_.append(scores_.mean())
                else:
                    per_ndone_.append(scores_.mean())
            
            if traj[step][2] == 1.0:
                per_done.append(per_done_)
            else:
                per_ndone.append(per_ndone_)
    
    print(np.mean(np.array(per_ndone), axis=0))
    print('numbers of the ndone trajs:', len(np.array(per_ndone)))
    print(np.mean(np.array(per_done), axis=0))
    print('numbers of the done trajs:', len(np.array(per_done)))
 


# train vae
env_name = f"{args.env}-{args.dataset}-{args.version}"
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
if args.no_max_action:
    max_action = None
print('state_dim:', state_dim, 'action_dim:', action_dim, 'max_action:', max_action)
latent_dim = action_dim * 2
if args.latent_dim is not None:
    latent_dim = args.latent_dim

# original dataset
replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../implicit_q_learning/datasets/datasets_cvae_full_split_1/antmaze-' + args.dataset + '-v2.hdf5')
# split expert from original dataset
name = 'antmaze-'+args.dataset+'-v2'
trajs = load_trajectories(name)
print('trajs numbers:', len(trajs))
returns = [sum([t[2] for t in traj]) / (1e-4 + np.linalg.norm(traj[0][0][:2])) for traj in trajs]
idx = np.argpartition(returns, -args.k)[-args.k:]
demo_returns = [returns[i] for i in idx]
print(f"demo returns {demo_returns}, mean {np.mean(demo_returns)}")
expert_demo = []
for i in idx:
    expert_demo.append(trajs[i])
expert_demo = merge_trajectories(expert_demo)    
print('length of selected expert demo:', len(expert_demo[0]))

if not args.no_normalize:
    mean, std = replay_buffer.normalize_states()
else:
    print("No normalize")
if args.clip_to_eps:
    replay_buffer.clip_to_eps(args.eps)
states = replay_buffer.state
actions = replay_buffer.action

# train
if args.model == 'VAE':
    vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=args.hidden_dim).to(device)
else:
    raise NotImplementedError
optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.scheduler:
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.gamma)

total_size = states.shape[0]
batch_size = args.batch_size
lambda_loss = args.lambda_loss

for step in tqdm(range(args.num_iters + 2), desc='train'):
    idx = np.random.choice(total_size, batch_size-5)
    idx_self = np.random.choice(len(expert_demo[0]), 5, replace=False)
    states_1 = list(states[idx])
    actions_1 = list(actions[idx])
    states_2 = list(expert_demo[0][idx_self])
    actions_2 = list(expert_demo[1][idx_self])
    states_t = np.array(states_1 + states_2)
    actions_t = np.array(actions_1 + actions_2)
    
    train_states = torch.from_numpy(states_t).to(device)
    train_actions = torch.from_numpy(actions_t).to(device)

    # Variational Auto-Encoder Training
    recon, mean, std = vae(train_states, train_actions)

    indices_z = torch.tensor([251, 252, 253, 254, 255]).to(device)
    sub_std = torch.index_select(std, 0, indices_z).to(device)
    sub_mean = torch.index_select(mean, 0, indices_z).to(device)
    std_loss = torch.var(sub_std, 0, unbiased=False).mean()
    mean_loss = torch.var(sub_mean, 0, unbiased=False).mean()   

    recon_loss = F.mse_loss(recon, train_actions)
    KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
    vae_loss = recon_loss + args.beta * KL_loss + std_loss * lambda_loss + mean_loss * lambda_loss
    #vae_loss = recon_loss + args.beta * KL_loss
    
    optimizer.zero_grad()
    vae_loss.backward()
    optimizer.step()
    
    if step == 100000:
        torch.save(vae.state_dict(), './models/vae_model_%s_%s_%s_%s.pt' %
                   (args.env, args.dataset, lambda_loss, step))


# load data
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
latent_dim = action_dim * 2

# calculate center point of expert distribution
states_expert = expert_demo[0]
actions_expert = expert_demo[1]
train_states = torch.from_numpy(states_expert).to(device)
train_actions = torch.from_numpy(actions_expert).to(device)
_, mean_all, std_all = vae(train_states, train_actions)
mean_center = torch.mean(mean_all, 0)
std_center = torch.mean(std_all, 0)

#compute_rewards_per_step(trajs, mean_center)
print(args.dataset, lambda_loss)
    