import os
import sys
sys.path.append('.')

import argparse
import pickle
import random
import numpy as np
import torch
import gym
import d4rl

from models.preference_transformer import PreferenceTransformer
from models.score_model import ScoreModel
from models.temporal import TrajCondUnet
from models.diffusion import GaussianDiffusion
from models.actor import DeterministicActor, StochasticActor, RNNGMMActor
from data import BlockRankingDataset, ActorDataset, TrajectoryDataset
from utils.normalizer import DatasetNormalizer
from utils.logger import Logger, make_log_dirs
from utils.trainer import Trainer
from utils.render import Render
from utils.timer import Timer
from utils.evaluator import Evaluator
from utils.helpers import make_dataset


def is_valid(trajs, threshold=30):
    masks = np.zeros(len(trajs))
    for i, traj in enumerate(trajs):
        if np.abs(traj).max() < threshold:
            masks[i] = 1
    return masks


def get_tlen(trajs, normalizer, obs_dim):
    tlens = []
    max_len = trajs.shape[1]
    for traj in trajs:
        observations = normalizer.unnormalize(traj[..., :obs_dim], "observations")
        actions = normalizer.unnormalize(traj[..., obs_dim:], "actions")
        for i in range(max_len):
            if np.sum(np.abs(observations[i])) + np.sum(np.abs(actions[i])) < 0.1:
                tlens.append(i)
                break
            if i == max_len - 1:
                tlens.append(max_len)
    return np.array(tlens)


def flow_to_better(pref_model, diffusion_model, normalizer, trajs, init_obs, max_flow_step=5, threshold=1.03):
    cond = trajs
    cond_tlens = get_tlen(trajs.detach().cpu().numpy(), normalizer, diffusion_model.observation_dim)
    cond_score = pref_model.predict_traj_return(trajs, cond_tlens)
    print("min_score: %.4f    max_score:%.4f    mean_score:%.4f"%(np.min(cond_score), np.max(cond_score), np.mean(cond_score)))
    
    generate_trajs = [cond.detach().cpu().numpy()]
    for i in range(max_flow_step):
        print("flow step:", i)
        generate_traj = diffusion_model.flow_one_step(cond, init_obs)
        
        tlens = get_tlen(generate_traj.detach().cpu().numpy(), normalizer, diffusion_model.observation_dim)
        score = pref_model.predict_traj_return(generate_traj, tlens)
        valid = is_valid(generate_traj.detach().cpu().numpy())
        
        indices = []
        ratio = score / cond_score
        for j in range(len(score)):
            if valid[j] and ratio[j] > threshold:
                cond[j] = generate_traj[j]
                cond_score[j] = score[j]  
                indices.append(j)
        
        print("improve ratio: %.4f"%(len(indices) / len(cond)))
        print("min_score: %.4f    max_score:%.4f    mean_score:%.4f"%(np.min(cond_score), np.max(cond_score), np.mean(cond_score)))
        
        generate_trajs.append(generate_traj[indices].detach().cpu().numpy())
        if (len(indices) / len(cond)) < 0.05:
            break
    
    return generate_trajs


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="bc")
    parser.add_argument("--task", type=str, default="hopper-medium-replay-v2")
    parser.add_argument("--domain", type=str, default="gym")
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--episode-len", type=int, default=1000)
    parser.add_argument("--model-dir", type=str, default=None)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # preference
    parser.add_argument("--use-human-label", type=bool, default=False)
    parser.add_argument("--pref-episode-len", type=int, default=100)
    parser.add_argument("--pref-num", type=int, default=500)
    parser.add_argument("--dist-noise", type=float, default=0.1)
    parser.add_argument("--pref-embed-dim", type=int, default=256)

    # diffusion
    parser.add_argument("--diff-episode-len", type=int, default=1000)
    parser.add_argument("--improve-step", type=int, default=20)
    parser.add_argument("--diff-embed-dim", type=int, default=128)
    parser.add_argument("--dim-mults", type=int, default=[1, 2, 4, 8])
    parser.add_argument("--n-diffusion-steps", type=int, default=1000)
    parser.add_argument("--upsample-temp", type=float, default=1.)
    parser.add_argument("--guidance-scale", type=float, default=1.2)
    
    # actor
    parser.add_argument("--actor-embed-dim", type=int, default=256)
    parser.add_argument("--actor-hidden-layer", type=int, default=1)
    parser.add_argument("--weight-decay", type=float, default=2e-4)
    parser.add_argument("--actor-type", type=str, default="deterministic")
    parser.add_argument("--flow-step", type=int, default=5)
    parser.add_argument("--percentile", type=float, default=0.1)
    parser.add_argument("--actor-lr", type=float, default=3e-5)
    parser.add_argument("--threshold", type=float, default=1.02)
    parser.add_argument("--actor-max-iters", type=float, default=100)
    parser.add_argument("--actor-num-steps-per-iter", type=float, default=1000)
    parser.add_argument("--actor-batch-size", type=int, default=1024)

    return parser.parse_args()


