import jax.numpy as jnp
import jax
from typing import List, Optional
import numpy.typing as npt
import flax
import numpy as np

import util.latent_obj_util as loutil

@flax.struct.dataclass
class ImgFeatures:
    intrinsic:jnp.ndarray # (... 6) [px, py, fx, fy, cx, cy]
    cam_posquat:jnp.ndarray # (... 7) pos-3 / quat-4
    img_feat:jnp.ndarray=None # (..., pi, pj, nf)
    dino_feat:jnp.ndarray=None # (..., pi, pj, nf)
    img_feat_patch:jnp.ndarray=None
    img_state:jnp.ndarray=None # (..., pi, pj, nf)
    spatial_PE:jnp.ndarray=None # (..., pi, pj, nf)
    rgb:jnp.ndarray=None

    def __getitem__(self, idx: jnp.ndarray):
        """Convenient indexing for dataclass"""
        return jax.tree_map(lambda x: x[idx], self) 


@flax.struct.dataclass
class EnvInfo:
    EnvInfo_obj_posquats: npt.NDArray           # [..., #obj, 7]
    EnvInfo_scale: float
    EnvInfo_uid_list: List[int]
    EnvInfo_mesh_name: str|None

@flax.struct.dataclass
class SceneData:
    @flax.struct.dataclass
    class NVRenInfo:
        obj_posquats: npt.NDArray           # [..., #obj, 7]
        scales: npt.NDArray
        mesh_name: str
    cam_posquats: npt.NDArray           # [..., #cam, 7]
    cam_intrinsics: npt.NDArray         # [..., #cam, 6]. (W, H, Fx, Fy, Cx, Cy)
    ObjInfo_obj_posquats: npt.NDArray           # [..., #obj, 7]
    ObjInfo_scale: float
    ObjInfo_uid_list: List[int]
    ObjInfo_mesh_name: str|None
    EnvInfo_obj_posquats: npt.NDArray           # [..., #obj, 7]
    EnvInfo_scale: float
    EnvInfo_uid_list: List[int]
    EnvInfo_mesh_name: str|None
    RobotInfo_posquat: npt.NDArray
    RobotInfo_q: npt.NDArray
    rgbs: npt.NDArray|None                  # [..., #cam, H, W, 3]
    depths: npt.NDArray|None                # [..., #cam, H, W]
    seg: npt.NDArray|None                # [..., #cam, H, W]
    # table_params: npt.NDArray               # [..., 3]
    # robot_params: npt.NDArray
    nvren_info: NVRenInfo|None

    @property
    def EnvInfo(self):
        return EnvInfo(self.EnvInfo_obj_posquats, self.EnvInfo_scale, self.EnvInfo_uid_list, self.EnvInfo_mesh_name)


@flax.struct.dataclass
class ColDataset:
    obj_idx: jax.Array
    obj_scale: jax.Array
    obj_pos: jax.Array
    obj_quat: jax.Array
    col_gt: jax.Array
    distance: jax.Array
    min_direction: jax.Array
    fps_col_labels: jax.Array=None

    def __getitem__(self, idx: jnp.ndarray) -> "ColDataset":
        """Convenient indexing for dataclass"""
        return jax.tree_util.tree_map(lambda x: x[idx], self)

    def __len__(self):
        return self.obj_idx.shape[0]

    def make_latent_obj(input: "ColDataset", latent_obj_list:loutil.LatentObjects, rot_configs, latent_obj=None)-> loutil.LatentObjects:
        if latent_obj is None:
            latent_obj: loutil.LatentObjects = latent_obj_list[input.obj_idx]
            latent_obj = latent_obj.init_pos_zero()
        latent_obj = latent_obj.apply_scale(input.obj_scale)
        if latent_obj.rel_pcd is None:
            latent_obj = latent_obj.apply_pq_z(input.obj_pos, input.obj_quat, rot_configs)
        else:
            latent_obj = latent_obj.apply_pq_pcd(input.obj_pos, input.obj_quat)
        return latent_obj

    def visualize_fps_col_labels(self, latent_obj_list:loutil.LatentObjects, rot_configs):
        import open3d as o3d
        latent_obj = ColDataset.make_latent_obj(self, latent_obj_list, rot_configs)

        for i in range(latent_obj.shape[0]):
            cur_latent_obj = latent_obj[i]
            cur_fps_col_label = self.fps_col_labels[i]
            pcd1 = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cur_latent_obj[0].fps_tf)).paint_uniform_color([1,0,0])
            pcd1.colors = o3d.utility.Vector3dVector(np.where(cur_fps_col_label[0][...,None] ,np.array([0.5,0,1]), np.array([1,0,0])))
            pcd2 = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cur_latent_obj[1].fps_tf)).paint_uniform_color([0,1,0])
            pcd2.colors = o3d.utility.Vector3dVector(np.where(cur_fps_col_label[1][...,None], np.array([0,0.5,1]), np.array([0,1,0])))
            o3d.visualization.draw_geometries([pcd1, pcd2])


@flax.struct.dataclass
class OccDataset:
    query_points: jax.Array
    signed_distance: jax.Array
    normal_vector: jax.Array
    rays: jax.Array
    t_hit: jax.Array
    surface_points: Optional[jax.Array]=None

    def __getitem__(self, idx: jnp.ndarray) -> "ColDataset":
        """Convenient indexing for dataclass"""
        return jax.tree_util.tree_map(lambda x: x[idx], self)

    def __len__(self):
        return self.obj_idx.shape[0]
    

@flax.struct.dataclass
class LossArgs:
    fixed_oriCORNs: loutil.LatentObjects
    moving_oriCORNs: Optional[loutil.LatentObjects]=None
    ee_to_obj_pq: Optional[jnp.ndarray]=None
    plane_params: Optional[jnp.ndarray]=None
    fixed_moving_idx_pair: Optional[jnp.ndarray]=None
    moving_spheres: Optional[List[jnp.ndarray]]=None
    mesh_ids: Optional[jnp.ndarray]=None