from __future__ import annotations

import logging
import math
from typing import Callable, Sequence

import gymnasium as gym
import gymnasium.spaces as spaces
import mujoco
import numpy as np
import open3d as o3d
from open3d.geometry import PointCloud

from .. import PointCloudSpace

logger = logging.getLogger(__name__)


STATE_KEY = "state"


class MetaworldPointCloudObservations(gym.ObservationWrapper):
    def __init__(
        self,
        env: gym.Env,
        only_cameras: Sequence[str] | None = None,
        exclude_cameras: Sequence[str] | None = None,
        depth_cutoff: float = 1000,
        max_expected_num_points: int | None = None,
        camera_resolution: tuple[int, int] | None = None,
        color: bool = True,
        points_only: bool = True,
        points_key: str = "points",
        post_processing: list[Callable[[PointCloud], PointCloud]] | None = None,
        debug_camera_origins: bool = False,
        debug_o3d_pcd: bool = False,
    ) -> None:
        super().__init__(env)

        self.depth_cutoff = depth_cutoff
        self.color = color
        self.points_only = points_only
        self.points_key = points_key
        self.post_processing = post_processing
        self.debug_camera_origins = debug_camera_origins
        self.debug_o3d_pcd = debug_o3d_pcd

        if not self.color:
            raise NotImplementedError(self.color)

        # mujoco Model object, stores most simulation information
        mj_model = self.env.unwrapped.model

        self.num_cameras: int = mj_model.ncam
        if only_cameras is not None and exclude_cameras is not None:
            raise ValueError("Specify only one of `only_cameras` or `exclude_cameras`.")
        elif only_cameras is not None:
            self.camera_ids = [mj_model.cam(name).id for name in only_cameras]
        elif exclude_cameras is not None:
            self.camera_ids = [
                i
                for i in range(self.num_cameras)
                if mj_model.cam(i).name not in exclude_cameras
            ]
        else:
            self.camera_ids = range(self.num_cameras)
        logger.info(
            "Using the following cameras for point cloud generation: "
            f"{[mj_model.cam(i).name for i in self.camera_ids]}"
        )

        width: int = mj_model.vis.global_.offwidth
        height: int = mj_model.vis.global_.offheight

        # precompute intrinsics for all cameras
        self.camera_intrinsics = {}
        for camera_id in self.camera_ids:
            fovy = mj_model.cam(camera_id).fovy
            half_fovy_rad = math.radians(fovy / 2)
            # use aspect ratio and vertical fov to compute horizontal fov
            # https://en.wikipedia.org/wiki/Field_of_view_in_video_games
            half_fovx_rad = math.atan(math.tan(half_fovy_rad) * width / height)
            # focal length can be computed from image height and vertical FOV
            # https://www.edmundoptics.com/knowledge-center/application-notes/imaging/understanding-focal-length-and-field-of-view/
            fy = height / (2 * math.tan(half_fovy_rad))
            fx = width / (2 * math.tan(half_fovx_rad))
            intrinsic = o3d.camera.PinholeCameraIntrinsic(
                width=width,
                height=height,
                fx=fx,
                fy=fy,
                cx=width / 2,
                cy=height / 2,
            )
            self.camera_intrinsics[camera_id] = intrinsic

        if max_expected_num_points is None:
            max_expected_num_points = self.num_cameras * height * width

        pointcloud_space = PointCloudSpace(
            max_expected_num_points=max_expected_num_points,
            low=-np.float32("inf"),
            high=np.float32("inf"),
            feature_shape=(6,) if self.color else (3,),
        )

        if self.points_only:
            self.observation_space = pointcloud_space
        else:
            self.observation_space = spaces.Dict(
                {
                    STATE_KEY: env.observation_space,
                    self.points_key: pointcloud_space,
                }
            )

    def observation(self, observation: np.ndarray | dict) -> PointCloud | dict:
        pcd = self.pointcloud()

        if self.post_processing is not None:
            for func in self.post_processing:
                pcd = func(pcd)

        if not self.debug_o3d_pcd:
            pos = np.asarray(pcd.points, dtype=self.observation_space.dtype)
            if self.color:
                colors = np.asarray(pcd.colors, dtype=self.observation_space.dtype)
                pcd = np.concatenate((pos, colors), axis=-1)
            else:
                pcd = pos

        if self.points_only:
            return pcd
        else:
            return {
                STATE_KEY: observation,
                self.points_key: pcd,
            }

    def pointcloud(self) -> PointCloud:
        mj_model = self.env.unwrapped.model
        mj_data = self.env.unwrapped.data
        mj_renderer = self.env.unwrapped.mujoco_renderer

        combined_pointcloud = o3d.geometry.PointCloud()
        for camera_id in self.camera_ids:
            rgb = mj_renderer.render(render_mode="rgb_array", camera_id=camera_id)
            depth = mj_renderer.render(render_mode="depth_array", camera_id=camera_id)

            # convert depth map to meters
            depth = convert_zbuffer_to_distance(depth, mj_model)

            # convert to open3d point cloud
            rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
                color=o3d.geometry.Image(np.ascontiguousarray(rgb)),
                depth=o3d.geometry.Image(np.ascontiguousarray(depth)),
                # depth is in meters, no need to rescale
                depth_scale=1.0,
                depth_trunc=self.depth_cutoff,
                convert_rgb_to_intensity=False,
            )

            intrinsic = self.camera_intrinsics[camera_id]

            camera = mj_data.cam(camera_id)
            camera_position = camera.xpos
            camera_rotation = camera.xmat.reshape((3, 3))
            extrinsic = compute_camera_extrinics(camera_rotation, camera_position)

            pc = o3d.geometry.PointCloud.create_from_rgbd_image(
                rgbd,
                intrinsic,
                extrinsic=extrinsic,
            )

            if self.debug_camera_origins:
                # add camera origin as a green point
                camera_origin = o3d.geometry.PointCloud()
                points = camera_position[None, :]
                colors = np.zeros_like(points)
                colors[:, 1] = 1.0  # set to green
                camera_origin.points = o3d.utility.Vector3dVector(points)
                camera_origin.colors = o3d.utility.Vector3dVector(colors)
                pc += camera_origin

            # TODO: this concatenation step can optimized
            combined_pointcloud += pc

        if self.debug_camera_origins:
            world_origin = o3d.geometry.PointCloud()
            points = np.array(
                [
                    [0, 0, 0],  # origin
                ]
            )
            colors = np.zeros_like(points)
            colors[:, 2] = 1.0  # set to blue
            world_origin.points = o3d.utility.Vector3dVector(points)
            world_origin.colors = o3d.utility.Vector3dVector(colors)
            combined_pointcloud += world_origin

        return combined_pointcloud


