import os
import json
import yaml
import math
from PIL import Image
import open3d as o3d
import networkx as nx
import numpy as np
import habitat_sim
from tqdm import tqdm
from omegaconf import DictConfig
from hovsg.graph.graph import Graph
from scipy.spatial.transform import Rotation as R
from hovsg.data.hm3dsem.habitat_utils import save_obs, make_cfg_mp3d
from viewpoint_base_hovsg.viewpoint import Viewpoint
from scipy.spatial import cKDTree
from hovsg.utils.graph_utils import pcd_denoise_dbscan, find_intersection_share, find_overlapping_ratio_faiss

class HM3DSemDataset():
    """
    Dataset class for the Habitat Matterport3D Semantic dataset.

    This class provides an interface to load RGB-D data samples from the ScanNet
    dataset. The dataset format is assumed to follow the ScanNet v2 dataset format.
    """    
    def __init__(self, rgb_data_list, depth_data_list, pose_data_list):
        """
        Args:
            root_dir: Path to the root directory containing the dataset.
            transforms: Optional transformations to apply to the data.
        """
        super(HM3DSemDataset, self).__init__()
        self.data_list = list(zip(rgb_data_list, depth_data_list, pose_data_list))
        self.rgb_H = self._load_image(self.data_list[0][0]).size[1]
        self.rgb_W = self._load_image(self.data_list[0][0]).size[0]
        self.depth_intrinsics = self._load_depth_intrinsics(self.rgb_H, self.rgb_W)
        self.scale = 1000.0
    
    def __getitem__(self, idx):
        """
        Get a data sample based on the given index.

        Args:
            idx: Index of the data sample.

        Returns:
            RGB image and depth image as numpy arrays.
        """
        rgb_path, depth_path, pose_path = self.data_list[idx]
        rgb_image = self._load_image(rgb_path)
        depth_image = self._load_depth(depth_path)
        pose = self._load_pose(pose_path)
        depth_intrinsics = self._load_depth_intrinsics(self.rgb_H, self.rgb_W)
        return rgb_image, depth_image, pose, list(), depth_intrinsics
        
    def _load_image(self, path):
        """
        Load the RGB image from the given path.

        Args:
            path: Path to the RGB image file.

        Returns:
            RGB image as a numpy array.
        """
        # Load the RGB image using PIL
        rgb_image = Image.open(path)
        return rgb_image

    def _load_depth(self, path):
        """
        Load the depth image from the given path.

        Args:
            path: Path to the depth image file.

        Returns:
            Depth image as a numpy array.
        """
        # Load the depth image using OpenCV
        depth_image = Image.open(path)
        return depth_image
    
    def _load_pose(self, path):
        """
        Load the camera pose from the given path.

        Args:
            path: Path to the camera pose file.

        Returns:
            Camera pose as a numpy array (4x4 matrix).
        """
        with open(path, "r") as file:
            line = file.readline().strip()
            values = line.split()
            values = [float(val) for val in values]
            transformation_matrix = np.array(values).reshape((4, 4))
            C = np.eye(4)
            C[1, 1] = -1
            C[2, 2] = -1
            transformation_matrix = np.matmul(transformation_matrix, C)
        return transformation_matrix
    
    def _load_depth_intrinsics(self, H, W):
        """
        Load the depth camera intrinsics.

        Returns:
            Depth camera intrinsics as a numpy array (3x3 matrix).
        """        
        hfov = 90 * np.pi / 180
        vfov = 2 * math.atan(np.tan(hfov / 2) * H / W)
        fx = W / (2.0 * np.tan(hfov / 2.0))
        fy = H / (2.0 * np.tan(vfov / 2.0))
        cx = W / 2
        cy = H / 2
        depth_camera_matrix = np.array([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]])
        return depth_camera_matrix

    def create__pcd(self, rgb, depth, camera_pose=None):
        """
        Create a point cloud from RGB-D images.

        Args:
            rgb: RGB image as a numpy array.
            depth: Depth image as a numpy array.
            camera_pose: Camera pose as a numpy array (4x4 matrix).

        Returns:
            Point cloud as an Open3D object.
        """
        # convert rgb and depth images to numpy arrays
        rgb = np.array(rgb)
        depth = np.array(depth)
        # load depth camera intrinsics
        H = rgb.shape[0]
        W = rgb.shape[1]
        camera_matrix = self._load_depth_intrinsics(H, W)
        # create point cloud
        y, x = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")
        depth = depth.astype(np.float32) / 1000.0
        mask = depth > 0
        x = x[mask]
        y = y[mask]
        depth = depth[mask]
        # convert to 3D
        X = (x - camera_matrix[0, 2]) * depth / camera_matrix[0, 0]
        Y = (y - camera_matrix[1, 2]) * depth / camera_matrix[1, 1]
        Z = depth
        # convert to open3d point cloud
        points = np.hstack((X.reshape(-1, 1), Y.reshape(-1, 1), Z.reshape(-1, 1)))
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        colors = rgb[mask]
        pcd.colors = o3d.utility.Vector3dVector(colors / 255.0)
        pcd.transform(camera_pose)
        return pcd
    
    def __len__(self):
        return len(self.data_list)

