import torch


from dataclasses import dataclass
from typing import Any, Dict


@dataclass
class EnvState:
    rb: torch.Tensor
    dof: torch.Tensor
    gripper_closed: bool
    last_grasp: torch.Tensor
    T_ee_to_tool: torch.Tensor
    predicates: Dict[str, Any]

    def to_dict(self) -> Dict:
        return {
            "rb": self.rb.tolist(),
            "dof": self.dof.tolist(),
            "gripper_closed": self.gripper_closed,
            "last_grasp": self.last_grasp.tolist(),
            "T_ee_to_tool": self.T_ee_to_tool.tolist(),
            "predicates": self.predicates,
        }

    @staticmethod
    def from_dict(data: Dict):
        return EnvState(
            rb=torch.as_tensor(data["rb"]),
            dof=torch.as_tensor(data["dof"]),
            gripper_closed=data["gripper_closed"],
            last_grasp=torch.as_tensor(data["last_grasp"]),
            T_ee_to_tool=torch.as_tensor(data["T_ee_to_tool"]),
            predicates=data["predicates"],
        )