def convert_zbuffer_to_distance(
    depth: np.ndarray,
    mj_model: mujoco.MjModel,
) -> np.ndarray:
    # Convert from [0 1] from opengl depth buffer to depth in m
    # References:
    # https://github.com/google-deepmind/mujoco/blob/2.3.7/python/mujoco/renderer.py#L171C15-L171C15
    # http://stackoverflow.com/a/6657284/1461210
    # https://www.khronos.org/opengl/wiki/Depth_Buffer_Precision
    extent = mj_model.stat.extent
    near_plane = mj_model.vis.map.znear * extent
    far_plane = mj_model.vis.map.zfar * extent
    true_depth = near_plane / (1 - depth * (1 - near_plane / far_plane))
    return true_depth


def compute_camera_extrinics(
    cam_rotation: np.ndarray,
    cam_position: np.ndarray,
) -> np.ndarray:
    # We need to rotate the point cloud by 180 degrees to that it aligns with
    # what we see in the depth image. This has no effect on the point cloud,
    # due to symmetry, but aligns the point cloud according to intuition.
    # Reference: https://github.com/mattcorsaro1/mj_pc/blob/main/mj_point_clouds.py
    rotate_about_x_axis = o3d.geometry.get_rotation_matrix_from_quaternion([0, 1, 0, 0])
    cam_rotation = np.matmul(cam_rotation, rotate_about_x_axis)

    # the extrinsic matrix is that it describes how the world is transformed relative to the camera
    # this is described with an affine transform, which is a rotation followed by a translation
    # https://ksimek.github.io/2012/08/22/extrinsic/
    assert cam_rotation.shape == (3, 3)
    assert cam_position.shape == (3,)
    inverse_rotation = cam_rotation.T
    extrinsic_matrix = np.identity(4)
    extrinsic_matrix[:3, :3] = inverse_rotation
    extrinsic_matrix[:3, 3] = -inverse_rotation @ cam_position
    return extrinsic_matrix