def get_mp3d_poses(scan):
    with open(f"vp2pos/vp2pos_{scan}.json", "r") as f:
        vp2pos = json.load(f)
    vp2poses = {}
    for vp, pos in vp2pos.items():
        poses = []
        for angle in range(0, 360, 30):
            angle = np.radians(angle)
            quad = R.from_euler('y', angle).as_quat()
            poses.append(pos + quad.tolist())
        vp2poses[vp] = poses

    return vp2poses

def create_mp3d_views(scan, vp2poses):
    scene_name = scan
    root_dataset_dir = "../scene_datasets/mp3d"
    scene_data_dir = f"{root_dataset_dir}/{scene_name}/"
    save_dir = f"viewpoint_base_hovsg/mp3d_views/{scene_name}"

    scene_mesh = os.path.join(scene_data_dir, scene_name + ".glb")
    print("scene:", scene_mesh)

    sim_settings = {
        "scene": scene_mesh,
        "default_agent": 0,
        "sensor_height": 1.5,
        "color_sensor": True,
        "depth_sensor": True,
        "semantic_sensor": True,
        "lidar_sensor": False,
        "move_forward": 0.2,
        "move_backward": 0.2,
        "turn_left": 5,
        "turn_right": 5,
        "look_up": 5,
        "look_down": 5,
        "look_left": 5,
        "look_right": 5,
        "width": 1080,
        "height": 720,
        "enable_physics": False,
        "seed": 42,
        "lidar_fov": 360,
        "depth_img_for_lidar_n": 20,
        "img_save_dir": save_dir,
    }
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    sim_cfg = make_cfg_mp3d(sim_settings, root_dataset_dir, scene_data_dir, scene_name)
    sim = habitat_sim.Simulator(sim_cfg)

    # # initialize the agent
    agent = sim.initialize_agent(sim_settings["default_agent"])
    agent_state = habitat_sim.AgentState()
    random_pt = sim.pathfinder.get_random_navigable_point()
    agent_state.position = random_pt
    agent.set_state(agent_state)

    agent_state = agent.get_state()
    print(
        "agent_state: position",
        agent_state.position,
        "rotation",
        agent_state.rotation,
    )

    for vp, poses in tqdm(vp2poses.items(), desc="Generating viewpoints", total=len(vp2poses)):
        steps = 0
        for pose in poses:
            agent = sim.get_agent(0)
            
            agent_state = habitat_sim.AgentState()
            agent_state.position = pose[:3]
            agent_state.rotation = pose[3:]

            agent.set_state(agent_state, reset_sensors=True, infer_sensor_states=False)
            obs = sim.get_sensor_observations(0)
            vp_save_dir = os.path.join(save_dir, vp)
            os.makedirs(vp_save_dir, exist_ok=True)
            save_obs(vp_save_dir, sim_settings, obs, pose, steps)
            steps += 1

