import os
from typing import Any, Tuple
import gym
import numpy as np
import tqdm
from absl import app, flags
from tensorboardX import SummaryWriter
import warnings
import h5py
import argparse
import torch
from vae import VAE
import utils
import d4rl
from tqdm import tqdm
import torch.nn as nn

import torch.nn.functional as F
from coolname import generate_slug
import json
from utils import get_lr
import tree

warnings.filterwarnings('ignore')
# CUDA_VISIBLE_DEVICES=1 python traj_vae_loop_loc_label_detach.py --env ant --dataset medium

def split_into_trajectories_scores(observations, actions, rewards, masks, dones_float,
                            next_observations, scores):
    
    trajs_scores = [[]]

    for i in tqdm(range(len(observations)), ncols = 150):

        
        trajs_scores[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i], scores[i]))
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs_scores.append([])

    return trajs_scores

def compute_iql_reward_scale(trajs):
  """Rescale rewards based on max/min from the dataset.
  This is also used in the original IQL implementation.
  """
  trajs = trajs.copy()

  def compute_returns(tr):
    return sum([step[6] for step in tr])

  trajs.sort(key=compute_returns)
  reward_scale = 1000.0 / (
      compute_returns(trajs[-1]) - compute_returns(trajs[0]))
  return reward_scale

def compute_done():
    f = h5py.File('../scik/datasets/'+env_name+'.hdf5', 'r')
    dataset = {}
    dataset['observations'] = np.array(f['observations'][:])
    dataset['actions'] = np.array(f['actions'][:])
    dataset['next_observations'] = np.array(f['next_observations'][:])
    dataset['rewards'] = np.array(f['rewards'][:])
    dataset['terminals'] = np.array(f['terminals'][:])
    dones_float = np.zeros_like(dataset['rewards'])
    for i in tqdm(range(len(dones_float) - 1)):
        if np.linalg.norm(dataset['observations'][i + 1] -
                      dataset['next_observations'][i]
                     ) > 1e-6 or dataset['terminals'][i] == 1.0:
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1
    if 'realterminals' in f:
        masks = 1.0 - dataset['realterminals'].astype(np.float32)
    else:
        masks = 1.0 - dataset['terminals'].astype(np.float32)   
    return dones_float, masks

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=6)
# dataset
parser.add_argument('--env', type=str, default='hopper')
parser.add_argument('--lambda_loss', type=float, default=1.0)
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--version', type=str, default='v2')
parser.add_argument('--k', type=int, default=1)
parser.add_argument('--save_dir', type=str, default='./tmp/')
# model
parser.add_argument('--model', default='VAE', type=str)
parser.add_argument('--hidden_dim', type=int, default=256) 
parser.add_argument('--beta', type=float, default=0.5)
# train
parser.add_argument('--num_iters', type=int, default=int(1e4))
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--weight_decay', default=0.0001, type=float)
parser.add_argument('--scheduler', default=False, action='store_true')
parser.add_argument('--gamma', default=0.95, type=float)
parser.add_argument('--no_max_action', default=False, action='store_true')
parser.add_argument('--clip_to_eps', default=False, action='store_true')
parser.add_argument('--eps', default=1e-4, type=float)
parser.add_argument('--latent_dim', default=None, type=int, help="default: action_dim * 2")
parser.add_argument('--no_normalize', default=False, action='store_true', help="do not normalize states")
args = parser.parse_args()

device = 'cuda'

# train vae
env_name = f'{args.env}-{args.dataset}-{args.version}'
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
if args.no_max_action:
    max_action = None
