from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from gr00t.data.schema import DatasetMetadata
class ModalityTransform(BaseModel, ABC):
    """
    Abstract class for transforming data modalities, e.g. video frame augmentation or action normalization.
    """
    apply_to: list[str] = Field(..., description="The keys to apply the transform to.")
    training: bool = Field(
        default=True, description="Whether to apply the transform in training mode."
    )
    _dataset_metadata: DatasetMetadata | None = PrivateAttr(default=None)
    model_config = ConfigDict(arbitrary_types_allowed=True)
    @property
    def dataset_metadata(self) -> DatasetMetadata:
        assert (
            self._dataset_metadata is not None
        ), "Dataset metadata is not set. Please call set_metadata() before calling apply()."
        return self._dataset_metadata
    @dataset_metadata.setter
    def dataset_metadata(self, value: DatasetMetadata):
        self._dataset_metadata = value
    def set_metadata(self, dataset_metadata: DatasetMetadata):
        """
        Set the dataset metadata. This is useful for transforms that need to know the dataset metadata, e.g. to normalize actions.
        Subclasses can override this method if they need to do something more complex.
        """
        self.dataset_metadata = dataset_metadata
    def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
        """Apply the transformation to the data corresponding to target_keys and return the processed data.
        Args:
            data (dict[str, Any]): The data to transform.
                example: data = {
                    "video.image_side_0": np.ndarray,
                    "action.eef_position": np.ndarray,
                    ...
                }
        Returns:
            dict[str, Any]: The transformed data.
                example: transformed_data = {
                    "video.image_side_0": np.ndarray,
                    "action.eef_position": torch.Tensor,  # Normalized and converted to tensor
                    ...
                }
        """
        return self.apply(data)
    @abstractmethod
    def apply(self, data: dict[str, Any]) -> dict[str, Any]:
        """Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
    def train(self):
        self.training = True
    def eval(self):
        self.training = False
class InvertibleModalityTransform(ModalityTransform):
    @abstractmethod
    def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
        """Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
class ComposedModalityTransform(ModalityTransform):
    """Compose multiple modality transforms."""
    transforms: list[ModalityTransform] = Field(..., description="The transforms to compose.")
    apply_to: list[str] = Field(
        default_factory=list, description="Will be ignored for composed transforms."
    )
    training: bool = Field(
        default=True, description="Whether to apply the transform in training mode."
    )
    model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True)
    def set_metadata(self, dataset_metadata: DatasetMetadata):
        for transform in self.transforms:
            transform.set_metadata(dataset_metadata)
    def apply(self, data: dict[str, Any]) -> dict[str, Any]:
        for i, transform in enumerate(self.transforms):
            try:
                data = transform(data)
            except Exception as e:
                raise ValueError(f"Error applying transform {i} to data: {e}") from e
        return data
    def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
        for i, transform in enumerate(reversed(self.transforms)):
            if isinstance(transform, InvertibleModalityTransform):
                try:
                    data = transform.unapply(data)
                except Exception as e:
                    step = len(self.transforms) - i - 1
                    raise ValueError(f"Error unapplying transform {step} to data: {e}") from e
        return data
    def train(self):
        for transform in self.transforms:
            transform.train()
    def eval(self):
        for transform in self.transforms:
            transform.eval()
