import os
import math
import networkx as nx
import numpy as np
import json
import habitat_sim
from hovsg.data.hm3dsem.habitat_utils import make_cfg_mp3d
import quaternion
from scipy.spatial.transform import Rotation as R
from hovsg.graph.object import Object
from viewpoint_base_hovsg.viewpoint import Viewpoint
from collections import defaultdict
from tqdm import tqdm

def read_node_wp(scan, vp):
    with open(f"../HOV-SG/navigation_graph/wp_graph/{scan}_0.json", "r") as f:
        data = json.load(f)
    
    return data['vps'][vp]

def quaternion_to_rotation_matrix(quat):
    """Convert a quaternion (w, x, y, z) into a 3x3 rotation matrix."""
    q = np.quaternion(quat[3], quat[0], quat[1], quat[2])  # Habitat's quat is [x, y, z, w]
    return quaternion.as_rotation_matrix(q)

def get_obb_vertices_with_rotation(obj):
    """Calculate the 8 vertices of an Oriented Bounding Box (OBB) in Habitat Simulator."""
    obj_pcd = obj.pcd
    obb = obj_pcd.get_oriented_bounding_box()
    obb_vertices = obb.get_box_points()
    vertices = np.array(obb_vertices)
    for vertice in vertices:
        vertice[1] += 1.5
    
    return np.array(vertices)

def project_to_camera_space(vertices, camera_transform):
    """Projects 3D vertices into camera space."""
    vertices_homo = np.hstack((vertices, np.ones((vertices.shape[0], 1))))  # Convert to homogeneous
    projected = (camera_transform @ vertices_homo.T).T  # Apply transformation
    return projected[:, :3]  # Remove homogeneous coordinate

def get_camera_intrinsics(sim, sensor_name):
    # Get render camera
    render_camera = sim._sensors[sensor_name]._sensor_object.render_camera

    # Get projection matrix
    projection_matrix = render_camera.projection_matrix

    # Get resolution
    viewport_size = render_camera.viewport

    # Intrinsic calculation
    fx = projection_matrix[0, 0] * viewport_size[0] / 2.0
    fy = projection_matrix[1, 1] * viewport_size[1] / 2.0
    cx = (projection_matrix[2, 0] + 1.0) * viewport_size[0] / 2.0
    cy = (projection_matrix[2, 1] + 1.0) * viewport_size[1] / 2.0

    intrinsics = np.array([
        [fx, 0, cx],
        [0, fy, cy],
        [0,  0,  1]
    ])
    return intrinsics

def project_3D_to_2D(points_3D, K):
    # Convert 3D point to homogeneous coordinates
    n = points_3D.shape[0]

    uv_coords = np.zeros((n, 2), dtype=np.float32)
    for i in range(n):
        x_c, y_c, z_c = points_3D[i]
        # if z_c >= 0:
        #     # This means it's behind the camera in Magnum’s -Z forward convention
        #     uv_coords[i] = [-1, -1]
        #     continue
        
        # Convert to +Z-forward style if you prefer, or just treat z_c as negative
        # and take z_ = -z_c for pinhole formula. We'll do:
        z_ = -z_c  # Now z_ > 0 for points in front
        x_ = x_c   # (No sign flip on x, y)
        y_ = -y_c

        # x_pix = fx * (x_ / z_) + cx,  y_pix = fy * (y_ / z_) + cy
        u = (K[0,0] * x_ / z_) + K[0,2]
        v = (K[1,1] * y_ / z_) + K[1,2]
        uv_coords[i] = [u, v]
        
    return uv_coords

