import os
import json
import shutil
from typing import Any
import torch
import hydra
import argparse
import cv2
import numpy as np
from tqdm import tqdm
from hydra import compose, initialize
from models.networks import Runner
from pathlib import Path
from utils.misc import set_random_seed
from vgn.io import my_IO
from vgn.grasp import Grasp, Label
from vgn.perception import *
from vgn.simulation import ClutterRemovalSim
from vgn.utils.misc import apply_noise
from vgn.utils.transform import Rotation, Transform
import matplotlib.pyplot as plt
from utils.visualization import *
from utils.misc import EasyDict
from utils.transform import *
from typing import Any, List, Dict, Set, Tuple, Union


def angles_to_extrinsic(angles: Tuple[float, float, float], size=0.3) -> Transform:
    r, theta, phi = angles
    origin = Transform(Rotation.identity(), np.r_[size / 2, size / 2, 0.0])
    return camera_on_sphere(origin, r, theta, phi)


def extrinsic_to_angles(transform: Transform, size=0.3) -> Tuple[float, float, float]:
    origin = Transform(Rotation.identity(), np.r_[size / 2, size / 2, 0.0])
    m = (transform * origin).inverse().as_matrix()

    eye = m[:3, 3]
    forward = m[:3, 2]
    right = m[:3, 0]
    up = -m[:3, 1]

    radius = np.linalg.norm(eye)
    theta = np.arccos(eye[2] / radius)
    phi = np.arctan2(eye[1], eye[0])

    if phi < 0.0:
        phi += 2.0 * np.pi

    return radius, theta, phi


def get_random_angles(size=0.3) -> Tuple[float, float, float]:
    r = np.random.uniform(1.6, 2.4) * size
    theta = np.random.uniform(0.0, np.pi / 3.0)  
    phi = np.random.uniform(0.0, 2.0 * np.pi)
    return r, theta, phi


def generate_extrinsic_neighbor(origin_cam: Transform, size=0.3, num_direction=6,
                                min_step=0.0, max_step=0.5*np.pi, num_steps=3) -> List[Transform]:
    rotated_cams = []
    for q in np.linspace(0, 2 * np.pi, num_direction, endpoint=False):  
        for s in np.linspace(max_step, min_step, num_steps, endpoint=False):  
            rotated_cam = origin_cam * Transform(Rotation.identity(), np.array([0.15, 0.15, 0]))
            optic_axis = rotated_cam.inverse().as_matrix()[:3, 3]  
            optic_axis = optic_axis / np.linalg.norm(optic_axis)
            rot_axis = np.cross(optic_axis, np.array([0, 0, 1]))  
            q_shifted = q + np.random.uniform(0, 2 * np.pi / num_direction)
            rot_axis = Transform(Rotation.from_rotvec(q_shifted * optic_axis), np.array([0, 0, 0])).as_matrix() @ np.r_[rot_axis, 1]  
            rot_axis = rot_axis[:3]
            rot_axis = rot_axis / np.linalg.norm(rot_axis)
            rotated_cam = rotated_cam * Transform(Rotation.from_rotvec(s * rot_axis), np.array([0, 0, 0]))  
            rotated_cam = rotated_cam * Transform(Rotation.identity(), np.array([-0.15, -0.15, 0]))
            radius, theta, phi = extrinsic_to_angles(rotated_cam)
            if 0 < theta < np.pi / 3.0:  
                
                
                rotated_cams.append((radius, theta, phi))
    return rotated_cams