print('state_dim:', state_dim, 'action_dim:', action_dim, 'max_action:', max_action)
latent_dim = action_dim * 2
dones_float, masks = compute_done()
replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../scik/datasets/'+env_name+'.hdf5')
for i in tqdm(range(100)):
    # original dataset
    replay_selfbuffer = utils.ReplayBuffer(state_dim, action_dim)
    replay_selfbuffer.convert_selfdata('../scik/datasets/datasets_class/'+args.env+'_'+args.dataset+'/'+args.env+'_class_'+str(i)+'.hdf5')

    if not args.no_normalize:
        mean, std = replay_buffer.normalize_states()
    else:
        print("No normalize")
    if args.clip_to_eps:
        replay_buffer.clip_to_eps(args.eps)
    states = replay_buffer.state
    actions = replay_buffer.action
    states_expert = replay_selfbuffer.state
    actions_expert = replay_selfbuffer.action
    lambda_loss = 0.5

    # train
    if args.model == 'VAE':
        vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=args.hidden_dim).to(device)
    else:
        raise NotImplementedError
    optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if args.scheduler:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.gamma)

    total_size = states.shape[0]
    batch_size = args.batch_size
    #lambda_loss = args.lambda_loss

    for step in range(args.num_iters + 2):
        idx = np.random.choice(total_size, batch_size-5)
        idx_self = np.random.choice(states_expert.shape[0], 5, replace=False)
        states_1 = list(states[idx])
        actions_1 = list(actions[idx])
        states_2 = list(states_expert[idx_self])
        actions_2 = list(actions_expert[idx_self])
        states_t = np.array(states_1 + states_2)
        actions_t = np.array(actions_1 + actions_2)
    
        train_states = torch.from_numpy(states_t).to(device)
        train_actions = torch.from_numpy(actions_t).to(device)

        # Variational Auto-Encoder Training
        recon, mean, std = vae(train_states, train_actions)

        indices_z = torch.tensor([251, 252, 253, 254, 255]).to(device)
        sub_std = torch.index_select(std, 0, indices_z).to(device)
        sub_mean = torch.index_select(mean, 0, indices_z).to(device)
        std_loss = torch.var(sub_std, 0, unbiased=False).mean()
        mean_loss = torch.var(sub_mean, 0, unbiased=False).mean()   

        recon_loss = F.mse_loss(recon, train_actions)
        KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
        vae_loss = recon_loss + args.beta * KL_loss + std_loss * lambda_loss + mean_loss * lambda_loss
        #vae_loss = recon_loss + args.beta * KL_loss
    
        optimizer.zero_grad()
        vae_loss.backward()
        optimizer.step()
    
    #print('vae training over...........')
    # calculate center point of expert distribution
    train_states = torch.from_numpy(states_expert).to(device)
    train_actions = torch.from_numpy(actions_expert).to(device)
    _, mean_all, std_all = vae(train_states, train_actions)
    mean = torch.mean(mean_all, 0)
    std = torch.mean(std_all, 0)
    
    train_states = torch.from_numpy(states).to(device)
    train_actions = torch.from_numpy(actions).to(device)
    _, mean1, _ = vae(train_states, train_actions)
    
    score = (torch.sum((mean-mean1).pow(2), 1)).pow(0.5)
    
    scores_ = []
    for step in tqdm(range(total_size)):
        scores_.append(score[step].item())
    #print(scores_[:100])
    f1 = h5py.File('../scik/datasets/'+env_name+'.hdf5', 'r')
    f2 = h5py.File('../scik/datasets/datasets_class_scores/'+args.env+'_'+args.dataset+'/'+args.env+'_class_'+str(i)+'-oriscores.hdf5',"w")
    f2['scores'] = np.array(scores_)
    f2.close()
    #print('dataset saving over..................')
    traj = split_into_trajectories_scores(
      observations=np.array(f1['observations'][:]),
      actions=np.array(f1['actions'][:]),
      rewards=np.array(f1['rewards'][:]),
      masks=masks,
      dones_float=dones_float.astype(np.float32),
      next_observations=np.array(f1['next_observations'][:]),
      scores=np.exp(-1 * np.array(scores_)))
    reward_scale = compute_iql_reward_scale(traj)
    print('===========', reward_scale)
    print('\n')
    f1.close()

        