import embodied
import numpy as np


class Dummy(embodied.Env):
    def __init__(self, task, size=(64, 64), length=100):
        assert task in ("cont", "disc")
        self._task = task
        self._size = size
        self._length = length
        self._step = 0
        self._done = False

    @property
    def obs_space(self):
        return {
            "image": embodied.Space(np.uint8, self._size + (3,)),
            "vector": embodied.Space(np.float32, (7,)),
            "step": embodied.Space(np.int32, (), 0, self._length),
            "reward": embodied.Space(np.float32),
            "is_first": embodied.Space(bool),
            "is_last": embodied.Space(bool),
            "is_terminal": embodied.Space(bool),
        }

    @property
    def act_space(self):
        if self._task == "cont":
            space = embodied.Space(np.float32, (6,))
        else:
            space = embodied.Space(np.int32, (), 0, 5)
        return {"action": space, "reset": embodied.Space(bool)}

    def step(self, action):
        if action["reset"] or self._done:
            self._step = 0
            self._done = False
            return self._obs(0.0, is_first=True)
        action = action["action"]
        self._step += 1
        self._done = self._step >= self._length
        return self._obs(1.0, is_last=self._done, is_terminal=self._done)

    def _obs(self, reward, is_first=False, is_last=False, is_terminal=False):
        return dict(
            image=np.zeros(self._size + (3,), np.uint8),
            vector=np.zeros(7, np.float32),
            step=self._step,
            reward=reward,
            is_first=is_first,
            is_last=is_last,
            is_terminal=is_terminal,
        )