def aabb_overlap(major_pcd, candidate_pcd):
    """Check if two point clouds have overlapping axis-aligned bounding boxes."""
    major_bbox = major_pcd.get_axis_aligned_bounding_box()
    candidate_bbox = candidate_pcd.get_axis_aligned_bounding_box()
    
    # Convert bounding box min/max coordinates to NumPy arrays
    major_min, major_max = major_bbox.min_bound, major_bbox.max_bound
    candidate_min, candidate_max = candidate_bbox.min_bound, candidate_bbox.max_bound

    # Check for overlap in all three dimensions
    return all(major_max >= candidate_min) and all(candidate_max >= major_min)

def create_vps(hovsg, vp2poses, scan):
    vp_save_dir = f"viewpoint_base_hovsg/mp3d_vps/{scan}"
    os.makedirs(vp_save_dir, exist_ok=True)

    objects = hovsg.objects
    all_viewpoints = []
    for vp in tqdm(vp2poses.keys(), desc="Creating viewpoints", total=len(vp2poses), position=0, leave=True):
        vp_dir = f"viewpoint_base_hovsg/mp3d_views/{scan}/{vp}"

        pose_path = [os.path.join(vp_dir, 'pose', x) for x in os.listdir(os.path.join(vp_dir, "pose"))]
        rgb_path = [os.path.join(vp_dir, 'rgb', x) for x in os.listdir(os.path.join(vp_dir, "rgb"))]
        depth_path = [os.path.join(vp_dir, 'depth', x) for x in os.listdir(os.path.join(vp_dir, "depth"))]

        pose_path.sort(key=lambda x: int(x.split("/")[-1].split(".")[0].split("_")[-1]))
        rgb_path.sort(key=lambda x: int(x.split("/")[-1].split(".")[0].split("_")[-1]))
        depth_path.sort(key=lambda x: int(x.split("/")[-1].split(".")[0].split("_")[-1]))
        
        dataset = HM3DSemDataset(rgb_path, depth_path, pose_path)

        vp_pcd = o3d.geometry.PointCloud()
        for i in range(len(dataset)):
            rgb_image, depth_image, pose, _, depth_intrinsics = dataset[i]
            vp_pcd += dataset.create__pcd(rgb_image, depth_image, pose)
        vp_pcd = vp_pcd.voxel_down_sample(voxel_size=0.02)
        vp_pcd = pcd_denoise_dbscan(vp_pcd, eps=0.01, min_points=100)

        viewpoint = Viewpoint(vp, hovsg.floors[0].floor_id)
        viewpoint.pcd = vp_pcd
        viewpoint.vp_pos = np.asarray(vp2poses[vp])[0][:3]
        viewpoint.vp_images = rgb_path

        vp_xyz = np.asarray(vp_pcd.points)
        for obj in tqdm(objects, desc="Adding objects", position=1, leave=False):
            obj_pcd = obj.pcd
            obj_xyz = np.asarray(obj_pcd.points)
            if not aabb_overlap(vp_pcd, obj_pcd):
                continue

            overlapping_ratio = find_overlapping_ratio_faiss(vp_xyz, obj_xyz, radius=0.1)
            if overlapping_ratio > 0.1:
                viewpoint.add_object(obj)

        # print(viewpoint)
        viewpoint.save(vp_save_dir)
        all_viewpoints.append(viewpoint)
                    
    return all_viewpoints


def main():
    scan = "zsNo4HB9uLZ"
    mp3d_vp2poses = get_mp3d_poses(scan)
    create_mp3d_views(scan, mp3d_vp2poses)
    
    param_path="config/Nav3DSG.yaml"
    with open(param_path, "r") as f:
        params = yaml.safe_load(f)
    params = DictConfig(params)
    hovsg = Graph(params)
    hovsg.load_graph(f"../HOV-SG/data/scene_graphs/hm3dsem/{scan}/graph")

    create_vps(hovsg, mp3d_vp2poses, scan)

if __name__ == "__main__":
    main()