import os
import sys
import argparse
import numpy as np
import tqdm
import torch
from matplotlib import pyplot as plt
import json
import cv2

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
RES_DIR = os.path.join(ROOT_DIR, "results")
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")

sys.path.append(FILE_DIR)
sys.path.append(ROOT_DIR)

from modules.velap.rl.td3_bc import TD3BC
from modules.velap.dynamics.model_dynamics import DynamicsModel
from modules.velap.rl.model_encoder import StateEncoder
from modules.velap.action_sampler.model_action_sampler import ActionSampler
from modules.velap.planner.planner import Planner
from modules.velap.controller.controller import Controller


from modules.envs.get_eval_env import get_eval_env


def main():
    exp_dir = os.path.join(RES_DIR, args.exp_name)

    # Load parameters
    with open(os.path.join(exp_dir, "encoder", "params.json")) as f:
        params = json.load(f)

    with open(os.path.join(exp_dir, "action_sampler", "params.json")) as f:
        action_sampler_params = json.load(f)

    args.z_dim = params["z_dim"]
    args.action_dim = 4 if not params["env"] in ("spiral_env", "obstacle_env") else 2
    args.max_action = 1.0
    args.dataset = params["dataset_train"]
    args.env = params["env"]
    args.n_frames = params["n_frames"]
    args.channel_dim = params["channel_dim"]
    args.discount = params["discount"]
    args.n_qs = params["n_qs"]
    args.vae_latent_dim = action_sampler_params["vae_latent_dim"]

    # Load encoder
    encoder = StateEncoder(args.z_dim,
                           args.channel_dim,
                           args.n_frames).to(args.device)
    encoder.load_state_dict(torch.load(os.path.join(exp_dir, "encoder", "model",
                                                    "model_encoder"), map_location=args.device), strict=True)
    encoder.to(args.device)
    encoder.eval()

    action_sampler = ActionSampler(state_dim=args.z_dim,
                                   action_dim=args.action_dim,
                                   latent_dim=args.vae_latent_dim,
                                   device=args.device).to(args.device)
    action_sampler.load_state_dict(torch.load(os.path.join(exp_dir, "action_sampler", "model", "action_sampler")))
    action_sampler.eval()

    # Create dynamics model
    model_dyn = DynamicsModel(args.z_dim, args.action_dim).to(args.device)
    model_dyn.load_state_dict(torch.load(os.path.join(exp_dir, "dynamics", "model", "model_dynamics")))
    model_dyn.eval()

    # Create policy high
    policy_high = TD3BC(z_dim=args.z_dim,
                         action_dim=args.action_dim,
                         goal_dim=0,
                         device=args.device)
    policy_high.load(os.path.join(exp_dir, "policy_high", "model"))

    policy_high.critic.eval()
    for m in policy_high.critic.models:
        m.eval()
    policy_high.actor.eval()

    # Create policy low
    policy_low = TD3BC(z_dim=args.z_dim,
                       action_dim=args.action_dim,
                       goal_dim=args.z_dim,
                       n_qs=args.n_qs)
    policy_low.load(os.path.join(exp_dir, "encoder", "model"), type="_low")
    policy_low.critic.eval()
    for m in policy_low.critic.models:
        m.eval()
    policy_low.actor.eval()

    # Load expansion filtering thresholds
    dist_stats = np.load(os.path.join(exp_dir, "transition_eucl_dist_percentiles.npy"))
    d_neigh = dist_stats[-3] * args.d_neigh_n_step
    d_discard = dist_stats[-3] * args.d_discard_n_step

    tau_q_neigh = args.discount**args.q_neigh_n_step
    tau_q_discard = args.discount**args.q_discard_n_step
    tau_q_goal = args.discount**args.q_goal_n_step
    tau_expand_q_min = args.discount**args.q_min_n_step
    tau_expand_q_std = 1.0 - (args.discount**args.q_max_std_n_step)
    tau_q_stop_plan = args.discount**args.q_stop_plan_n_step
    tau_q_next_wp = args.discount**args.q_next_wp_n_step
    success = []
    steps = []
    i_test_case = 0
    attempts = 0
    reset_info = {"mode": "eval",  "init_mode": "", "context_info": None, "context_params": None}

    env, max_steps = get_eval_env(args)

    while attempts < 50 and i_test_case < args.n_eval:

        # Create planner
        planner = Planner(d_discard=d_discard,
                          d_neigh=d_neigh,
                          tau_q_neigh=tau_q_neigh,
                          tau_q_discard=tau_q_discard,
                          tau_q_goal=tau_q_goal,
                          dynamics=model_dyn,
                          policy_low=policy_low,
                          policy_high=policy_high,
                          tau_expand_q_min=tau_expand_q_min,
                          tau_expand_q_std=tau_expand_q_std,
                          action_sampler=action_sampler,
                          t_sparse=1.0,
                          t_value=1.0,
                          p_sample_node_sparse=args.p_sample_node_sparse,
                          p_sample_node_rand=args.p_sample_node_rand,
                          p_sample_action_rand=args.p_sample_node_rand,
                          p_sample_action_model=args.p_sample_action_model,
                          n_iter=args.n_iter,
                          n_sim=args.n_sim,
                          discount=args.discount,
                          use_q_std_reject=1,
                          batch_size=32)

        ctrl = Controller(planner,
                          policy_low,
                          policy_high,
                          args.replan_every,
                          tau_q_stop_plan,
                          tau_q_next_wp)
        # try:
        state = env.reset(reset_info=reset_info)
        ctrl.reset()
        total_steps = 0
        while 1:
            video = state["video"]
            video = np.array(video, dtype=np.uint8).astype(np.float32).transpose(2, 0, 1) / 255.
            video_t = torch.from_numpy(video.astype(np.float32)).to(args.device).unsqueeze(0)
            z_t = encoder.forward(video_t)

            # Determine control action
            action = ctrl.get_action(z_t)

            # Take step in environment
            state, reward, done, info = env.step(action.copy())
            total_steps += 1

            if done:
                success.append(1.0)
                steps.append(total_steps)

                break

            if total_steps > max_steps:
                success.append(0.0)
                steps.append(total_steps)
                break

        i_test_case += 1

        eval_all = [success, steps]
        eval_stats = [np.sum(success), np.mean(success), np.mean(steps)]
        os.makedirs(os.path.join(RES_DIR, args.exp_name, "velap"), exist_ok=True)

        np.save(os.path.join(RES_DIR, args.exp_name, "velap", "%seval.npy" % args.name), eval_all)
        np.save(os.path.join(RES_DIR, args.exp_name, "velap",  "%seval_stats.npy" % args.name), eval_all)

        # Store parameter to json
        dict = vars(args)
        dict["eval_success"] = eval_stats
        with open(os.path.join(RES_DIR, args.exp_name, "velap", "%sparams.json" % args.name), 'w') as json_file:
            json.dump(dict, json_file, sort_keys=True, indent=2)

    print("final stats ", eval_stats)


