import os
import torch
import numpy as np
from collections import deque

from gen_rl.policy.ddpg import DDPG
from gen_rl.commons.seeds import set_randomSeed
from gen_rl.commons.scheduler import AnnealingSchedule
from gen_rl.commons.buffer import ReplayBuffer
from gen_rl.commons.tensorboard import TensorBoard
from gen_rl.commons.utils import logging, mean_dict, tile_images, save_mp4
from gen_rl.commons.launcher import launch_env, launch_models


def decompose_obs(args, obses_t, obses_tp1):
    if args["env_name"].lower() == "pendulum-v0":
        state, next_state = obses_t[:, 3:], obses_tp1[:, 3:]
        obses_t, obses_tp1 = obses_t[:, :-2], obses_tp1[:, :-2]
    elif args["env_name"] == "mujoco-HalfCheetah-v3":
        state, next_state = obses_t[:, 17:], obses_tp1[:, 17:]
        obses_t, obses_tp1 = obses_t[:, :-18], obses_tp1[:, :-18]
    else:
        state, next_state = obses_t[:, 3:], obses_tp1[:, 3:]
        obses_t, obses_tp1 = obses_t[:, :-2], obses_tp1[:, :-2]
    return state, next_state, obses_t, obses_tp1


def roll_out(agent, env, args, num_episodes, if_pretrain_rollout=False):  # if_pretrain_rollout=False means Evaluation
    if if_pretrain_rollout:
        buffer = ReplayBuffer(args=args)
    ep_return, ep_metrics = list(), list()
    for ep in range(num_episodes):
        _ts = 0
        obs, done = env.reset(), [False] * args["num_envs"]
        done_env_mask = np.asarray([False] * args["num_envs"])
        _ep_return = np.zeros(args["num_envs"])
        _if_vis = args["if_visualise"] and (not if_pretrain_rollout) and ep == 0
        if _if_vis:
            frame_buffer = list()
            frame_buffer.append(tile_images(img_nhwc=env.render(), _size=(4, 4)))
        while not all(done_env_mask):
            action = agent.select_action(state=obs, epsilon=1.0 if if_pretrain_rollout else 0.0)
            next_obs, reward, done, info = env.step(action)
            if _if_vis:
                frame_buffer.append(tile_images(img_nhwc=env.render(), _size=(4, 4)))

            if if_pretrain_rollout:
                if args["env_name"].startswith("mujoco-single"):
                    buffer.add(obs, action, reward, next_obs, float(done))
                else:
                    for i in range(args["num_envs"]):
                        if not args["if_use_act_val_fn"] and args["env_name"] == "Pendulum":
                            s = np.hstack([obs[i], info[i]["state"]])
                            ns = np.hstack([next_obs[i], info[i]["next_state"]])
                        else:
                            s, ns = obs[i], next_obs[i]
                        buffer.add(s, action[i], reward[i], ns, float(done[i]))

            if args["env_name"].startswith("mujoco-single"):
                done_env_mask = [done]
            else:
                # Universal logic to conduct the eval w/h vector env
                reward[done_env_mask] = 0.0
                done_env_mask[done] = True

            _ep_return += reward
            # print(_ep_return)
            obs = next_obs
            _ts += 1
        """ === After 1 Episode === """
        ep_return.append(np.mean(_ep_return))
        if args["env_name"].lower() == "paint":
            metrics = args["img_metric"].evaluate(env.canvas.float() / 255., env.gt / 255.)
            ep_metrics.append(metrics)
        if _if_vis:
            args["video_path"] = f"./videos/ts_{args['global_ts']}.mp4"
            save_mp4(np.asarray(frame_buffer), vid_dir="./videos", name=f"ts_{args['global_ts']}", fps=10.0)

    """ === After All Episodes === """
    ep_metrics = mean_dict(ep_metrics)
    ep_metrics["ep_return"] = np.mean(ep_return)
    if if_pretrain_rollout:
        return ep_metrics, buffer
    else:
        return ep_metrics, args


def update_agent(agent, buffer):
    # perform several training steps after each episode
    update_result = list()
    # import pudb; pudb.start()
    for i in range(args["num_updates"]):
        _res_model = agent.update_models(buffer=buffer, batch_size=args["batch_size"])
        _res = list()
        for ii in range(args["num_RL_updates"]):
            _res_rl = agent.update_policy(buffer=buffer, batch_size=args["batch_size"])
            _res.append(_res_rl)
        _res = mean_dict(_res)
        update_result.append(_res)
    update_result = mean_dict(update_result)
    update_result.update(_res_model)
    return update_result


