import os
import torch
import numpy as np

from models.low_level_embedding.low_model import Low_Model
from models.baseline.bc_model import BC_Model

from configs import get_test_config, get_task_parameter
from funcs import get_setting, get_rl_hyperparameter, load_model

import gymnasium as gym
import mani_skill.envs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_STATE_DIM = 54
MAX_STEPS = 50


def find_closest_play(curr_play, play_list):
    curr_play = np.expand_dims(curr_play, 0)
    play_list = play_list.detach().cpu().numpy()
    dist_ = np.linalg.norm(play_list-curr_play, 1)
    min_idx_ = np.argmin(dist_)
    closest_play = play_list[min_idx_]
    return closest_play

def run_test(test_env, test_models, task_name, n_demos, config):
    print()
    print("** RESULT: {}".format(task_name))

    result_list, reward_list = [], []
    for n in range(n_demos):
       if n == 0: epi_info = [n+1, 0.0]
       else: epi_info = [n+1, np.mean(result_list)]
       result_, reward_ = rollout_episode(
           test_env, test_models, epi_info, config
        )
       result_list.append(result_)
       reward_list.append(reward_)
       print("[Task: {}] Success Rate {}/{} ({:.1f}%) | Reward {:.2f} ".format(task_name,np.sum(result_list),len(result_list),np.mean(result_list)*100.0,np.mean(reward_list)), end="\r")
    
    print("[Task: {}] Success Rate {}/{} ({:.1f}%) | Reward {:.2f} ".format(task_name,np.sum(result_list),len(result_list),np.mean(result_list)*100.0,np.mean(reward_list)))
    print()
    test_env.close()

def rollout_episode(test_env, test_models, episode_info, config):
    playbook_model, bc_model = test_models
    epi_idx, success_rate = episode_info

    curr_play = None
    total_reward = 0.0

    obs, _  = test_env.reset()
    obs = obs.to(device)
    for i in range(MAX_STEPS):
        obs_pad = torch.zeros((1,MAX_STATE_DIM)).to(device)
        obs_dim = obs.size()[-1]
        obs_pad[:,:obs_dim] = obs.to(device)

        if curr_play is None:
            play_set = playbook_model.get_playbook()
            B, _ = play_set.shape

            n_weights = int(np.sum(config.n_weights))
            if n_weights > 1:
                curr_play_raw = bc_model.get_action(obs_pad)
                curr_play = find_closest_play(curr_play_raw, play_set)
                curr_play = torch.tensor(curr_play).float().to(device).unsqueeze(0)
            else:
                curr_play = play_set[0].unsqueeze(0)

        (move, _), play_idx = playbook_model.get_action_from_weight(obs_pad, curr_play)
        curr_move = move.detach().cpu().numpy()[0]
        curr_action = np.array(curr_move)
        curr_action = np.clip(curr_action, -1.0, 1.0)

        obs, reward, done0, done1, info = test_env.step(curr_action)
        test_env.render()
        obs = obs.to(device)
        success = info["success"].tolist()[0]
        total_reward += reward.tolist()[0]

        done = done0.tolist()[0] or done1.tolist()[0]
        if done: break

        # if (i+1) % 1 == 0: curr_play = None
        if (i+1) % config.window_size == 0: curr_play = None

        num_bars = 50
        progress_ = int((i+1)/MAX_STEPS*num_bars)
        percent_ = (i+1)/MAX_STEPS*100

        print('  [EPISODE{:03d}][Progress {}{}:{:.1f}%] Steps: {} / {}    '\
            .format(epi_idx, '█'*progress_, ' '*(num_bars-progress_), percent_, i+1, MAX_STEPS), end='\r')

    if success: result_, result_str = True, "Success"
    else: result_, result_str = False, "Fail"
    success_rate = (success_rate*(epi_idx-1)+result_)/epi_idx
    print('  [EPISODE{:03d}][Progress {}{}:{:.1f}%] Steps: {} / {} | Result: {} | Success: {:.2f}%   '\
        .format(epi_idx, '█'*progress_, ' '*(num_bars-progress_), percent_, i+1, MAX_STEPS, result_str, success_rate*100.0))
    return result_, total_reward

if __name__ == "__main__":
    config = get_test_config()
    setting = get_setting(0)

    if "all" in config.task:
        # mani_task_list = ["PushCube-v1", "PullCube-v1", "PokeCube-v1", "LiftPegUpright-v1"]
        # mani_task_list = ["PullCube-v1", "PokeCube-v1"]
        # mani_task_list = ["PushCube-v1"]
        mani_task_list = ["PullCube-v1"]
        # mani_task_list = ["PokeCube-v1"]
    elif "push" in config.task: mani_task_list = ["PushCube-v1"]
    elif "pull" in config.task: mani_task_list = ["PullCube-v1"]
    elif "poke" in config.task: mani_task_list = ["PokeCube-v1"]
    elif "lift" in config.task: mani_task_list = ["LiftPegUpright-v1"]

    state_dim, action_dim, input_type, consider_gripper, tt_max_length = get_task_parameter(config.task)
    play_dim = int(np.sum(config.n_subpols))

    load_name = config.loadname
    load_path = "./results/{}/{}".format(config.task,load_name)
    model_name = "{}_EMB_{}_best".format(load_name,config.task)

    # define playbook model
    emb_model = Low_Model(
        input_type, state_dim, action_dim,
        config.z_dep_dim, config.z_ind_dim, config.window_size,
        config.n_subpols, config.n_weights,
        consider_gripper, config.use_newdata, False, True, setting
    ).to(device)
    _ = load_model(emb_model, model_name, load_path)
    emb_model.eval()
    print("* Successfully Load Pre-Trained Playbook.")

    n_weights = int(np.sum(config.n_weights))
    if n_weights > 1:
        bc_model = BC_Model(
            input_type, state_dim, play_dim, consider_gripper, config, setting
        ).to(device)

        bc_load_path = os.path.join(config.logbase, config.task, config.loadname, "step0/high_bc")
        bc_model_name = "best_seed{}".format(config.seed)
        _ = load_model(bc_model, bc_model_name, bc_load_path)
        bc_model.eval()
        print("* Successfully Load Pre-Trained BC Model.")
    else:
        bc_model = None

    for curr_task in mani_task_list:
        env = gym.make(
            curr_task,
            num_envs=1,
            obs_mode="state", 
            reconfiguration_freq=1,
            control_mode="pd_ee_delta_pos",
            render_mode="rgb_array", # human
        )

        test_model = [emb_model, bc_model]
        run_test(env, test_model, curr_task, config.eval_episodes, config)
    

    