class ExtrinsicMonitor:
    def __init__(self, intrinsic):
        self.intrinsic = intrinsic
        self.cam_angles = []
        self.candidates_angles = []

    def add_cam(self, angles):
        self.cam_angles.append(angles)

    def add_candidates(self, angles_list: List[Tuple[float, float, float]]):
        self.candidates_angles = angles_list
        self.quality_list = [0.0] * len(angles_list)

    def add_candidatas_with_quality(self, angles_list: List[Tuple[float, float, float]],
                                    quality_list: List[float]):
        self.candidates_angles = angles_list
        self.quality_list = quality_list

    def clear_candidates(self):
        self.candidates_angles = []
        self.quality_list = []

    def draw_and_save(self, path=None):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(projection='3d')
        for index, angle in enumerate(self.candidates_angles):
            if self.quality_list[index] == 0.0:  
                color = 'gray'
            else:
                q_scaled = (self.quality_list[index] - 0.9) / 0.1  
                q_scaled = max(0.0, q_scaled)
                color = (q_scaled, 0.0, 1.0 - q_scaled)
            draw_pyramid(ax, angles_to_extrinsic(angle).as_matrix(),
                         intrinsic=self.intrinsic, height=0.1,
                         color=color,
                         label="candidate")
        for angle in self.cam_angles:
            draw_pyramid(ax, angles_to_extrinsic(angle).as_matrix(),
                         intrinsic=self.intrinsic, height=0.1,
                         color='red', label="cam")
        draw_floor_grid(ax)
        set_axis_range(ax, [-0.3, 0.6, -0.3, 0.6, -0.3, 0.6])
        fix_3d_axis_equal(ax)
        mark_axis_label(ax)
        if path is not None:
            plt.savefig(path)

    def show(self):
        plt.show()


class GraspPlanner:
    def __init__(self, cfg):
        self.net = Runner(cfg)

    def __call__(self, tsdf, pos_list, extrinsic_list, intrinsic_list, size) -> EasyDict:
        result = EasyDict()
        if len(pos_list) == 0:
            result.grasp_label = torch.tensor([0])
            return result
        batch_size = 1000  
        is_first_batch = True
        for i in tqdm(range(0, len(pos_list), batch_size), disable=True):  
            batch = self.prepare_batch(tsdf,
                                       pos_list[i:i + batch_size],  
                                       extrinsic_list,
                                       intrinsic_list,
                                       size)
            prediction_batch = self.net.predict_grasp(batch)
            prediction_batch = self.clean_prediction(prediction_batch, size)
            if is_first_batch:
                result = prediction_batch
                is_first_batch = False
            else:
                result.append(prediction_batch)
        return result

    def prepare_batch(self, tsdf, pos_list, extrinsics, intrinsics, size) -> EasyDict:
        tsdf = tsdf if isinstance(tsdf, torch.Tensor) else torch.tensor(tsdf)
        pos_list = pos_list if isinstance(pos_list, torch.Tensor) else torch.tensor(pos_list)
        if len(pos_list.shape) == 2:
            pos_list = pos_list.unsqueeze(1)
        extrinsics = extrinsics if isinstance(extrinsics, torch.Tensor) else torch.tensor(extrinsics)
        intrinsics = intrinsics if isinstance(intrinsics, torch.Tensor) else torch.tensor(intrinsics)

        batch_size = len(pos_list)
        assert batch_size != 0
        
        extrinsics = extrinsics[None, None, :].repeat(batch_size, 1, 1)
        intrinsics = intrinsics[None, :].repeat(batch_size, 1)

        
        pos_list = pos_list / size - 0.5  

        batch = EasyDict({'tsdf': tsdf, 'point_grasp': pos_list,
                          'camera_extrinsic': extrinsics,
                          'camera_intrinsic': intrinsics})
        return batch

    def clean_prediction(self, prediction: EasyDict, size) -> EasyDict:        
        prediction.grasp_width *= size
        return prediction


