import cv2
import uuid
import argparse
import numpy as np
import multiprocessing as mp
import scipy.signal as signal
from tqdm import tqdm
from pathlib import Path
try:
    import open3d as o3d
except ImportError:
    print("Running without Open3D.")

from vgn.io import my_IO
from vgn.perception import *
from vgn.grasp import Grasp, Label
from vgn.simulation import ClutterRemovalSim
from vgn.utils.transform import Rotation, Transform
from vgn.utils.implicit import get_mesh_pose_list_from_world
from vgn.utils.misc import apply_noise
from utils.misc import set_random_seed
from utils.transform import *


def generate_scene_and_grasp(args, rank):
    """
    args: arguments from the command line
    rank: index for multiprocessing
    """
    seed = np.random.randint(0, 1000) + rank
    np.random.seed(seed)
    sim = ClutterRemovalSim(args.scene, args.object_set, gui=args.sim_gui,
                            log_render=True, save_dir=args.root)
    args.size = sim.size  
    if args.record_video:
        default_extrinsic = get_default_extrinsic(args.size)  
        sim.world.log_renderer.reset(default_extrinsic)  
        sim.world.log_renderer.enable()

    grasps_per_worker = args.num_grasps // args.num_proc  
    pbar = tqdm(total=grasps_per_worker, disable=rank != 0)  

    myio = my_IO(args.root)  
    if rank == 0:
        myio.write_setup(  
            sim.size,  
            sim.camera.intrinsic,
            sim.gripper.max_opening_width,  
            sim.gripper.finger_depth,  
        )
        if args.save_scene:  
            (args.root / "mesh_pose_list").mkdir(parents=True, exist_ok=True)
    
    for scene_idx in range(grasps_per_worker // args.grasps_per_scene):
        scene_id = uuid.uuid4().hex
        object_count = np.random.poisson(args.object_count_lambda) + 1  
        sim.reset(object_count)  
        sim.save_state()  
        
        depth_imgs, _, segmentation_masks, extrinsics = render_images(sim, viewpoint_count=12)
        
        if args.observation_type == "facing":  
            tsdf = generate_tsdf_from_depth(args.size, args.grid_resolution,
                                            depth_imgs[[0]], sim.camera.intrinsic,
                                            extrinsics[[0]], add_noise='dex')
        elif args.observation_type == "side":  
            assert len(depth_imgs) >= 2, "Need at least 2 depth images for side observation"
            tsdf = generate_tsdf_from_depth(args.size, args.grid_resolution,
                                            depth_imgs[[1]], sim.camera.intrinsic,
                                            extrinsics[[1]], add_noise='dex')
        elif args.observation_type == "multiview":  
            tsdf = generate_tsdf_from_depth(args.size, args.grid_resolution,
                                            depth_imgs, sim.camera.intrinsic,
                                            extrinsics, add_noise='dex')
        else:
            raise ValueError("Invalid observation type")
        pc = tsdf.get_cloud()
        if args.save_pointcloud:  
            (args.root / "scenes").mkdir(parents=True, exist_ok=True)
            o3d.io.write_point_cloud(str(args.root / "scenes" / ("pc_scene_{:03d}".format(scene_idx) + ".pcd")), pc)  
        if args.sim_gui:
            o3d.visualization.draw_geometries([pc])
        if pc.is_empty():  
            print("Point cloud empty, skipping scene")
            continue
        grid = tsdf.get_grid()
        myio.write_voxel_grid(scene_id, grid)  
        myio.write_depth_image(scene_id, depth_imgs[[0]], extrinsics[[0]])  

        if args.save_scene:
            mesh_pose_list = get_mesh_pose_list_from_world(sim.world, args.object_set)
            myio.write_mesh_pose_list(scene_id, mesh_pose_list, name="mesh_pose_list")

        
        tsdf = generate_tsdf_from_depth(args.size, 120,
                                        depth_imgs, sim.camera.intrinsic,
                                        extrinsics)  
        pc = tsdf.get_cloud()
        bounding_box = o3d.geometry.AxisAlignedBoundingBox(sim.lower, sim.upper)  
        pc = pc.crop(bounding_box)  
        
        
        depth = depth_imgs[0]
        extrinsic = extrinsics[0]
        extrinsic_matrix = Transform.from_list(list(extrinsic)).as_matrix()
        camera_M = Transform(Rotation.from_quat(list(extrinsic)[:4]),
                             np.array([0, 0, 1])).as_matrix()  
        direction_vector = np.linalg.inv(camera_M)[:3, 3]  
        segmentation_mask = segmentation_masks[0]  
        valid_point = np.nonzero(segmentation_mask)  

        if args.with_heatmap:  
            visulizer = GraspVisualizer(depth_imgs[0], sim.camera.intrinsic.K, extrinsic_matrix, camera_M)

        for _ in range(args.grasps_per_scene):
            sim.world.log_renderer.reset()


            point, pixel = sample_grasp_point_from_depth(depth, valid_point,
                                                         sim.camera.intrinsic.K, extrinsic_matrix)
            
            eps = 0.1
            grasp_depth = np.random.uniform(-eps * sim.gripper.finger_depth, (1.0 + eps) * sim.gripper.finger_depth)
            grasp_point = point + direction_vector * grasp_depth

            label, yaw_from_view, width = evaluate_grasp_point(sim, grasp_point, direction_vector, num_rotations=args.rotation_per_pos)
            
            if args.with_heatmap:
                visulizer.add_grasp(pixel, point, direction_vector, yaw_from_view, label, width)

            
            myio.write_grasp_yaw(scene_id, grasp_point, yaw_from_view, width, label)
            if label != 0:
                sim.world.log_renderer.export_video()
            pbar.update()

        if args.with_heatmap:
            visulizer.save(scene_id)

    pbar.close()
    print('Process %d finished!' % rank)


def get_default_extrinsic(size):
    origin = Transform(Rotation.identity(), np.r_[size / 2, size / 2, size / 3])
    r = 2 * size
    theta = np.pi / 3.0
    phi = - np.pi / 2.0
    return camera_on_sphere(origin, r, theta, phi)


def get_random_extrinsic(size=0.3,
                         origin=Transform(Rotation.identity(), np.r_[0.3 / 2, 0.3 / 2, 0.0])):
    r = np.random.uniform(1.6, 2.4) * size
    theta = np.random.uniform(0.0, np.pi / 3.0)  
    phi = np.random.uniform(0.0, 2.0 * np.pi)
    return camera_on_sphere(origin, r, theta, phi)  


def render_images(sim, viewpoint_count=6):
    height, width = sim.camera.intrinsic.height, sim.camera.intrinsic.width
    origin = Transform(Rotation.identity(), np.r_[sim.size / 2, sim.size / 2, 0.0])

    extrinsics = np.empty((viewpoint_count, 7), np.float32)  
    depth_imgs = np.empty((viewpoint_count, height, width), np.float32)
    rgb_imgs = np.empty((viewpoint_count, height, width, 3), np.float32)
    segmentation_masks = np.empty((viewpoint_count, height, width), np.float32)
    for i in range(viewpoint_count):  
        extrinsic = get_random_extrinsic(sim.size, origin)  
        rgb_img, depth_img, segmentation_mask = sim.camera.render(extrinsic)  
        rgb_img = np.array(rgb_img)

        extrinsics[i] = extrinsic.to_list()
        rgb_imgs[i] = rgb_img
        depth_imgs[i] = depth_img
        segmentation_masks[i] = segmentation_mask

    return depth_imgs, rgb_imgs, segmentation_masks, extrinsics


def generate_tsdf_from_depth(size, grid_resolution, depth_imgs, intrinsic, extrinsics, add_noise=''):
    assert add_noise in ['', 'dex', 'trans', 'norm']
    assert len(depth_imgs) == len(extrinsics)
    depth_imgs = np.array([apply_noise(x, add_noise) for x in depth_imgs])
    tsdf = create_tsdf(size, grid_resolution, depth_imgs, intrinsic, extrinsics)
    return tsdf


def sample_grasp_point_from_pointcloud(point_cloud, intrinsic, extrinsic):
    points = np.asarray(point_cloud.points)  
    normals = np.asarray(point_cloud.normals)  
    ok = False
    
    while not ok:
        
        idx = np.random.randint(len(points))
        point, normal = points[idx], normals[idx]
        
        ok = True
    
    pixel = world2pixel(*point, intrinsic, extrinsic)
    return point, pixel, normal


def sample_grasp_point_from_depth(depth, valid_point,
                                  intrinsic, extrinsic):
    index = np.random.randint(len(valid_point[0]))
    x, y = valid_point[1][index], valid_point[0][index]  
    z = depth[y, x]  

    
    pos = pixel2world(x, y, z, intrinsic, extrinsic)
    return pos, (x, y)


def evaluate_grasp_point(sim, pos, normal, num_rotations=6):
    z_axis = -normal  
    x_axis = np.r_[1.0, 0.0, 0.0]
    if np.isclose(np.abs(np.dot(x_axis, z_axis)), 1.0, 1e-4):  
        x_axis = np.r_[0.0, 1.0, 0.0]  
    y_axis = np.cross(z_axis, x_axis)  
    x_axis = np.cross(y_axis, z_axis)
    R = Rotation.from_matrix(np.vstack((x_axis, y_axis, z_axis)).T)  

    
    yaws = np.linspace(0.0, np.pi, num_rotations, endpoint=False)  
    yaws += np.random.uniform(0.0, np.pi / num_rotations)  
    outcomes, widths = [], []

    for yaw in yaws:
        ori = R * Rotation.from_euler("z", yaw)
        sim.restore_state()  
        candidate = Grasp(Transform(ori, pos), width=sim.gripper.max_opening_width)  
        outcome, width = sim.execute_grasp(candidate, remove=False)  
        outcomes.append(outcome)
        widths.append(width)

    
    
    successes = (np.asarray(outcomes) == Label.SUCCESS).astype(float)
    if np.sum(successes):
        peaks, properties = signal.find_peaks(  
            x=np.r_[0, successes, 0], height=1, width=1
        )
        idx_of_widest_peak = peaks[np.argmax(properties["widths"])] - 1
        ori = R * Rotation.from_euler("z", yaws[idx_of_widest_peak])
        width = widths[idx_of_widest_peak]
        return int(np.max(outcomes)), yaws[idx_of_widest_peak], width
    else:
        return int(np.max(outcomes)), yaw, width  


class GraspVisualizer():
    def __init__(self, depth_img, intrinsic, extrinsic, camera_M):
        self.depth_img = depth_img
        self.grasp_map_q = np.expand_dims(depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
        self.grasp_map_rw = np.expand_dims(depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
        self.normal_map = np.expand_dims(depth_img * 255, -1).repeat(3, -1).astype(np.uint8)
        self.intrinsic = intrinsic
        self.extrinsic = extrinsic
        self.camera_M = camera_M

    def add_grasp(self, pixel, point, direction_vector, yaw_from_view, label, width):
        
        color = (0, 0, 255) if label == 1 else (255, 0, 0)  
        cv2.circle(self.grasp_map_q, (pixel[0], pixel[1]), 2, color, -1)
        cv2.circle(self.grasp_map_rw, (pixel[0], pixel[1]), 1, color, -1)
        
        if label == 1:  
            z_axis = -direction_vector  
            x_axis = np.r_[1.0, 0.0, 0.0]
            if np.isclose(np.abs(np.dot(x_axis, z_axis)), 1.0, 1e-4):  
                x_axis = np.r_[0.0, 1.0, 0.0]  
            y_axis = np.cross(z_axis, x_axis)  
            x_axis = np.cross(y_axis, z_axis)
            R = Rotation.from_matrix(np.vstack((x_axis, y_axis, z_axis)).T)  
            ori = R * Rotation.from_euler("z", yaw_from_view)
            
            pos_center = Transform(ori, point)

            gripper_offset1 = Transform(Rotation.identity(), [0, width / 2, 0])
            gripper_offset2 = Transform(Rotation.identity(), [0, -width / 2, 0])
            pos1 = pos_center * gripper_offset1
            pos2 = pos_center * gripper_offset2
            
            intrinsic = np.hstack((self.intrinsic, np.zeros((3, 1))))  
            extrinsic = Transform.from_matrix(self.extrinsic)
            pos1_cam = extrinsic * pos1
            pixel1 = (intrinsic @ pos1_cam.as_matrix())[:, 3]
            pixel1 /= pixel1[-1]
            pos2_cam = extrinsic * pos2
            pixel2 = (intrinsic @ pos2_cam.as_matrix())[:, 3]
            pixel2 /= pixel2[-1]
            
            cv2.line(self.grasp_map_rw,
                     (int(pixel1[0]), int(pixel1[1])),
                     (int(pixel2[0]), int(pixel2[1])), (0, 255, 0), 1)
        
        
        normal_color = tuple(map(lambda x: int(x * 127 + 127), self.camera_M[:3, :3] @ direction_vector))  
        cv2.circle(self.normal_map, (pixel[0], pixel[1]), 4, normal_color, -1)

    def save(self, scene_id):
        path = args.root / "heat_map"
        path.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(f'{path}/{scene_id}_grasp_q.png', self.grasp_map_q)
        cv2.imwrite(f'{path}/{scene_id}_grasp_rw.png', self.grasp_map_rw)
        cv2.imwrite(f'{path}/{scene_id}_depth.png', self.depth_img * 255)
        cv2.imwrite(f'{path}/{scene_id}_normal.png', self.normal_map)


def clean_balance_data(root):
    myio = my_IO(args.root)
    df = myio.read_df()  
    positives = df[df["label"] == 1]  
    negatives = df[df["label"] == 0]

    print("Before clean and balance:")
    print("Number of samples:", len(df.index))
    print("Number of positives:", len(positives.index))
    print("Number of negatives:", len(negatives.index))

    df.drop(df[df["x"] < 0.02].index, inplace=True)
    df.drop(df[df["y"] < 0.02].index, inplace=True)
    df.drop(df[df["z"] < 0.02].index, inplace=True)
    df.drop(df[df["x"] > 0.28].index, inplace=True)
    df.drop(df[df["y"] > 0.28].index, inplace=True)
    df.drop(df[df["z"] > 0.28].index, inplace=True)

    
    positives = df[df["label"] == 1]
    negatives = df[df["label"] == 0]
    i = np.random.choice(negatives.index, len(negatives.index) - len(positives.index), replace=False)
    df = df.drop(i)
    positives = df[df["label"] == 1]
    negatives = df[df["label"] == 0]
    myio.write_df(df)

    print("After clean and balance:")
    print("Number of samples:", len(df.index))
    print("Number of positives:", len(positives.index))
    print("Number of negatives:", len(negatives.index))

    
    
    grasp_scene = df["scene_id"].values
    for f in (root / "depth_imgs").iterdir():
        if f.suffix == ".npz" and f.stem not in grasp_scene:
            print("Removed scene", f.stem)
            f.unlink()  
            if (args.root / "mesh_pose_list").is_dir():
                (root / "mesh_pose_list").joinpath(f.stem + ".npz").unlink()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=Path, default=Path("./data/pile/data_pile_bind_minimal"))

    
    parser.add_argument("--scene", type=str, choices=["pile", "packed"], default="packed")
    parser.add_argument("--object-count-lambda", type=int, default=4)
    parser.add_argument("--object-set", type=str, default="packed/train")
    parser.add_argument("--observation-type", type=str, choices=["facing", "side", 'multiview'], default="facing")
    parser.add_argument("--grid-resolution", type=int, default=40)

    
    parser.add_argument("--num-grasps", type=int, default=4000000)
    parser.add_argument("--grasps-per-scene", type=int, default=240)  
    parser.add_argument("--rotation-per-pos", type=int, default=12)

    
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-proc", type=int, default=40)
    parser.add_argument("--save-scene", action="store_true")
    parser.add_argument("--record-video", action="store_true")
    parser.add_argument("--sim-gui", action="store_true")
    parser.add_argument("--with-heatmap", action="store_true")
    parser.add_argument("--save-pointcloud", action="store_true")
    args = parser.parse_args()

    
    
    
    
    

    
    if args.root.exists():
        breakpoint()  
        import shutil
        shutil.rmtree(args.root)

    
    print('====== Summary ======')
    print(f'Generate data with {args.object_set} object in {args.scene} scene.')
    print(f'{args.num_grasps} grasps, {args.grasps_per_scene} grasps per scene, {args.num_proc} workers in total.')
    print(f'Output to {args.root}.')
    print('=====================')
    set_random_seed(args.seed)  

    if args.num_proc > 1:
        pool = mp.Pool(processes=args.num_proc)
        for i in range(args.num_proc):
            pool.apply_async(func=generate_scene_and_grasp, args=(args, i))
        pool.close()
        pool.join()
    else:
        generate_scene_and_grasp(args, 0)

    print('=====================')
    clean_balance_data(args.root)
