from time import sleep

from gymnasium import Wrapper


class FruitTreeTeleportToTheStart(Wrapper):

    def __init__(self, env, terminate_after_n_episodes=1):
        super().__init__(env)

        self._env = env

        self._initial_position = None
        self._episode = 0

        self.infinite_horizon = False
        if terminate_after_n_episodes:
            self.max_episodes = terminate_after_n_episodes
        else:
            self.infinite_horizon = True

    def reset(self, **kwargs):
        self._reset_kwargs = kwargs
        self._episode = 1

        obs, info = self.env.reset(**kwargs)
        self._initial_position = obs.copy()
        return obs, info

    def step(self, action):

        obs, reward, terminated, truncated, info = self.env.step(action)

        self._reward_shape = reward.shape

        if terminated:
            if self.infinite_horizon or self._episode < self.max_episodes:
                self._episode += 1
                obs = self._initial_position.copy()
                self.unwrapped.current_state = self._initial_position.copy()
                if self.unwrapped.render_mode is not None:
                    sleep(1)
                    self.unwrapped.render()
                terminated = False

        return obs, reward, terminated, truncated, info
