import os
import sys
sys.path.append('.')

import argparse
import pickle
import random
import numpy as np
import torch
import gym
import d4rl

from models.preference_transformer import PreferenceTransformer
from models.temporal import TrajCondUnet
from models.diffusion import GaussianDiffusion
from models.score_model import ScoreModel
from data import BlockRankingDataset
from utils.normalizer import DatasetNormalizer
from utils.logger import Logger, make_log_dirs
from utils.trainer import Trainer
from utils.render import Render
from utils.timer import Timer
from utils.evaluator import Evaluator
from utils.helpers import make_dataset
from utils.robosuite import make_robosuite_env_and_dataset


def get_tlen(trajs, normalizer, obs_dim):
    tlens = []
    max_len = trajs.shape[1]
    for traj in trajs:
        observations = normalizer.unnormalize(traj[..., :obs_dim], "observations")
        actions = normalizer.unnormalize(traj[..., obs_dim:], "actions")
        for i in range(max_len):
            if np.sum(np.abs(observations[i])) + np.sum(np.abs(actions[i])) < 1e-2:
                tlens.append(i)
                break
            if i == max_len - 1:
                tlens.append(max_len)
    return tlens


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--algo-name", type=str, default="Dropout")
    parser.add_argument("--task", type=str, default="walker2d-medium-expert-v2")
    parser.add_argument("--domain", type=str, default="gym")
    parser.add_argument("--teacher", type=str, default="block_ranking", choices=["scripted", "human", "block_ranking"])
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--episode-len", type=int, default=1000)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # preference
    parser.add_argument("--use-human-label", type=bool, default=False)
    parser.add_argument("--pref-episode-len", type=int, default=100)
    parser.add_argument("--pref-num", type=int, default=500)
    parser.add_argument("--dist-noise", type=float, default=0.1)
    parser.add_argument("--pref-embed-dim", type=int, default=256)
    parser.add_argument("--warmup-steps", type=int, default=10000)
    parser.add_argument("--pref-lr", type=float, default=1e-4)
    parser.add_argument("--pref-max-iters", type=int, default=50)
    parser.add_argument("--pref-num-steps-per-iter", type=int, default=100)
    parser.add_argument("--pref-batch-size", type=int, default=256)

    # diffusion
    parser.add_argument("--diff-episode-len", type=int, default=1000)
    parser.add_argument("--improve-step", type=int, default=20)
    parser.add_argument("--diff-embed-dim", type=int, default=128)
    parser.add_argument("--dim-mults", type=int, default=[1, 2, 4, 8])
    parser.add_argument("--n-diffusion-steps", type=int, default=1000)
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--diff-ema-start-epoch", type=int, default=20)
    parser.add_argument("--guidance-scale", type=float, default=1.2)
    parser.add_argument("--diff-lr", type=float, default=1e-4)
    parser.add_argument("--diff-max-iters", type=int, default=500)
    parser.add_argument("--diff-num-steps-per-iter", type=int, default=1000)
    parser.add_argument("--diff-batch-size", type=int, default=64)

    return parser.parse_args()


def train(args=get_args()):
    env = gym.make(args.task)
    dataset = env.get_dataset()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    
    datadir = f'./data/human_label/{args.task}/data.pkl'
    with open(datadir, 'rb') as f:
        label_dataset = pickle.load(f)
    if not args.use_human_label:
        label_dataset = make_dataset(dataset, label_dataset, args.pref_num)
    normalizer = DatasetNormalizer(dataset)
    dataset = BlockRankingDataset(args.task, dataset, normalizer, label_dataset, args.episode_len, args.pref_num, args.device)

    args.obs_shape = env.observation_space.shape
    args.obs_dim = int(np.prod(args.obs_shape))
    args.action_dim = int(np.prod(env.action_space.shape))
    args.max_action = env.action_space.high[0]

    # # create preference model
    preference_model = ScoreModel(observation_dim=args.obs_dim,
                                  action_dim=args.action_dim,
                                  device=args.device)
    preference_model.to(args.device)
    # preference_model.load_state_dict(torch.load('logs_0919/%s/use_human_label=%s/preference.pth'%(args.task, args.use_human_label)))
    preference_optim = torch.optim.Adam(preference_model.parameters(), lr=args.pref_lr)
    preference_scheduler = torch.optim.lr_scheduler.LambdaLR(
        preference_optim,
        lambda steps: min((steps+1)/args.warmup_steps, 1)
    )

    # crearte diffusion model
    temporal_model = TrajCondUnet(args.diff_episode_len, args.obs_dim + args.action_dim, hidden_dim=args.diff_embed_dim, dim_mults=args.dim_mults, condition_dropout=args.dropout)
    diffusion_model = GaussianDiffusion( 
        model=temporal_model,
        horizon=args.diff_episode_len,
        observation_dim=args.obs_dim,
        action_dim=args.action_dim,
        n_timesteps=args.n_diffusion_steps,
        guidance_scale=args.guidance_scale,
        loss_type='l2',
        clip_denoised=False,
    )
    diffusion_model.to(args.device)
    # diffusion_model.load_state_dict(torch.load('logs_0919/%s/use_human_label=%s/diffusion.pth'%(args.task, args.use_human_label)))
    diffusion_optim = torch.optim.Adam(diffusion_model.parameters(), args.diff_lr)
    
    # logger
    log_dirs = make_log_dirs(args.task, args.algo_name, args.seed, vars(args))
    # key: output file name, value: output handler type
    output_config = {
        "consoleout_backup": "stdout",
        "critic_training_progress": "csv",
        "diffusion_training_progress": "csv",
        "tb": "tensorboard"
    }
    logger = Logger(log_dirs, output_config)
    logger.log_hyperparameters(vars(args))
    
    # render
    render = Render(env, args.task, args.obs_dim, args.action_dim)
    # timer
    timer = Timer()
    # evaluator 
    evaluator = Evaluator(env, normalizer)
    # trainer
    trainer = Trainer(preference_model, diffusion_model, None, dataset, logger, timer, render, evaluator, device=args.device)
    
    print('-------------train preference model-------------')
    trainer.train_preference(
        preference_optim,
        preference_scheduler,
        args.pref_max_iters,
        args.pref_num_steps_per_iter,
        batch_size=args.pref_batch_size,
        dist_noise=args.dist_noise,
    )

    trainer.preference_model.eval()
    trainer.dataset.set_returns(trainer.preference_model)
    trainer.dataset.block_ranking(args.improve_step)
    
    print('-------------train diffusion model-------------')
    trainer.train_diffusion(
        optim=diffusion_optim,
        ema_decay=0.995,
        epoch_start_ema=args.diff_ema_start_epoch,
        update_ema_every=10,
        max_iters=args.diff_max_iters,
        num_steps_per_iter=args.diff_num_steps_per_iter,
        batch_size=args.diff_batch_size,
    )


if __name__ == '__main__':
    train()
    