from enum import Enum
from typing import Optional
from numpydantic import NDArray
from pydantic import BaseModel, Field, field_serializer
from .embodiment_tags import EmbodimentTag
class RotationType(Enum):
    """Type of rotation representation"""
    AXIS_ANGLE = "axis_angle"
    QUATERNION = "quaternion"
    ROTATION_6D = "rotation_6d"
    MATRIX = "matrix"
    EULER_ANGLES_RPY = "euler_angles_rpy"
    EULER_ANGLES_RYP = "euler_angles_ryp"
    EULER_ANGLES_PRY = "euler_angles_pry"
    EULER_ANGLES_PYR = "euler_angles_pyr"
    EULER_ANGLES_YRP = "euler_angles_yrp"
    EULER_ANGLES_YPR = "euler_angles_ypr"
class LeRobotModalityField(BaseModel):
    """Metadata for a LeRobot modality field."""
    original_key: Optional[str] = Field(
        default=None,
        description="The original key of the modality in the LeRobot dataset",
    )
class LeRobotStateActionMetadata(LeRobotModalityField):
    """Metadata for a LeRobot modality."""
    start: int = Field(
        ...,
        description="The start index of the modality in the concatenated state/action vector",
    )
    end: int = Field(
        ...,
        description="The end index of the modality in the concatenated state/action vector",
    )
    rotation_type: Optional[RotationType] = Field(
        default=None, description="The type of rotation for the modality"
    )
    absolute: bool = Field(default=True, description="Whether the modality is absolute")
    dtype: str = Field(
        default="float64",
        description="The data type of the modality. Defaults to float64.",
    )
    range: Optional[tuple[float, float]] = Field(
        default=None,
        description="The range of the modality, if applicable. Defaults to None.",
    )
    original_key: Optional[str] = Field(
        default=None,
        description="The original key of the modality in the LeRobot dataset.",
    )
class LeRobotStateMetadata(LeRobotStateActionMetadata):
    """Metadata for a LeRobot state modality."""
    original_key: Optional[str] = Field(
        default="observation.state",  # LeRobot convention for states
        description="The original key of the state modality in the LeRobot dataset",
    )
class LeRobotActionMetadata(LeRobotStateActionMetadata):
    """Metadata for a LeRobot action modality."""
    original_key: Optional[str] = Field(
        default="action",  # LeRobot convention for actions
        description="The original key of the action modality in the LeRobot dataset",
    )
class LeRobotModalityMetadata(BaseModel):
    """Metadata for a LeRobot modality."""
    state: dict[str, LeRobotStateMetadata] = Field(
        ...,
        description="The metadata for the state modality. The keys are the names of each split of the state vector.",
    )
    action: dict[str, LeRobotActionMetadata] = Field(
        ...,
        description="The metadata for the action modality. The keys are the names of each split of the action vector.",
    )
    video: dict[str, LeRobotModalityField] = Field(
        ...,
        description="The metadata for the video modality. The keys are the new names of each video modality.",
    )
    annotation: Optional[dict[str, LeRobotModalityField]] = Field(
        default=None,
        description="The metadata for the annotation modality. The keys are the new names of each annotation modality.",
    )
    def get_key_meta(self, key: str) -> LeRobotModalityField:
        """Get the metadata for a key in the LeRobot modality metadata.
        Args:
            key (str): The key to get the metadata for.
        Returns:
            LeRobotModalityField: The metadata for the key.
        Example:
            lerobot_modality_meta = LeRobotModalityMetadata.model_validate(U.load_json(modality_meta_path))
            lerobot_modality_meta.get_key_meta("state.joint_shoulder_y")
            lerobot_modality_meta.get_key_meta("video.main_camera")
            lerobot_modality_meta.get_key_meta("annotation.human.action.task_description")
        """
        split_key = key.split(".")
        modality = split_key[0]
        subkey = ".".join(split_key[1:])
        if modality == "state":
            if subkey not in self.state:
                raise ValueError(
                    f"Key: {key}, state key {subkey} not found in metadata, available state keys: {self.state.keys()}"
                )
            return self.state[subkey]
        elif modality == "action":
            if subkey not in self.action:
                raise ValueError(
                    f"Key: {key}, action key {subkey} not found in metadata, available action keys: {self.action.keys()}"
                )
            return self.action[subkey]
        elif modality == "video":
            if subkey not in self.video:
                raise ValueError(
                    f"Key: {key}, video key {subkey} not found in metadata, available video keys: {self.video.keys()}"
                )
            return self.video[subkey]
        elif modality == "annotation":
            assert (
                self.annotation is not None
            ), "Trying to get annotation metadata for a dataset with no annotations"
            if subkey not in self.annotation:
                raise ValueError(
                    f"Key: {key}, annotation key {subkey} not found in metadata, available annotation keys: {self.annotation.keys()}"
                )
            return self.annotation[subkey]
        else:
            raise ValueError(f"Key: {key}, unexpected modality: {modality}")
class DatasetStatisticalValues(BaseModel):
    max: NDArray = Field(..., description="Maximum values")
    min: NDArray = Field(..., description="Minimum values")
    mean: NDArray = Field(..., description="Mean values")
    std: NDArray = Field(..., description="Standard deviation")
    q01: NDArray = Field(..., description="1st percentile values")
    q99: NDArray = Field(..., description="99th percentile values")
    @field_serializer("*", when_used="json")
    def serialize_ndarray(self, v: NDArray) -> list[float]:
        return v.tolist()
class DatasetStatistics(BaseModel):
    state: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the state")
    action: dict[str, DatasetStatisticalValues] = Field(..., description="Statistics of the action")
class VideoMetadata(BaseModel):
    """Metadata of the video modality"""
    resolution: tuple[int, int] = Field(..., description="Resolution of the video")
    channels: int = Field(..., description="Number of channels in the video", gt=0)
    fps: float = Field(..., description="Frames per second", gt=0)
class StateActionMetadata(BaseModel):
    absolute: bool = Field(..., description="Whether the state or action is absolute")
    rotation_type: Optional[RotationType] = Field(None, description="Type of rotation, if any")
    shape: tuple[int, ...] = Field(..., description="Shape of the state or action")
    continuous: bool = Field(..., description="Whether the state or action is continuous")
class DatasetModalities(BaseModel):
    video: dict[str, VideoMetadata] = Field(..., description="Metadata of the video")
    state: dict[str, StateActionMetadata] = Field(..., description="Metadata of the state")
    action: dict[str, StateActionMetadata] = Field(..., description="Metadata of the action")
class DatasetMetadata(BaseModel):
    """Metadata of the trainable dataset
    Changes:
        - Update to use the new RawCommitHashMetadataMetadata_V1_2
    """
    statistics: DatasetStatistics = Field(..., description="Statistics of the dataset")
    modalities: DatasetModalities = Field(..., description="Metadata of the modalities")
    embodiment_tag: EmbodimentTag = Field(..., description="Embodiment tag of the dataset")