class GraspExp:
    def __init__(self, args, save_dir, rng=None, tsdf_resolution=40):
        self.save_dir = save_dir
        self.mode = args.mode
        self.depth_imgs = []
        self.tsdf = TSDFVolume(args.size, tsdf_resolution)  
        self.args = args
        

        if self.mode == 'sim':  
            self.create_sim_scene(args, rng=rng)
            self.sim_gui = args.sim_gui
            self.add_noise = args.add_noise
            self.intrinsic = self.sim.camera.intrinsic
            self.intrinsic_list = [getattr(self.intrinsic, k) for k in ['fx', 'fy', 'cx', 'cy', 'width', 'height']]
            self.target_label = None
            self.angles = get_random_angles()  
            self.extrinsic: Transform = angles_to_extrinsic(self.angles)

        elif self.mode == 'dataset':  
            self.sim_gui = args.sim_gui
            self.io = my_IO('data/packed/data_packed_facing_grasp')
            self.df = self.io.read_df()
            _, self.intrinsic, _, _ = self.io.read_setup()
            self.intrinsic_list = [getattr(self.intrinsic, k) for k in ['fx', 'fy', 'cx', 'cy', 'width', 'height']]
            scene_id = self.df.loc[0, "scene_id"]
            voxel_grid = self.io.read_voxel_grid(scene_id)
            self.tsdf = torch.tensor(voxel_grid[0], dtype=torch.float32)
            [self.depth_img], [extrinsic_list] = self.io.read_depth_image(scene_id)
            self.depth_imgs.append(self.depth_img)
            self.extrinsic: Transform = Transform.from_list(extrinsic_list)
            self.angles = extrinsic_to_angles(self.extrinsic)
            print(f'Load scene {scene_id} from dataset')

    def clear_depth_imgs(self):
        self.depth_imgs = []

    def create_sim_scene(self, args, rng=None):
        assert self.mode == 'sim', 'Only sim mode can create sim scene'

        object_count = 4  
        self.sim = ClutterRemovalSim(args.scene_type, args.object_set,
                                     gui=args.sim_gui,  
                                     add_noise=args.add_noise,
                                     save_dir=self.save_dir, save_pkl=True)
        if args.record_video:
            self.sim.world.log_renderer.enable()  

        while self.sim.num_objects == 0:  
            self.sim.world.log_renderer.reset()
            self.sim.reset(object_count, rng=rng)
            self.object_count = self.sim.num_objects  
            print(f"Resetting simulation with {self.object_count} objects")

    def get_tsdf(self):
        if self.mode == 'dataset':
            return self.tsdf

        if self.mode == 'sim':
            depth_img, self.segmentation = self.sim.camera.render(self.extrinsic)[1:3]
            self.depth_img = apply_noise(depth_img, self.add_noise)
            self.depth_imgs.append(depth_img)

        elif self.mode == 'real':
            self.depth_img = self.get_real_scene()  
            pass
        
        self.tsdf.integrate(depth_img, self.intrinsic, self.extrinsic)  
        pc = self.tsdf.get_cloud()
        
        
        return self.tsdf.get_grid(), pc

    def get_target_pos(self, pixel: Union[Tuple, List, None] = None, any_point: bool = False) -> Tuple[List, List]:
        if self.mode == 'sim':
            if any_point:
                target_mask = self.segmentation > 0
            else:
                if self.target_label is None:
                    if pixel is None:
                        while True:
                            self.target_label = np.random.randint(1, self.object_count + 1)
                            target_mask = self.segmentation == self.target_label
                            if np.sum(target_mask) > 0:  
                                break
                        
                    else:
                        self.target_label = self.segmentation[pixel[1], pixel[0]]
                target_mask = self.segmentation == self.target_label

        elif self.mode == 'dataset':
            
            pos_len = len(self.df.index)
            pos_list = []
            pixel_list = []
            
            num_pts = self.args.sample_points_per_view if self.args.sample_points_per_view < pos_len else pos_len
            for i in np.random.choice(pos_len, num_pts, replace=False):
                pos = self.df.loc[i, "x":"z"].to_numpy(np.single)
                pos_list.append(pos)
                u, v = world2pixel(*pos, self.intrinsic, self.extrinsic)
                pixel_list.append([v, u])
            return pos_list, pixel_list

        elif self.mode == 'real':
            
            pass

        pixel_list = np.argwhere(target_mask)
        
        if len(pixel_list) > self.args.sample_points_per_view:
            pixel_list = pixel_list[np.random.choice(len(pixel_list), self.args.sample_points_per_view, replace=False)]

        
        pos_list = []
        for y, x in pixel_list:
            z = self.depth_img[y, x]  
            
            pos = pixel2world(x, y, z, self.intrinsic, self.extrinsic)
            pos_list.append(pos)

        
        
        return pos_list, list(pixel_list)

    def expand_grasp_depth(self, pos_list: List, pixel_list: List,
                           extrinsic: Transform,
                           finger_depth=0.05, depth_candidates=5) -> Tuple[List, List]:
        camera_M = Transform(extrinsic.rotation,
                             np.array([0, 0, 1])).as_matrix()  
        direction_vector = np.linalg.inv(camera_M)[:3, 3]  

        
        pos_list_with_depth = []
        eps = 0.1
        if len(pos_list) == 0:
            return pos_list_with_depth, pixel_list
        for depth in np.linspace(-eps * finger_depth, (1.0 + eps) * finger_depth, depth_candidates):
            pos_list_with_depth.extend(pos_list + direction_vector * depth)

        return pos_list_with_depth, pixel_list * depth_candidates

    def excute_grasp(self, prediction, pos_list, easy_mode=False):
        index = prediction.grasp_label.argmax()
        ori = Rotation.from_quat(prediction.grasp_rotation[index].cpu())
        candidate = Grasp(Transform(ori, pos_list[index]),
                          width=prediction.grasp_width[index])  
        print(f"Try grasp: {candidate.pose.to_dict()}, width: {candidate.width}")
        if self.mode == 'sim':
            outcome, width = self.sim.execute_grasp(candidate, remove=False, easy_mode=easy_mode)  
            self.sim.world.log_renderer.export_video()
        elif self.mode == 'real':
            pass
        return outcome

    def move_to_next_view(self, next_view):
        if self.mode == 'sim':
            self.angles = next_view
            self.extrinsic = angles_to_extrinsic(self.angles)
        elif self.mode == 'real':
            
            pass