def clip_bbox(bbox):
    edges = [(bbox[0], bbox[1]), (bbox[0], bbox[2]), (bbox[0], bbox[4]), 
            (bbox[1], bbox[3],), (bbox[1], bbox[5]), (bbox[2], bbox[3]),
            (bbox[2], bbox[6]), (bbox[3], bbox[7]), (bbox[4], bbox[5]),
            (bbox[4], bbox[6]), (bbox[5], bbox[7]), (bbox[6], bbox[7])]
    
    clipped_points = []
    near_plane = -0.1
    for (v1, v2) in edges:
        z1, z2 = v1[2], v2[2]

        in_front_1 = (z1 < near_plane)   # True if v1 is 'in front'
        in_front_2 = (z2 < near_plane)

        if (not in_front_1) and (not in_front_2):
            continue

        if in_front_1 and in_front_2:
            clipped_points.append(tuple(v1))
            clipped_points.append(tuple(v2))
            continue
        
        if in_front_1 and (not in_front_2):
            # v1 is in front, v2 is behind
            t = (near_plane - z1) / (z2 - z1)
            intersection = v1 + t * (v2 - v1)
            clipped_points.append(tuple(v1))
            clipped_points.append(tuple(intersection))
        elif (not in_front_1) and in_front_2:
            # v1 is behind, v2 is in front
            t = (near_plane - z2) / (z1 - z2)
            intersection = v2 + t * (v1 - v2)
            clipped_points.append(tuple(v2))
            clipped_points.append(tuple(intersection))

    # remove duplicates
    clipped_points = list(set(clipped_points))
    return np.array(clipped_points, dtype=np.float32)

def habitat_quat_to_np_quat(hab_q):
    """Reorder from [x,y,z,w] -> [w,x,y,z]."""
    return np.quaternion(hab_q.w, hab_q.x, hab_q.y, hab_q.z)

def get_2d_bbox(bbox):
    img_width = 1080
    img_height = 720
    # Ensure coordinates are integers
    min_x = int(np.min(bbox[:, 0]))
    max_x = int(np.max(bbox[:, 0]))
    min_y = int(np.min(bbox[:, 1]))
    max_y = int(np.max(bbox[:, 1]))

    # check overlap
    if max_x < 0 or min_x >= img_width or max_y < 0 or min_y >= img_height:
        return None  # bounding box out of frame
    
    x1 = max(0, min_x)
    y1 = max(0, min_y)
    x2 = min(img_width - 1, max_x)
    y2 = min(img_height - 1, max_y)
    return (x1, y1, x2, y2)

def depth_to_point_cloud(depth_image, K, T_world_cam):
    """
    depth_image: (H, W) array of depth in meters
    K: (3,3) intrinsics [ [fx, 0, cx], [0, fy, cy], [0, 0, 1] ]
    T_world_cam: (4,4) transform from camera -> world

    Returns:
        pts_world: (N,3) array of 3D points in world coordinates
    """
    fx = K[0,0]
    fy = K[1,1]
    cx = K[0,2]
    cy = K[1,2]

    # We'll accumulate points into a list for demonstration
    points_world = []

    for v in range(depth_image.shape[0]):
        for u in range(depth_image.shape[1]):
            z = depth_image[v, u]
            if z <= 0 or z >= 100.0:
                # skip invalid or far clipping
                continue

            # In Habitat default coords, the camera looks along -Z, so let's do:
            # Z_cam = -z if you want the forward direction to have negative Z.
            # Or you can do a direct unprojection if you keep consistent sign convention.
            Z_cam = -z

            X_cam = -(u - cx) / fx * Z_cam
            Y_cam = (v - cy) / fy * Z_cam

            # Now transform [X_cam, Y_cam, Z_cam, 1] into world coords
            pt_cam_h = np.array([X_cam, Y_cam, Z_cam, 1.0])
            pt_world_h = T_world_cam @ pt_cam_h
            pt_world = pt_world_h[:3] / pt_world_h[3]

            points_world.append(pt_world)

    pts_world = np.array(points_world)
    return pts_world

def load_depth_intrinsics(H, W):
    """
    Load the depth camera intrinsics.

    Returns:
        Depth camera intrinsics as a numpy array (3x3 matrix).
    """        
    hfov = 90 * np.pi / 180
    vfov = 2 * math.atan(np.tan(hfov / 2) * H / W)
    fx = W / (2.0 * np.tan(hfov / 2.0))
    fy = H / (2.0 * np.tan(vfov / 2.0))
    cx = W / 2
    cy = H / 2
    depth_camera_matrix = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
    return depth_camera_matrix
    
