import gym
import torch


class TensorWrapper(gym.Wrapper):
    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return torch.tensor(obs, dtype=torch.float32).detach()

    def step(self, act) -> tuple[torch.Tensor, torch.Tensor, bool, dict]:
        obs, rews, dones, infos = self.env.step(act)
        obs = torch.tensor(obs, dtype=torch.float32).detach()
        rews = torch.tensor(rews, dtype=torch.float32).detach()
        return obs, rews, dones, infos
