import numpy as np
import torch as t
from typing import Any, Dict, Tuple, Union
from pyquaternion import Quaternion

from rise import *  # noqa: F401,F403 (project convention)
from sim.env.env_for_latent_conditioned_moe import RiseEnvForLatentConditionedMoE


class RiseEnvForLatentConditionedMoEHeightReward(RiseEnvForLatentConditionedMoE):

    @staticmethod
    def _segment1_pose(
        structure_record: RS_StructureRecord,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Return (position, orientation) for the rigid body with segment_bid == 1.
        Position is rigid body COM in world frame.
        Orientation is quaternion [x, y, z, w] in world frame.
        """
        seg_mask = structure_record.rigid_body_segment_bid() == 1
        if not np.any(seg_mask):
            # Fallback: zeros (keeps reward pipeline robust if the segment is absent)
            return np.zeros(3, dtype=np.float32), np.array(
                [0, 0, 0, 1], dtype=np.float32
            )

        pos = structure_record.rigid_body_com()[seg_mask][0].astype(
            np.float32, copy=False
        )
        quat = structure_record.rigid_body_orientation()[seg_mask][0].astype(
            np.float32, copy=False
        )
        return pos, quat

    def get_observation_processor(
        self,
        build_voxel_observations: bool = True,
        build_kinematic_graph: bool = True,
        build_reward_state: bool = True,
        build_velocity_observations: bool = True,
    ):
        def process_observation(structure_controller, prev_com, robot_latents):
            structure_name = structure_controller.name()

            prev_com_structure = prev_com.get(structure_name)
            robot_latent_structure = robot_latents.get(structure_name)
            if robot_latent_structure is None:
                raise ValueError(f"Missing latent for robot {structure_name}")

            structure_record = structure_controller.structure_record()

            observation: Dict[str, Any] = {}
            reward_state: Dict[str, Any] = {}

            # shape [3]
            com = np.mean(structure_record.voxel_position(), axis=0)

            # Attach per-robot latent vector for gating, shape [1, latent_dim]
            observation["robot_latent"] = robot_latent_structure

            if build_voxel_observations:
                relative_voxel_positions, pressures = (
                    RiseEnvForLatentConditionedMoE.build_voxel_observations(
                        structure_record, com
                    )
                )
                observation["relative_voxel_positions"] = relative_voxel_positions
                observation["voxel_features"] = pressures
                observation["com"] = t.from_numpy(com).unsqueeze(0)

            if build_velocity_observations:
                if prev_com_structure is not None:
                    velocity = observation["com"] - prev_com_structure
                else:
                    velocity = t.zeros_like(observation["com"])
                observation["velocity"] = velocity
                prev_com[structure_name] = observation["com"].clone()

            if build_kinematic_graph:
                (
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                ) = RiseEnvForLatentConditionedMoE.build_kinematic_graph(
                    structure_record, com
                )
                observation["relative_node_positions"] = relative_node_positions
                observation["node_features"] = node_features
                observation["edges"] = edges
                observation["edge_features"] = edge_features

            if build_reward_state:
                # numpy, shape [voxel_num, 3]
                reward_state["voxel_positions"] = structure_record.voxel_position()
                # numpy, shape [3]
                reward_state["com"] = com

                # Add second segment pose (segment bid == 1)
                seg_pos, seg_quat = self._segment1_pose(structure_record)
                reward_state["second_segment_position"] = seg_pos
                reward_state["second_segment_orientation"] = seg_quat

                seg_raw_quat = Quaternion(
                    x=seg_quat[0],
                    y=seg_quat[1],
                    z=seg_quat[2],
                    w=seg_quat[3],
                )
                seg_rotation_matrix = seg_raw_quat.rotation_matrix
                world_up_vector = np.array([0, 0, 1], dtype=np.float32)
                body_up_vector = seg_rotation_matrix @ world_up_vector
                reward_state["body_up_vector"] = body_up_vector

            return (
                structure_name,
                observation,
                reward_state,
            )

        return process_observation
