import gymnasium as gym
import numpy as np

from jax_rl.wrappers.common import TimeStep


class FetchRepeatAction(gym.Wrapper):
    def __init__(self, env, action_repeat=4):
        super().__init__(env)
        self._action_repeat = action_repeat

    def step(self, action: np.ndarray) -> TimeStep:
        total_reward = 0.0
        done = None
        combined_info = {}

        for _ in range(self._action_repeat):
            obs, reward, terminated, truncated, info = self.env.step(action)
            done = terminated or truncated
            total_reward += reward
            combined_info.update(info)
            if done:
                break

        return obs, total_reward, terminated, truncated, combined_info