if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default="bench/obstacle_env_0_0")
    parser.add_argument('--name', type=str, default="")
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--replan_every', type=int, default=25)
    parser.add_argument('--n_eval', type=int, default=100)
    parser.add_argument('--n_sim', type=int, default=10)
    parser.add_argument('--n_iter', type=int, default=250)
    parser.add_argument('--render', type=int, default=1)

    # Planner parameters
    parser.add_argument('--d_neigh_n_step', type=int, default=5)
    parser.add_argument('--d_discard_n_step', type=int, default=2)
    parser.add_argument('--q_discard_n_step', type=float, default=2)
    parser.add_argument('--q_neigh_n_step', type=float, default=5)
    parser.add_argument('--q_max_std_n_step', type=int, default=1)
    parser.add_argument('--q_min_n_step', type=int, default=10)
    parser.add_argument('--q_goal_n_step', type=int, default=3)
    parser.add_argument('--use_q_std_reject', type=int, default=1)

    parser.add_argument('--p_sample_node_sparse', type=float, default=0.6)
    parser.add_argument('--p_sample_node_rand', type=float, default=0.1)
    parser.add_argument('--p_sample_action_rand', type=float, default=0.0)
    parser.add_argument('--p_sample_action_model', type=float, default=0.7)

    # Controller parameters
    parser.add_argument('--q_stop_plan_n_step', type=int, default=3)
    parser.add_argument('--q_next_wp_n_step', type=int, default=5)

    args = parser.parse_args()

    # Load data
    main()