def hovsg_depth_to_world(depth_image, pose):
    rot = R.from_quat(pose[3:])
    pose_mat = np.eye(4)
    pose_mat[:3, :3] = rot.as_matrix()
    pose_mat[:3, 3] = pose[:3]
    values = [float(val) for val in pose_mat.flatten()]
    transformation_matrix = np.array(values).reshape((4, 4))
    C = np.eye(4)
    C[1, 1] = -1
    C[2, 2] = -1
    transformation_matrix = np.matmul(transformation_matrix, C)

    # load depth camera intrinsics
    H = depth_image.shape[0]
    W = depth_image.shape[1]
    K = load_depth_intrinsics(H, W)
    fx = K[0,0]
    fy = K[1,1]
    cx = K[0,2]
    cy = K[1,2]
    # create point cloud
    pt_worlds = []
    for y in range(H):
        for x in range(W):
            z = depth_image[y, x]
            if z <= 0 or z >= 100.0:
                # skip invalid or far clipping
                continue

            # convert to 3D
            X = (x - cx) * z / fx
            Y = (y - cy) * z / fy
            Z = z
            # convert to open3d point cloud
            pt_cam_h = np.array([X, Y, Z, 1.0])
            pt_world_h = transformation_matrix  @ pt_cam_h
            pt_world = pt_world_h[:3] / pt_world_h[3]
            pt_worlds.append(pt_world)

    return pt_worlds

def get_sim(scan, position, objs, yaw):
    scene_name = scan
    root_dataset_dir = "../scene_datasets/mp3d"
    scene_data_dir = f"{root_dataset_dir}/{scene_name}/"
    save_dir = f"viewpoint_base_hovsg"

    scene_mesh = os.path.join(scene_data_dir, scene_name + ".glb")
    # print("scene:", scene_mesh)

    sim_settings = {
        "scene": scene_mesh,
        "default_agent": 0,
        "sensor_height": 1.5,
        "color_sensor": True,
        "depth_sensor": True,
        "semantic_sensor": True,
        "lidar_sensor": False,
        "move_forward": 0.2,
        "move_backward": 0.2,
        "turn_left": 5,
        "turn_right": 5,
        "look_up": 5,
        "look_down": 5,
        "look_left": 5,
        "look_right": 5,
        "width": 1080,
        "height": 720,
        "enable_physics": False,
        "seed": 42,
        "lidar_fov": 360,
        "depth_img_for_lidar_n": 20,
        "img_save_dir": save_dir,
    }
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    sim_cfg = make_cfg_mp3d(sim_settings, root_dataset_dir, scene_data_dir, scene_name, print_scene=False)
    sim = habitat_sim.Simulator(sim_cfg)

    agent = sim.initialize_agent(sim_settings["default_agent"])
    agent_state = habitat_sim.AgentState()
    agent_state.position = position
    quad_1 = R.from_euler('y', np.radians(yaw)).as_quat()
    quad_2 = R.from_euler('x', np.radians(0)).as_quat()
    new_quat = R.from_quat(quad_1) * R.from_quat(quad_2)
    quad = new_quat.as_quat()
    agent_state.rotation = quad
    agent.set_state(agent_state)

    bboxes = []
    for obj in objs:
        bbox = get_obb_vertices_with_rotation(obj)
        bboxes.append(bbox)
    # print("Bounding box:", bbox)

    translation_0 = agent.get_state().sensor_states["color_sensor"].position
    hab_q = agent.get_state().sensor_states["color_sensor"].rotation
    q = habitat_quat_to_np_quat(hab_q)
    rotation_0 = quaternion.as_rotation_matrix(q)
    T_world_camera0 = np.eye(4)
    T_world_camera0[0:3,0:3] = rotation_0
    T_world_camera0[0:3,3] = translation_0
    T_camera0_world = np.linalg.inv(T_world_camera0)
    # print("T_world_camera0:", T_world_camera0)

    # depth = sim.get_sensor_observations()["depth_sensor"]
    # pose = position + quad.tolist()
    # hovsg_points = hovsg_depth_to_world(depth, pose)

    # intrinsics = get_camera_intrinsics(sim, "color_sensor")
    # real_points = depth_to_point_cloud(depth, intrinsics, T_world_camera0)
    # assert len(real_points) == len(hovsg_points), "Length mismatch between real points and HOVSG points"
    # for i in range(len(real_points)):
    #     real_points[i] = np.array(real_points[i])
    #     hovsg_points[i] = np.array(hovsg_points[i])
    #     print(real_points[i]-hovsg_points[i])
    #     if i == 10:
    #         exit(0)

    projected_bboxes = [project_to_camera_space(bbox, T_camera0_world) for bbox in bboxes]
    # print("Projected bbox:", projected_bbox)
    
    intrinsics = get_camera_intrinsics(sim, "color_sensor")
    # print("Intrinsics:", intrinsics)

    projected_bboxes = [clip_bbox(projected_bbox) for projected_bbox in projected_bboxes]
    # print("Clipped bbox:", projected_bbox)

    coords2D = {}
    for i, projected_bbox in enumerate(projected_bboxes):
        if len(projected_bbox) == 0:
            continue
        projected_2D = project_3D_to_2D(projected_bbox, intrinsics)
        bbox_2d = get_2d_bbox(projected_2D)
        if bbox_2d is not None:
            coords2D[objs[i].object_id] = bbox_2d
            
    sim.close()
    return coords2D

