import numpy as np
from collections import OrderedDict
import imageio
import time
import matplotlib.pyplot as plt
from robomimic.envs.env_base import EnvBase
from robomimic.envs.wrappers import EnvWrapper
from omnigibson.object_states.contact_bodies import ContactBodies
import os
from typing import Union, Optional, Tuple, Dict, List
from brs_algo.utils import any_concat, any_to_torch_tensor, any_slice
import torch
from collections import deque

num_pcd_points = 4096
pad_pcd_if_needed = True
gripper_half_width = 0.05

num_latest_obs = 2
num_deployed_actions=8

# The joint limits is based on the urdf range limit I guess
torso_joint_high = np.array([1.8326, 2.5307, 1.8326, 3.0543])
torso_joint_low = np.array([-1.1345, -2.7925, -2.0944, -3.0543])
left_arm_joint_high = np.array([2.8798, 3.2289, 0, 2.8798, 1.6581, 2.8798])
left_arm_joint_low = np.array([-2.8798, 0, -3.3161, -2.8798, -1.6581, -2.8798])
right_arm_joint_high = np.array([2.8798, 3.2289, 0, 2.8798, 1.6581, 2.8798])
right_arm_joint_low = np.array([-2.8798, 0, -3.3161, -2.8798, -1.6581, -2.8798])

class OGRollout:

    def __init(self,
        pcd_x_range: Tuple[float, float],
        pcd_y_range: Tuple[float, float],
        pcd_z_range: Tuple[float, float],
               ):
        self._pcd_x_range = pcd_x_range
        self._pcd_y_range = pcd_y_range
        self._pcd_z_range = pcd_z_range


