# rlbench
from rlbench.observation_config import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class

from rlbench.environment import Environment
from rlbench.backend.exceptions import *
import argparse
import os
from pathlib import Path

import IPython
e = IPython.embed
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import cv2

def save_videos(video, dt, video_path=None):
    # Ensure the output directory exists
    if video_path is not None:
        save_dir = os.path.dirname(video_path)
        if save_dir:
            os.makedirs(save_dir, exist_ok=True)
    if isinstance(video, list):
        cam_names = list(video[0].keys())
        h, w, _ = video[0][cam_names[0]].shape
        w = w * len(cam_names)
        fps = int(1/dt)
        out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
        for ts, image_dict in enumerate(video):
            images = []
            for cam_name in cam_names:
                image = image_dict[cam_name]
                image = image[:, :, [2, 1, 0]] # swap B and R channel
                images.append(image)
            images = np.concatenate(images, axis=1)
            out.write(images)
        out.release()
        print(f'Saved video to: {video_path}')
    elif isinstance(video, dict):
        cam_names = list(video.keys())
        all_cam_videos = []
        for cam_name in cam_names:
            all_cam_videos.append(video[cam_name])
        all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension

        n_frames, h, w, _ = all_cam_videos.shape
        fps = int(1 / dt)
        out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
        for t in range(n_frames):
            image = all_cam_videos[t]
            image = image[:, :, [2, 1, 0]]  # swap B and R channel
            out.write(image)
        out.release()
        print(f'Saved video to: {video_path}')


def make_sim_env(task_name, onscreen_render, robot_name):
    
    img_size = [224, 224] # 160, 120
        
    obs_config = ObservationConfig()
    
    obs_config.set_all(False)
    obs_config.wrist_camera.set_all(True)
    obs_config.overhead_camera.set_all(True)
    obs_config.front_camera.set_all(True)
    obs_config.set_all_low_dim(True)
    
    obs_config.wrist_camera.image_size = img_size
    obs_config.overhead_camera.image_size = img_size
    obs_config.wrist_camera.depth_in_meters = False
    obs_config.front_camera.image_size = img_size
    
    headless_val = False if onscreen_render else True # False 
    rlbench_env = Environment(
        action_mode=MoveArmThenGripper(JointPosition(), Discrete()),
        obs_config=obs_config,
        headless=headless_val,
        robot_setup=robot_name)
    rlbench_env.launch()
    rlbench_env._pyrep.step_ui()
    task_class = task_file_to_task_class(task_name)
    task_env = rlbench_env.get_task(task_class) # Type[Task]) -> TaskEnvironment（include scene）
    return task_env, rlbench_env
"""
task_env, rlbench_env = make_sim_env("beat_the_buzz", False, "panda")
rlbench_env.shutdown()
"""

def main(args):
    
    np.set_printoptions(linewidth=300)
    np.random.seed(1)
    # command line parameters

    task_name = args['task_name']
    robot_name = "panda"
    max_timesteps = args['max_timesteps']
    onscreen_render = False
    
    policy = _websocket_client_policy.WebsocketClientPolicy("127.0.0.1", 8005)
    save_dir = Path(args["save_dir"]).expanduser()
    
    
    results = []
    success_rate, avg_return = eval_bc(
        policy,
        task_name,
        max_timesteps,
        robot_name,
        onscreen_render,
        save_episode=True,
        save_dir=save_dir,
    )
    results = [task_name, success_rate, avg_return]
    print(results)


def eval_bc(
    policy,
    task_name,
    max_timesteps,
    robot_name,
    onscreen_render,
    save_episode=True,
    num_verification=50,
    variation=0,
    save_dir: Path | str | None = None,
):
    np.random.seed(42)

    env, rlbench_env = make_sim_env(task_name, onscreen_render, robot_name)
    env_max_reward = 1 # env.task.max_rewardz

    # Prepare save directory
    resolved_save_dir = Path(save_dir).expanduser() if save_dir else Path.cwd() / "rlbench_eval_outputs"
    resolved_save_dir.mkdir(parents=True, exist_ok=True)

    num_rollouts = num_verification 
    episode_returns = []
    highest_rewards = []
    for rollout_id in range(num_rollouts):   
        env.set_variation(variation) 
        descriptions, ts_obs = env.reset() 
        
        image_list = [] # for visualization
        rewards = []
        action_list = []
        for timestep in range(max_timesteps):
            obs = ts_obs
            image_list.append({'front':obs.front_rgb, 'head':obs.overhead_rgb, 'wrist':obs.wrist_rgb})
            qpos_numpy = np.array(np.append(obs.joint_positions, obs.gripper_open)) # 7 + 1 = 8
            if len(action_list) <= 0:
                input_dict = {
                    "observation/image": obs.overhead_rgb,
                    "observation/wrist_image": obs.wrist_rgb,
                    "observation/state":   qpos_numpy,
                    "prompt": str(descriptions[0]),
                    "cot_info": "template"
                }
                action_list = policy.infer(input_dict)['actions']
                print("infer done")
            action = action_list[0]
            action_list = action_list[1:]
            print("action done")
            
            ts_obs, reward, terminate = env.step(action) # qpos could deal with gripper command
            rewards.append(reward) 
            
            if reward == env_max_reward:
                break 
        rewards = np.array(rewards) # 
        episode_return = np.sum(rewards[rewards!=None])
        episode_returns.append(episode_return)
        episode_highest_reward = np.max(rewards)  
        highest_rewards.append(episode_highest_reward)

        if save_episode:
            video_path = resolved_save_dir / f"video_{task_name}_{rollout_id}_{episode_highest_reward==env_max_reward}.mp4"
            save_videos(image_list, 0.1, video_path=str(video_path))

    success_rate = np.mean(np.array(highest_rewards) == env_max_reward) 
    avg_return = np.mean(episode_returns) 
    summary_str = f'\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n'
    for r in range(env_max_reward+1):
        more_or_equal_r = (np.array(highest_rewards) >= r).sum()
        more_or_equal_r_rate = more_or_equal_r / num_rollouts
        summary_str += f'Reward >= {r}: {more_or_equal_r}/{num_rollouts} = {more_or_equal_r_rate*100}%\n'

    print(summary_str)

    # save success rate to txt
    result_file_name = 'result_' + task_name + f'({more_or_equal_r_rate*100}%).txt'
    result_file_path = resolved_save_dir / result_file_name
    with open(result_file_path, 'w') as f:
        f.write(summary_str)
        # f.write(repr(episode_returns))
        f.write('\n\n')
        f.write(repr(highest_rewards)) 
    
    rlbench_env.shutdown()

    return success_rate, avg_return


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--task_name', action='store', type=str, help='task_name', default="beat_the_buzz")
    parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=1)
    parser.add_argument('--seed', action='store', type=int, help='seed', default=42)
    parser.add_argument('--max_timesteps', action='store', type=int, help='max_timesteps', default=400)
    parser.add_argument('--save_dir', action='store', type=str, default="./outputs/rlbench_eval", help='保存视频与结果的目录')
    main(vars(parser.parse_args()))  