def load_objs(obj_dir):
    obj_files = [x for x in os.listdir(obj_dir) if x.endswith(".ply")]
    objs = []
    for obj_file in obj_files:
        obj_name = obj_file.split(".")[0]
        obj = Object(obj_name, None, obj_name)
        obj.load(obj_dir)
        objs.append(obj)
    print(f"Loaded {len(objs)} objects")
    return objs

def load_vps(vp_dir):
    vp_files = [x for x in os.listdir(vp_dir) if x.endswith(".ply")]
    vps = []
    for vp_file in vp_files:
        vp_name = vp_file.split(".")[0]
        vp = Viewpoint(vp_name, '0')
        vp.load(vp_dir)
        vps.append(vp)
    print(f"Loaded {len(vps)} viewpoints")
    return vps

def filter_few_points(objs):
    keeped_obj_ids = []
    for obj in objs:
        if len(obj.pcd.points) < 5:
            continue
        else:
            keeped_obj_ids.append(obj.object_id)
    filtered_objs = [x for x in objs if x.object_id in keeped_obj_ids]
    return filtered_objs

with open("../VLN-DUET/datasets/R2R/connectivity/scans.txt", "r") as f:
    scans = f.readlines()
scans = [scan.strip() for scan in scans]

for scan in scans:
    vp_dir = f"../HOV-SG/viewpoint_base_hovsg/wp_vps/{scan}"
    obj_dir = f"../HOV-SG/data/scene_graphs/hm3dsem/{scan}/graph/objects"
    try:
        vps = load_vps(vp_dir)
        objs = load_objs(obj_dir)
    except:
        print(f"Error loading {scan}")
        continue

    if os.path.exists(f"viewpoint_base_hovsg/wp_bbox/{scan}_bbox.json"):
        with open(f"viewpoint_base_hovsg/wp_bbox/{scan}_bbox.json", "r") as f:
            results = json.load(f)
            print(f"Loaded {len(results)} saved results")
    else:
        results = {}
    progress_bar = tqdm(total=len(vps)*12, desc="Processing objects", unit="object")
    for vp in vps:
        position = read_node_wp(scan, vp.vp_id)
        assert position == vp.vp_pos, "Position mismatch"
        for angle in range(12):
            yaw = angle * 30
            key = f"{vp.vp_id}_{yaw}"
            if key in results:
                progress_bar.update(1)
                continue
            vp_objs = [obj for obj in objs if obj.object_id in vp.objects]
            filterer_objs = filter_few_points(vp_objs)
            coords2D = get_sim(scan, position, filterer_objs, yaw)
            results[key] = coords2D

            with open(f"viewpoint_base_hovsg/wp_bbox/{scan}_bbox.json", "w") as f:
                json.dump(results, f, indent=4)
            
            progress_bar.update(1)
    
    progress_bar.close()
