import dm_env
from bsuite.environments.cartpole import Cartpole as _Cartpole
from bsuite.environments.catch import Catch as _Catch
from bsuite.environments.mountain_car import MountainCar as _MountainCar


class Cartpole(_Cartpole):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.episode_id = 0
    self.episode_return = 0
    self.bsuite_id = "cartpole/0"

  def reset(self) -> dm_env.TimeStep:
    self.episode_id += 1
    self.episode_return = 0
    return super().reset()

  def step(self, action: int) -> dm_env.TimeStep:
    timestep = super().step(action)
    if timestep.reward is not None:
      self.episode_return += timestep.reward
    return timestep


class Catch(_Catch):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.episode_id = 0
    self.episode_return = 0
    self.bsuite_id = "catch/0"

  def _reset(self) -> dm_env.TimeStep:
    self.episode_id += 1
    self.episode_return = 0
    return super()._reset()

  def _step(self, action: int) -> dm_env.TimeStep:
    timestep = super()._step(action)
    if timestep.reward is not None:
      self.episode_return += timestep.reward
    return timestep


class MountainCar(_MountainCar):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.episode_id = 0
    self.episode_return = 0
    self.bsuite_id = "mountain_car/0"

  def _reset(self) -> dm_env.TimeStep:
    self.episode_id += 1
    self.episode_return = 0
    return super()._reset()

  def _step(self, action: int) -> dm_env.TimeStep:
    timestep = super()._step(action)
    if timestep.reward is not None:
      self.episode_return += timestep.reward
    return timestep