def visualize_prediction(prediction, pixel_list, pos_list, grasp_exp, depth_expanded=True, best_depth_only=True):
    if len(pos_list) == 0:
        raw_img = np.expand_dims(grasp_exp.depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
        return raw_img, raw_img
    result = EasyDict()
    depth_candidates = int(len(prediction.grasp_label) / len(pixel_list))
    if depth_expanded:  
        result.grasp_label = torch.zeros(len(pixel_list))
        result.grasp_rotation = torch.zeros(len(pixel_list), 4)
        result.grasp_width = torch.zeros(len(pixel_list))
        for i in range(len(pixel_list)):
            q_list = [prediction.grasp_label[j * depth_candidates + i] for j in range(depth_candidates)]
            best_q_index = q_list.index(max(q_list))
            result.grasp_label[i] = q_list[best_q_index]
            result.grasp_rotation[i] = prediction.grasp_rotation[best_q_index * depth_candidates + i]
            result.grasp_width[i] = prediction.grasp_width[best_q_index * depth_candidates + i]
    else:
        assert depth_candidates == 1
        result = prediction

    
    prediction_map_q = np.expand_dims(grasp_exp.depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
    for index, (y, x) in enumerate(pixel_list):
        q = int((result.grasp_label[index] - 0.001) / 0.999 * 255)
        cv2.circle(prediction_map_q, (x, y), 4, (255 - q, 0, q), -1)  

    
    prediction_map_rw = np.expand_dims(grasp_exp.depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
    for index, (y, x) in enumerate(pixel_list):
        if result.grasp_label[index] > 0.94:  
            if best_depth_only:
                
                pos_center = Transform(Rotation.from_quat(result.grasp_rotation[index]),
                                       pos_list[index])
                pixel1, pixel2 = _cal_gripper_pixel(pos_center, result.grasp_width[index],
                                                   grasp_exp.intrinsic.K, grasp_exp.extrinsic)
                cv2.line(prediction_map_rw,
                         (int(pixel1[0]), int(pixel1[1])),
                         (int(pixel2[0]), int(pixel2[1])), (0, 255, 0), 1)

            else:
                for j in range(depth_candidates):
                    pos_center = Transform(Rotation.from_quat(prediction.grasp_rotation[j * depth_candidates + index]),
                                           pos_list[index])
                    pixel1, pixel2 = _cal_gripper_pixel(pos_center, prediction.grasp_width[j * depth_candidates + index],
                                                       grasp_exp.intrinsic.K, grasp_exp.extrinsic)
                    
                    q_color = int(prediction.grasp_label[j * depth_candidates + index] * 255)
                    j_ratio = (j+8) / (depth_candidates+10)
                    cv2.line(prediction_map_rw,
                             (int(pixel1[0]), int(pixel1[1])),
                             (int(pixel2[0]), int(pixel2[1])),
                             (int((255 - q_color) * j_ratio),
                              0,
                              int(q_color * j_ratio)),
                             1)
    return prediction_map_q, prediction_map_rw


def _cal_gripper_pixel(pos_center, grasp_width, intrinsic, extrinsic) -> Tuple[float, float]:
    gripper_offset1 = Transform(Rotation.identity(), [0,  grasp_width / 2, 0])
    gripper_offset2 = Transform(Rotation.identity(), [0, -grasp_width / 2, 0])
    pos1 = pos_center * gripper_offset1
    pos2 = pos_center * gripper_offset2
    
    intrinsic = np.hstack((intrinsic, np.zeros((3, 1))))  
    pos1_cam = extrinsic * pos1
    pixel1 = (intrinsic @ pos1_cam.as_matrix())[:, 3]
    pixel1 /= pixel1[-1]
    pos2_cam = extrinsic * pos2
    pixel2 = (intrinsic @ pos2_cam.as_matrix())[:, 3]
    pixel2 /= pixel2[-1]
    return pixel1, pixel2


def check_all_angles(tsdf, pos, intrinsic, grasp_planner, size=0.3):
    
    step = 100  
    r = 2.0 * 0.3
    extrinsics = torch.zeros(step, step, 7)
    for i, theta in enumerate(np.linspace(1e-3, np.pi / 2.0 - 1e-3, step)):
        for j, phi in enumerate(np.linspace(1e-3, np.pi * 2 - 1e-3, step)):
            extrinsics[i, j] = torch.tensor(angles_to_extrinsic((r, theta, phi)).to_list())
    extrinsics = extrinsics.reshape(step**2, 1, 7)

    pos = pos / args.size - 0.5
    pos = pos[None, None, :]

    tsdf = torch.tensor(tsdf)

    intrinsics = torch.tensor(intrinsic)
    intrinsics = intrinsics[None, :]

    
    batch_size = 200  
    result = None
    for i in tqdm(range(0, step**2, batch_size)):
        batch = EasyDict({'tsdf': tsdf.repeat(batch_size, 1, 1, 1),
                          'point_grasp': pos.repeat(batch_size, 1, 1),
                          'camera_extrinsic': extrinsics[i:i + batch_size],
                          'camera_intrinsic': intrinsics.repeat(batch_size, 1, 1)})
        prediction_batch = grasp_planner.predict_grasp(batch)
        if result is None:
            result = prediction_batch
        else:
            result.append(prediction_batch)

    q = ((result.grasp_label.reshape(step, step) - 0.001) / 0.999 * 255).int().cpu().numpy()
    prediction_map_q = np.zeros((step, step, 3), dtype=np.uint8)
    prediction_map_q[:, :, 0] = 255 - q  
    prediction_map_q[:, :, 2] = q
    return prediction_map_q


def main(args, cfg):
    grasp_planner = GraspPlanner(cfg)
    
    save_dir_root = Path("experiments") / args.experiment_name / "grasp_results" / f"ckpt_{args.ckpt_index}"
    if os.path.exists(save_dir_root):  
        shutil.rmtree(save_dir_root)
    os.makedirs(save_dir_root, exist_ok=True)

    
    if args.seed < 0:
        with open('seed_list.txt', 'r') as f:
            seeds_list = [int(seed) for seed in f.readlines()]
        assert len(seeds_list) >= args.num_rounds, 'seed_list.txt 中的种子数量不足'
    else:
        np.random.seed(args.seed)
        seeds_list = np.random.choice(2**20, args.num_rounds, replace=False)

    results = []
    for index in range(args.num_rounds):  
        set_random_seed(seeds_list[index])  
        save_dir = save_dir_root / f'round_{index:04d}'
        os.makedirs(save_dir, exist_ok=True)
        print('\033[1;34m' + f'===== Round {index:04d} Start =====' + '\033[0m')
        print(f'Model: {args.experiment_name}  Seed: {seeds_list[index]}')
        grasp_exp = GraspExp(args, save_dir, rng=np.random.default_rng(seeds_list[index]))  
        eplot = ExtrinsicMonitor(grasp_exp.intrinsic_list)  
        os.makedirs(save_dir / 'cameras', exist_ok=True)

        
        success_flag = False
        grasp_count = 0
        look_count = 0
        while not success_flag:
            
            look_count += 1
            tsdf, pc = grasp_exp.get_tsdf()
            o3d.io.write_point_cloud(f'{save_dir}/pointcloud_{look_count}.pcd', pc)  
            if args.sim_gui:
                o3d.visualization.draw_geometries([pc])

            print('\033[1;35m' + f'Get depth image from view {grasp_exp.extrinsic.to_list()}.' + '\033[0m')
            
            eplot.add_cam(grasp_exp.angles)
            eplot.draw_and_save(f'{save_dir}/cameras/{args.ckpt_index}_{look_count}_0.png')
            if args.sim_gui:
                
                grasp_exp.sim.world.p.resetDebugVisualizerCamera(
                    cameraDistance=0.6,
                    cameraYaw=90 + grasp_exp.angles[2] / np.pi * 180,
                    cameraPitch=-90 + grasp_exp.angles[1] / np.pi * 180,
                    cameraTargetPosition=[0.15, 0.15, 0.0],
                )
                eplot.show()

            
            pos_list, pixel_list = grasp_exp.get_target_pos(any_point=args.grasp_any_object)
            pos_list_with_depth, pixel_list_duplicated = grasp_exp.expand_grasp_depth(
                pos_list, pixel_list, extrinsic=grasp_exp.extrinsic)

            prediction = grasp_planner(tsdf, pos_list_with_depth,
                                       grasp_exp.extrinsic.to_list(),
                                       grasp_exp.intrinsic_list,
                                       args.size)
            print(f'Evaluate {len(pos_list_with_depth)} grasp positions from current view, '
                  f'and the best possible grasp quality is {prediction.grasp_label.max():.4f}.')

            map_q, map_rw = visualize_prediction(prediction, pixel_list, pos_list, grasp_exp)
            
            
            cv2.imwrite(f'{save_dir}/prediction_map_q_{args.ckpt_index}_{look_count}.png', map_q)
            cv2.imwrite(f'{save_dir}/prediction_map_rw_{args.ckpt_index}_{look_count}.png', map_rw)
            if args.sim_gui:
                cv2.imshow('prediction_map_q', map_q)
                cv2.imshow('prediction_map_rw', map_rw)
                cv2.waitKey(0)

            if prediction.grasp_label.max() < args.grasp_q_thresh and look_count <= args.max_look_time:
                
                print('No good grasp found. Start to sample new view.')
                angles_candidate = generate_extrinsic_neighbor(grasp_exp.extrinsic,
                                                               num_direction=args.new_view_directions,
                                                               num_steps=args.new_view_step)  
                print(f'Sampled {len(angles_candidate)} new view angles.')

                eplot.add_candidates(angles_candidate)
                eplot.draw_and_save(f'{save_dir}/cameras/{args.ckpt_index}_{look_count}_1.png')
                if args.sim_gui:
                    eplot.show()

                next_view = None
                candidate_q = []
                for angles in angles_candidate:  
                    pos_list_with_depth, _ = grasp_exp.expand_grasp_depth(pos_list, pixel_list, extrinsic=angles_to_extrinsic(angles))
                    prediction = grasp_planner(tsdf, pos_list_with_depth,
                                               angles_to_extrinsic(angles).to_list(),
                                               grasp_exp.intrinsic_list,
                                               args.size)
                    candidate_q.append(prediction.grasp_label.max().item())

                
                eplot.add_candidatas_with_quality(angles_candidate, candidate_q)
                eplot.draw_and_save(f'{save_dir}/cameras/{args.ckpt_index}_{look_count}_2.png')
                eplot.clear_candidates()
                if args.sim_gui:
                    eplot.show()

                best_q = max(candidate_q)
                next_view = angles_candidate[candidate_q.index(best_q)]
                print(f'Next best predicted grasp quality is {best_q:.4f} in new view {next_view}.')
                grasp_exp.move_to_next_view(next_view)
                print(f'Move to next view {next_view}.')

            else:
                print('\033[1;33m' + 'Start to execute grasp.' + '\033[0m')
                grasp_count += 1
                success_flag = grasp_exp.excute_grasp(prediction, pos_list_with_depth, easy_mode=True)
                if success_flag:  
                    print('\033[1;32m' + 'Grasp success!' + '\033[0m')
                    break
                else:  
                    grasp_exp.clear_depth_imgs()
                    print('\033[1;31m' + 'Grasp failed!' + '\033[0m' + ' Clear previous observation.')
                    
                    if grasp_count >= args.max_grasp_time:
                        
                        print('\033[1;31m' + 'Tired. Reach max trail count. Stop this scene.' + '\033[0m')
                        break


        print(f'Round {index} finished, looked {look_count} times, tried {grasp_count} times and',
              'succeed' if success_flag else 'failed')
        results.append([success_flag, look_count, grasp_count])

    analyze_results(results, save_dir_root)


def analyze_results(results, save_dir_root):
    '''
    Grasp Success Rate: success / cnt
    Average Look Count: look_count / cnt
    Average Trail Count: grasp_count / cnt, if success
    '''
    cnt = len(results)
    success_cnt = sum([1 for success_flag, _, _ in results if success_flag])
    GSR = success_cnt / cnt
    ALC = sum([look_count for _, look_count, _ in results]) / cnt
    if success_cnt:
        ATC = sum([grasp_count for success_flag, _, grasp_count in results if success_flag]) / success_cnt
    else:
        ATC = np.nan
    print(f'GSR: {GSR:.4f}, ALC: {ALC:.4f}, ATC: {ATC:.4f}')

    with open(save_dir_root / 'results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f'Saving results to {save_dir_root / "results.json"}')


def load_hydra_cfg(args):
    '''load configs from experiment log, and override ckpt_name, device, etc.
    '''
    hydra.core.global_hydra.GlobalHydra.instance().clear()
    initialize(version_base=None,
               config_path="experiments/" + args.experiment_name + "/configs",
               job_name="test_app")
    ckpt_name = get_ckpts(args.experiment_name)[args.ckpt_index]

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    cfg = compose(config_name="config", overrides=[f"load_path='{ckpt_name}'",
                                                   f"device={device}"])
    return cfg


def get_ckpts(experiment, only_epoch_end=True):
    ckpts_path = 'experiments/' + experiment + '/ckpts/'
    ckpts = [f for f in os.listdir(ckpts_path)
             if not only_epoch_end or f.endswith('end.pt')]  
    
    ckpts.sort()
    return [os.path.join(ckpts_path, f) for f in ckpts]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("experiment_name", type=str)
    parser.add_argument("--ckpt-index", type=int, default=-1)
    
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-rounds", type=int, default=20)
    parser.add_argument("--mode", type=str, choices=["sim", "real", "dataset"], default="sim")
    parser.add_argument("--size", type=float, default=0.3)
    
    parser.add_argument("--load-scene", type=str,)
    parser.add_argument("--scene-type", type=str, choices=["pile", "packed"], default="packed")
    parser.add_argument("--object-set", type=str, default="packed/test",)
    parser.add_argument("--num-objects", type=int, default=4,)
    parser.add_argument("--add-noise", type=str, default='dex',)
    parser.add_argument("--sim-gui", action="store_true",)
    parser.add_argument("--record-video", action="store_true",)
    
    parser.add_argument("--grasp-any-object", action="store_true",)
    parser.add_argument("--sample-points-per-view", type=int, default=200,)
    parser.add_argument("--grasp-q-thresh", type=float, default=0.9,)
    parser.add_argument("--new-view-directions", type=int, default=6,)
    parser.add_argument("--new-view-step", type=int, default=3,)
    parser.add_argument("--max-grasp-time", type=int, default=1,)
    parser.add_argument("--max-look-time", type=int, default=12,)

    args = parser.parse_args()
    cfg = load_hydra_cfg(args)
    main(args, cfg)