def get_normalized_obs(ob_dict, 
                       device,
                       pcd_x_range,
                       pcd_y_range,
                       pcd_z_range,
                       mobile_base_vel_action_max,
                       mobile_base_vel_action_min,
                       ):
    """
    dict_keys(['object::coffee_cup', 'external::viewer::depth_linear', 'external::viewer::rgb', 'combined::color_point_cloud', 'joint_qpos', 'joint_qpos_sin', 'joint_qpos_cos', 'joint_qvel', 'joint_qeffort', 'robot_pos', 'robot_ori_cos', 'robot_ori_sin', 'robot_2d_ori', 'robot_2d_ori_cos', 'robot_2d_ori_sin', 'robot_lin_vel', 'robot_ang_vel', 'arm_left_qpos', 'arm_left_qpos_sin', 'arm_left_qpos_cos', 'arm_left_qvel', 'eef_left_pos', 'eef_left_quat', 'grasp_left', 'gripper_left_qpos', 'gripper_left_qvel', 'arm_right_qpos', 'arm_right_qpos_sin', 'arm_right_qpos_cos', 'arm_right_qvel', 'eef_right_pos', 'eef_right_quat', 'grasp_right', 'gripper_right_qpos', 'gripper_right_qvel', 'trunk_qpos', 'trunk_qvel', 'base_qpos', 'base_qpos_sin', 'base_qpos_cos', 'base_qvel', 'prop_state', 'prop_eef_state', 'prop_eef_basepose', 'timesteps', 'actions'])

    """

    # get pcd, pcd is already normalized in [-1, 1], DO NOT need normalization again
    pcd_all = ob_dict['combined::color_point_cloud']
    pcd_rgb = np.array(pcd_all[-1, :, :3]) # num_latest_obs, num_pcd_points, 3
    pcd_xyz = np.array(pcd_all[-1, :, 3:])
    pcd_xyz_min = np.array([pcd_x_range[0], pcd_y_range[0], pcd_z_range[0]])
    pcd_xyz_max = np.array([pcd_x_range[1], pcd_y_range[1], pcd_z_range[1]])

    pcd_xyz = (
        2
        * (pcd_xyz - pcd_xyz_min)
        / (pcd_xyz_max - pcd_xyz_min)
        - 1
    )

    # set multi_view_cameras None
    multi_view_cameras = {}

    # get prop states and normalize it
    prop_state = ob_dict['prop_state']
    left_arm_qpos = prop_state[-1, 7:13]
    left_arm_qpos = (
        2
        * (left_arm_qpos - left_arm_joint_low)
        / (left_arm_joint_high - left_arm_joint_low)
        - 1
    ).astype(np.float32)
    right_arm_qpos = prop_state[-1, 13:19]
    right_arm_qpos = (
        2
        * (right_arm_qpos - right_arm_joint_low)
        / (right_arm_joint_high - right_arm_joint_low)
        - 1
    ).astype(np.float32)
    torso_qpos = prop_state[-1, 3:7]
    torso_qpos = (
        2
        * (torso_qpos - torso_joint_low)
        / (torso_joint_high - torso_joint_low)
        - 1
    ).astype(np.float32)

    # get gripper states and normalize it, note the gripper middle state is set to gripper_half_width= 0.05
    gripper_left_qpos = np.array(ob_dict['gripper_left_qpos'])[-1,:]
    left_gripper_position = np.array(gripper_left_qpos).sum()[np.newaxis]
    gripper_right_qpos = np.array(ob_dict['gripper_right_qpos'])[-1,:]
    right_gripper_position = np.array(gripper_right_qpos).sum()[np.newaxis]
    # 1 = gripper closed, -1 = gripper opened
    left_gripper_state = (left_gripper_position <= gripper_half_width).astype(np.float32) * 2 - 1
    right_gripper_state = (right_gripper_position <= gripper_half_width).astype(np.float32) * 2 - 1

    # get base velocity and normalize it
    odom_base_vel = ob_dict['base_qvel'][-1, :]
    odom_base_vel = (2* (odom_base_vel - mobile_base_vel_action_min)/ (mobile_base_vel_action_max - mobile_base_vel_action_min) - 1)
    odom_base_vel = np.clip(odom_base_vel, -1, 1)

    # fill in other observations
    # base odom information
    base_qpos = ob_dict['joint_qpos'][-1, 0:6]

    # object information
    coffee_cup_states = None

    # link pose information 
    left_eef = np.array(ob_dict['prop_eef_state'][-1, 13:20])
    right_eef = np.array(ob_dict['prop_eef_state'][-1, 27:34])

    obs_dict = {
        "pointcloud": {
            "xyz": any_to_torch_tensor(
                pcd_xyz, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),  # (B=1, T=1, N, 3)
            "rgb": any_to_torch_tensor(
                pcd_rgb, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
        },
        "qpos": {
            "left_arm": any_to_torch_tensor(
                left_arm_qpos, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
            "right_arm": any_to_torch_tensor(
                right_arm_qpos, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
            "torso": any_to_torch_tensor(
                torso_qpos, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
            "left_gripper": any_to_torch_tensor(
                left_gripper_state, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
            "right_gripper": any_to_torch_tensor(
                right_gripper_state, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
        },
        "odom": {
            "base_qpos": any_to_torch_tensor(
                base_qpos, device=device, dtype=torch.float32
            ).unsqueeze(0).unsqueeze(0),
            "base_velocity": any_to_torch_tensor(
                odom_base_vel, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
        },
        "multi_view_cameras": multi_view_cameras,
        "link_poses": {
            "left_eef": any_to_torch_tensor(
                left_eef, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
            "right_eef": any_to_torch_tensor(
                right_eef, device=device, dtype=torch.float32
            )
            .unsqueeze(0)
            .unsqueeze(0),
        },
        "object_states": {
            coffee_cup_states,
        },
    }

    return obs_dict


def unnormalize_action(action: Dict[str, np.ndarray], 
                       mobile_base_vel_action_max,
                       mobile_base_vel_action_min,
                       ):
    mobile_base_vel_cmd = action["mobile_base"]
    mobile_base_vel_cmd = np.clip(mobile_base_vel_cmd, -1, 1)
    mobile_base_vel_cmd = (mobile_base_vel_cmd + 1) / 2 * (
        mobile_base_vel_action_max - mobile_base_vel_action_min
    ) + mobile_base_vel_action_min
    left_arm = action["left_arm"]
    left_arm = (left_arm + 1) / 2 * (
        left_arm_joint_high - left_arm_joint_low
    ) + left_arm_joint_low
    right_arm = action["right_arm"]
    right_arm = (right_arm + 1) / 2 * (
        right_arm_joint_high - right_arm_joint_low
    ) + right_arm_joint_low
    torso = action["torso"]
    torso = (torso + 1) / 2 * (
        torso_joint_high - torso_joint_low
    ) + torso_joint_low
    # in og, -1, close gripper, 1, open gripper 
    left_gripper = 1.0 if action["left_gripper"] > 0 else -1.0
    right_gripper = 1.0 if action["right_gripper"] > 0 else -1.0

    return {
        "mobile_base": mobile_base_vel_cmd,
        "left_arm": left_arm,
        "left_gripper": left_gripper,
        "right_arm": right_arm,
        "right_gripper": right_gripper,
        "torso": torso,
    }


def action_flatten(action: Dict[str, np.ndarray]):
    """
    Flatten the action dict to a single array
    """
    action_flat = np.concatenate(
        [
            action["mobile_base"],
            action["torso"],
            action["left_arm"],
            np.array(action["left_gripper"])[None],
            action["right_arm"],
            np.array(action["right_gripper"])[None],
        ]
    )
    return action_flat


def run_rollout(
        policy, 
        env, 
        horizon,
        use_goals=False,
        render=False,
        video_writer=None,
        video_skip=5,
        terminate_on_success=False,
        demo_actions=None,
        check_action_plot=False,
        init_states=None,
        model_type='brs',
        mobile_base_vel_action_max=None,
        mobile_base_vel_action_min=None,
    ):
    """
    Runs a rollout in an environment with the current network parameters.

    Args:
        policy (RolloutPolicy instance): policy to use for rollouts.

        env (EnvBase instance): environment to use for rollouts.

        horizon (int): maximum number of steps to roll the agent out for

        use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env

        render (bool): if True, render the rollout to the screen

        video_writer (imageio Writer instance): if not None, use video writer object to append frames at 
            rate given by @video_skip

        video_skip (int): how often to write video frame

        terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered

    Returns:
        results (dict): dictionary containing return, success rate, etc.
    """

    assert isinstance(env, EnvBase) or isinstance(env, EnvWrapper)
    
    # coffee_cup_obj = env.env.env.scene.object_registry("name", "coffee_cup_7") 
    ob_dict = env.reset()

    if init_states is not None:
        # set initial state
        env.set_object_pose(init_states)
    
    # print('after reset in rollout')
    # print('coffee cup obj', coffee_cup_obj.get_position_orientation()[0])
    # breakpoint()

    goal_dict = None
    if use_goals:
        # retrieve goal from the environment
        goal_dict = env.get_goal()

    results = {}
    video_count = 0  # video frame counter

    total_reward = 0.
    success = { k: False for k in env.is_success() } # success metrics
    got_exception = False

    try:
        ac_list = []
        pcd_list = []
        
        average_step_time = 0

        obs_history = deque(maxlen=num_latest_obs)
        action_idx = 0
        for step_i in range(horizon):
            # step_time = time.time()
            # print("")
            # print('step', step_i)

            # get action from policy
            per_step_policy_rollout_time = time.time()

            if model_type == 'brs':
                # ob_dict is the raw observation form the environment
                if len(obs_history) == 0:
                    for _ in range(num_latest_obs):
                        obs_history.append(get_normalized_obs(
                            ob_dict, policy.device,
                            env.x_range,
                            env.y_range,
                            env.z_range,
                            mobile_base_vel_action_max,
                            mobile_base_vel_action_min,
                            ))
                else:
                    obs_history.append(get_normalized_obs(
                        ob_dict, policy.device,
                        env.x_range,
                        env.y_range,
                        env.z_range,
                        mobile_base_vel_action_max,
                        mobile_base_vel_action_min,
                        ))
            
                obs = any_concat(obs_history, dim=1)  # (B = 1, T = num_latest_obs, ...)

                need_inference = action_idx % num_deployed_actions == 0
                if need_inference:
                    action_traj_pred = policy.act(obs)  # dict of (B = 1, T_A, ...)
                    action_traj_pred = {
                        k: v[0].detach().cpu().numpy() for k, v in action_traj_pred.items()
                    }  # dict of (T_A, ...)
                    action_idx = 0

                action = any_slice(action_traj_pred, np.s_[action_idx])
                action = unnormalize_action(
                    action,
                    mobile_base_vel_action_max,
                    mobile_base_vel_action_min,)
                ac = action_flatten(action)

                # print('policy rollout time', time.time() - per_step_policy_rollout_time)

            
            elif model_type == 'openpi':
                # print('step_i', step_i)

                obs_ego = ob_dict['robot_r1::robot_r1:eyes:Camera:0::rgb'][-1, ...,:3]
                obs_wrist_left= ob_dict['robot_r1::robot_r1:left_eef_link:Camera:0::rgb'][-1, ...,:3]
                obs_wrist_right = ob_dict['robot_r1::robot_r1:right_eef_link:Camera:0::rgb'][-1, ...,:3]
                cam_obs = np.stack([obs_ego, obs_wrist_left, obs_wrist_right], axis=0)
                cam_obs = cam_obs[None, None]  #(B, T, num_cameras, H, W, C) 
                prop_state= ob_dict['prop_state'][-1, ...][None, None] # (B, T, 21)
                obs_dict= {
                    "observation": cam_obs,
                    "proprio": prop_state,
                }
                
                changing_step = 2000
                if step_i == 0:
                    policy.reset()

                try:
                    start_time = time.time()
                    action = policy.act(obs_dict)
                    # print('per step inference time', time.time() - start_time)
                except:
                    print('error when getting action from policy')
                    breakpoint()
                
                # print('======================= left gripper', action['left_gripper'])
                # print('======================= right gripper', action['right_gripper'])

                first_action = {key: value[0] for key, value in action.items()}
                first_action_concat = np.concatenate([v for v in first_action.values()])
                ac = first_action_concat
                # print('action')
                # print(ac)
                # if step_i % changing_step == 0:
                #     breakpoint()


            ac_list.append(ac)
            pcd_list.append(ob_dict['combined::color_point_cloud'][-1])

            # play action
            env_rollout_time = time.time()
            ob_dict, r, done, truncated, info = env.step(ac)
            
            action_idx += 1

            if average_step_time == 0:
                average_step_time = time.time() - per_step_policy_rollout_time
            else:
                average_step_time = average_step_time *step_i / (step_i + 1) + (time.time() - per_step_policy_rollout_time) / (step_i + 1)



            # render_time = time.time()
            # render to screen
            if render:
                env.render(mode="human")
            # print('time for render', time.time() - render_time)

            # compute reward
            total_reward += r
            # start_success_time = time.time()
            cur_success_metrics = env.is_success()
            for k in success:
                success[k] = success[k] or cur_success_metrics[k]
            # visualization
            if video_writer is not None:
                if video_count % video_skip == 0:
                    video_img = env.render(mode="rgb_array", height=720, width=1280)
                    # breakpoint()
                    video_writer.append_data(video_img)

                video_count += 1

            # if done or (terminate_on_success and success["task"]):
            if terminate_on_success and success["task"]:
                print('success episode')
                if success['lift']:
                    print('lift success step:', step_i)
                    # print('lift success')
                    # breakpoint()
                    break

            if step_i % 10 == 0:

                early_termination = env.early_termination(step_i, ob_dict)
                if early_termination:
                    print("early termination condition met")
                    # breakpoint()
                    break
            
            if np.sum(ob_dict['combined::color_point_cloud'])< 0.0001:
                print("pcd in observation is None")
                # breakpoint()
                break

            # print('step time', time.time() - step_time)

    except KeyboardInterrupt:
        print('keyboard interrupt')
        got_exception = True
    

    results["Return"] = total_reward
    results["Horizon"] = step_i + 1
    results["Success_Rate"] = float(success["task"])
    results["Exception_Rate"] = float(got_exception)

    # log additional success metrics
    for k in success:
        if k != "task":
            results["{}_Success_Rate".format(k)] = float(success[k])
    
    # for example whether grasping objects is successful

    return results, ac_list, pcd_list

    
def rollout_with_stats(
        policy,
        envs,
        horizon,
        use_goals=False,
        num_episodes=None,
        render=False,
        video_dir=None,
        video_path=None,
        epoch=None,
        video_skip=5,
        terminate_on_success=False,
        verbose=False,
        demo_actions=None,
        check_action_plot=False,
        init_states_list=None,
        train_init_states=True,
        val_init_states=False,
        model_type='brs',
        mobile_base_vel_action_max=None,
        mobile_base_vel_action_min=None,
    ):
    """
    A helper function used in the train loop to conduct evaluation rollouts per environment
    and summarize the results.

    Can specify @video_dir (to dump a video per environment) or @video_path (to dump a single video
    for all environments).

    Args:
        policy (RolloutPolicy instance): policy to use for rollouts.

        envs (dict): dictionary that maps env_name (str) to EnvBase instance. The policy will
            be rolled out in each env.

        horizon (int): maximum number of steps to roll the agent out for

        use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env

        num_episodes (int): number of rollout episodes per environment

        render (bool): if True, render the rollout to the screen

        video_dir (str): if not None, dump rollout videos to this directory (one per environment)

        video_path (str): if not None, dump a single rollout video for all environments

        epoch (int): epoch number (used for video naming)

        video_skip (int): how often to write video frame

        terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered

        verbose (bool): if True, print results of each rollout
    
    Returns:
        all_rollout_logs (dict): dictionary of rollout statistics (e.g. return, success rate, ...) 
            averaged across all rollouts 

        video_paths (dict): path to rollout videos for each environment
    """

    all_rollout_logs = OrderedDict()

    # handle paths and create writers for video writing
    assert (video_path is None) or (video_dir is None), "rollout_with_stats: can't specify both video path and dir"
    write_video = (video_path is not None) or (video_dir is not None)
    video_paths = OrderedDict()
    video_writers = OrderedDict()
    if video_path is not None:
        # a single video is written for all envs
        video_paths = { k : video_path for k in envs }
        video_writer = imageio.get_writer(video_path, fps=20)
        video_writers = { k : video_writer for k in envs }
    if video_dir is not None:
        # video is written per env
        if train_init_states:
            video_str = "{}_train.mp4".format(epoch) if epoch is not None else ".mp4"
        elif val_init_states:
            video_str=  "{}_val.mp4".format(epoch) if epoch is not None else ".mp4" 
        video_paths = { k : os.path.join(video_dir, "{}{}".format(k, video_str)) for k in envs }
        video_writers = { k : imageio.get_writer(video_paths[k], fps=20) for k in envs }

    for env_name, env in envs.items():
        env_video_writer = None
        if write_video:
            print("video writes to " + video_paths[env_name])
            env_video_writer = video_writers[env_name]

        print("rollout: env={}, horizon={}, use_goals={}, num_episodes={}".format(
            env.name, horizon, use_goals, num_episodes,
        ))
        rollout_logs = []
        iterator = range(num_episodes)

        num_success = { k: 0 for k in env.is_success() }
        for ep_i in iterator:
            init_states = None
            if init_states_list is not None:
                init_states = init_states_list[ep_i % len(init_states_list)]
            rollout_timestamp = time.time()

            rollout_info, act_list, pcd_list = run_rollout(
                policy=policy,
                env=env,
                horizon=horizon,
                render=render,
                use_goals=use_goals,
                video_writer=env_video_writer,
                video_skip=video_skip,
                terminate_on_success=terminate_on_success,
                demo_actions=demo_actions,
                check_action_plot=check_action_plot,
                init_states=init_states,
                model_type=model_type,
                mobile_base_vel_action_max=mobile_base_vel_action_max,
                mobile_base_vel_action_min=mobile_base_vel_action_min,
            ) # 'Return', 'Horizon', 'Success_Rate', 'Exception_Rate', 'actions'

            rollout_info["time"] = time.time() - rollout_timestamp
            rollout_logs.append(rollout_info)
            for k in num_success:
                if k != "task":
                    num_success[k] += rollout_info[k + "_Success_Rate"]
            num_success["task"] += rollout_info["Success_Rate"]
            if verbose:
                print("")
                print("Episode {}, horizon={}, total_num_success={}".format(ep_i + 1, horizon, num_success["task"]))
                print("time", rollout_info["time"])
                print("")
                # print(json.dumps(rollout_info, sort_keys=True, indent=4))

            print('rollout info', rollout_info)
            # breakpoint()
            

        if video_dir is not None:
            # close this env's video writer (next env has it's own)
            env_video_writer.close()

        # average metric across all episodes
        rollout_logs = dict((k, [rollout_logs[i][k] for i in range(len(rollout_logs))]) for k in rollout_logs[0])
        # dict_keys(['Return', 'Horizon', 'Success_Rate', 'Exception_Rate', 'actions', 'time'])
        rollout_logs_mean = dict((k, np.mean(v)) for k, v in rollout_logs.items() if k != "actions")
        rollout_logs_mean["Time_Episode"] = np.sum(rollout_logs["time"]) / 60. # total time taken for rollouts in minutes
        all_rollout_logs[env_name] = rollout_logs_mean

    if video_path is not None:
        # close video writer that was used for all envs
        video_writer.close()

    return all_rollout_logs, video_paths