def train(args=get_args()):
    # seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True

    env = gym.make(args.task)
    dataset = env.get_dataset()
    normalizer = DatasetNormalizer(dataset)
    datadir = f'./data/human_label/{args.task}/data.pkl'
    with open(datadir, 'rb') as f:
        label_dataset = pickle.load(f)
    if not args.use_human_label:
        label_dataset = make_dataset(dataset, label_dataset)
    normalizer = DatasetNormalizer(dataset)
    dataset = BlockRankingDataset(args.task, dataset, normalizer, label_dataset, args.episode_len, args.pref_num, args.device)

    args.obs_shape = env.observation_space.shape
    args.obs_dim = int(np.prod(args.obs_shape))
    args.action_dim = int(np.prod(env.action_space.shape))

    args.max_action = env.action_space.high[0]

    # create preference model
    preference_model = ScoreModel(observation_dim=args.obs_dim,
                                  action_dim=args.action_dim,
                                  device=args.device)
    
    preference_model.load_state_dict(torch.load(os.path.join(args.model_dir, "preference.pth")))
    preference_model.to(args.device)
    preference_optim = torch.optim.Adam(preference_model.parameters(), lr=args.pref_lr)

    # diffusion
    temporal_model = TrajCondUnet(args.diff_episode_len, args.obs_dim + args.action_dim, hidden_dim=args.diff_embed_dim, dim_mults=args.dim_mults)
    diffusion_model = GaussianDiffusion( 
        model=temporal_model,
        horizon=args.diff_episode_len,
        observation_dim=args.obs_dim,
        action_dim=args.action_dim,
        n_timesteps=args.n_diffusion_steps,
        guidance_scale=args.guidance_scale,
        loss_type='l2',
        clip_denoised=False,
    )
    diffusion_model.load_state_dict(torch.load(os.path.join(args.model_dir, "diffusion.pth")))
    diffusion_model.to(args.device)
    diffusion_model.eval()

    # actor
    if args.actor_type == "deterministic":
        actor = DeterministicActor(
            observation_dim=args.obs_dim,
            action_dim=args.action_dim,
            hidden_dim=args.actor_embed_dim,
            hidden_layer=args.actor_hidden_layer
        )
    elif args.actor_type == "stochastic":
        actor = StochasticActor(
            observation_dim=args.obs_dim,
            action_dim=args.action_dim,
            hidden_dim=args.actor_embed_dim,
            hidden_layer=args.actor_hidden_layer
        )
    else:
        raise NotImplementedError
    
    actor.to(args.device)
    actor_optim = torch.optim.Adam(actor.parameters(), args.actor_lr, weight_decay=args.weight_decay)

    # logger
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args), record_params=["weight_decay", "flow_step", "threshold"])
    # key: output file name, value: output handler type
    output_config = {
        "consoleout_backup": "stdout",
        "critic_training_progress": "csv",
        "diffusion_training_progress": "csv",
        "actor_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))
    
    # render
    render = Render(env, args.task, args.obs_dim, args.action_dim)
    # timer
    timer = Timer()
    # evaluator 
    evaluator = Evaluator(env, normalizer)

    trainer = Trainer(preference_model, diffusion_model, actor, dataset, logger, timer, render, evaluator, device=args.device)

    print('-------------generate_data-------------')
    preference_model.eval()
    dataset.set_returns(preference_model)
    dataset.block_ranking(args.improve_step)
    
    percentile = 100 / len(dataset.trajs)
    trajs, init_obs = dataset.get_top_traj(percentile)
    tlens = get_tlen(trajs.detach().cpu().numpy(), normalizer, args.obs_dim)
    score = preference_model.predict_traj_return(trajs.detach().cpu().numpy(), tlens)
    flow_step = dataset.improve_step - dataset.get_block_id(score.min())
    print("flow_step", flow_step)

    generate_trajs = flow_to_better(preference_model, diffusion_model, normalizer, trajs, init_obs, max_flow_step=flow_step, threshold=args.threshold)
    for i in range(len(generate_trajs)):
        np.save(os.path.join(logger.result_dir, 'generate_data_%d'%i), generate_trajs[i])
    generate_trajs = np.concatenate(generate_trajs)
    
    tlens = get_tlen(generate_trajs, normalizer, args.obs_dim)
    valid = is_valid(generate_trajs)
    score = preference_model.predict_traj_return(generate_trajs, tlens) * valid
    indices = np.argsort(score)[-len(trajs):]
    generate_trajs = generate_trajs[indices]
    tlens = tlens[indices]
    print("min score:", score[indices].min())
    np.save(os.path.join(logger.result_dir, 'min_score_%.4f.npy'%score[indices].min()), score[indices].min())

    bc_trajs = []
    for i in range(len(tlens)):
        bc_trajs.append(generate_trajs[i][:tlens[i]])
    bc_trajs = np.concatenate(bc_trajs)
    bc_dataset = dict()
    bc_dataset["observations"] = normalizer.unnormalize(bc_trajs[..., :args.obs_dim], "observations").reshape(-1, args.obs_dim)
    bc_dataset["actions"] = normalizer.unnormalize(bc_trajs[..., args.obs_dim:], "actions").reshape(-1, args.action_dim)

    bc_dataset = ActorDataset(bc_dataset, normalizer, device=args.device)
    
    print('-------------train actor-------------')
    trainer.train_actor(
        dataset=bc_dataset,
        optim=actor_optim,
        max_iters=args.actor_max_iters,
        num_steps_per_iter=args.actor_num_steps_per_iter,
        batch_size=args.actor_batch_size,
    ) 


if __name__ == "__main__":
    train()