def train(args):
    print(args)
    writer = TensorBoard(args["output"])
    import time
    args["save_dir_name"] = f"{args['prefix']}-{time.time()}"

    if not os.path.exists(f"./results/{args['save_dir_name']}"):
        os.makedirs(f"./results/{args['save_dir_name']}")

    if args["if_save_agent"] and not os.path.exists(f"./weights/{args['save_dir_name']}"):
        os.makedirs(f"./weights/{args['save_dir_name']}")

    # create output directory
    import time
    args["output"] = f"./train_log/{time.time()}"

    env, eval_env, args = launch_env(args=args)
    # obs = env.reset()
    # imgs = env.render()
    # print(obs.shape, imgs.shape)
    # asdf

    if args["policy_name"] == "dqn":
        from gen_rl.policy.dqn import DQN
        agent = DQN(args=args)
    else:
        agent = DDPG(random_act_fn=args["random_act_fn"], args=args)

    buffer = ReplayBuffer(args=args)
    scheduler = AnnealingSchedule(start=args["epsilon_start"], end=args["epsilon_end"], decay_steps=args["decay_steps"])

    if not args["if_use_act_val_fn"] and args["if_train_models"]:
        state_model, reward_model = launch_models(env=env, args=args)
        agent.set_models(reward_model=reward_model, state_model=state_model, decompose_obs_fn=decompose_obs)

    args["global_ts"], epoch = 0, 0
    train_ep_return = deque(maxlen=100)
    train_return_vec = np.zeros(args["num_envs"])
    per_train_ts = args["max_episode_steps"] * 2

    obs, done = env.reset(), [False] * args["num_envs"]
    update_result = list()
    ep_ts, session_num = 0, 0

    # roll_out untrained policy
    eval_metrics, args = roll_out(agent=agent, env=eval_env, num_episodes=args["eval_num_episodes"], args=args)
    for k, v in eval_metrics.items():
        writer.add_scalar(f"eval/{k}", v, args["global_ts"])
    if args["if_visualise"] and args["wandb"]:
        writer.add_video(path_to_video=args["video_path"], step=args["global_ts"], log_string="eval/video")
    logging(f"[Eval] {0}/{args['total_ts']}| Return: {eval_metrics['ep_return']:.3f}")

    # ============ Pretrain the models
    if not args["if_use_act_val_fn"] and args["if_train_models"] and args["env_name"].lower() != "recsim":
        _, _pre_buffer = roll_out(agent=agent, env=env, num_episodes=1, if_pretrain_rollout=True, args=args)
        print(f"[Pretrain] Collected {len(_pre_buffer)} transitions")
        for i in range(100):
            agent.update_models(buffer=_pre_buffer, batch_size=args["batch_size"])
    # ============ Pretrain the models

    # import pudb; pudb.start()
    while args["global_ts"] <= args["total_ts"]:
        _if_warmup = args["global_ts"] <= args["warmup_ts"]
        eps = scheduler.get_value(ts=args["global_ts"])
        action = agent.select_action(state=obs, epsilon=eps, if_warmup=_if_warmup)
        next_obs, reward, done, info = env.step(action)

        # Add to Buffer
        if args["env_name"].startswith("mujoco-single"):
            buffer.add(obs, action, reward, next_obs, float(done))
        else:
            for i in range(args["num_envs"]):
                if not args["if_use_act_val_fn"] and args["env_name"] == "Pendulum":
                    _o = np.hstack([obs[i], info[i]["state"]])
                    _no = np.hstack([next_obs[i], info[i]["next_state"]])
                else:
                    _o, _no = obs[i], next_obs[i]
                buffer.add(_o, action[i], reward[i], _no, float(done[i]))

        obs = next_obs
        args["global_ts"] += args["num_envs"]
        ep_ts += 1  # this indicates the time-step in episode and not progress
        train_return_vec += reward

        if args["if_update_per_ts"] and not _if_warmup and len(buffer) >= args["batch_size"]:
            update_result = update_agent(agent=agent, buffer=buffer)

        if args["env_name"].startswith("mujoco-single"):
            if done:
                train_ep_return.append(train_return_vec[0])
                train_return_vec[0] = 0.0  # empty the bucket
                obs, done = env.reset(), [False] * args["num_envs"]
        else:
            if any(done):
                ids = np.arange(args["num_envs"])[done]  # get the index of dead env
                for _id in ids:
                    agent.reset(id=_id)
                    train_ep_return.append(train_return_vec[_id])
                    train_return_vec[_id] = 0.0  # empty the bucket
                obs, done = env.reset(), [False] * args["num_envs"]

        if ep_ts > per_train_ts:
            logging(f"[Train] {args['global_ts']}/{args['total_ts']} | "
                    f"R: {np.mean(train_ep_return):.3f}, Eps: {eps}, Buffer: {len(buffer)}")
            writer.add_scalar('train/mean_reward', np.mean(train_ep_return), args["global_ts"])

            if not args["if_update_per_ts"] and not _if_warmup and len(buffer) >= args["batch_size"]:
                update_result = update_agent(agent=agent, buffer=buffer)

            if not _if_warmup:
                for k, v in update_result.items():
                    writer.add_scalar(f"update/{k}", v, args["global_ts"])
                logging(f"[Update] {update_result}")
                update_result = list()

            # roll_out episode
            if session_num % args["eval_freq"] == 0:
                eval_metrics, args = roll_out(agent=agent, env=eval_env, num_episodes=args["eval_num_episodes"],
                                              args=args)
                for k, v in eval_metrics.items():
                    writer.add_scalar(f"eval/{k}", v, args["global_ts"])
                if args["if_visualise"] and args["wandb"]:
                    writer.add_video(path_to_video=args["video_path"], step=args["global_ts"], log_string="eval/video")
                logging(f"[Eval] {args['global_ts']}/{args['total_ts']} | Return: {eval_metrics['ep_return']:.3f}")

            if args["if_save_agent"] and (session_num % args["save_freq"] == 0):
                agent.save(filename=f"./weights/{args['save_dir_name']}/session-{session_num}")
            ep_ts = 0
            session_num += 1

    # After all episodes
    agent.save(filename=f"./weights/{args['save_dir_name']}/final")


if __name__ == "__main__":
    from gen_rl.commons.args import get_all_args

    args = get_all_args()
    args = vars(args)
    set_randomSeed(seed=args["seed"])
    train(args=args)
