import gym
import math
import numpy as np


class ActionRepeatWrapper(gym.Wrapper):
    def __init__(self, env, k=1, return_history=True, discount=1.0):
        super().__init__(env)
        self.k = k
        self.return_history = return_history
        self.observation_space.shape = (self.observation_space.shape[0] + 1,)
        self.discount = discount

    def step(self, action):
        reward_history = np.zeros((self.k,))
        done = False
        for step in range(self.k):
            if not done:
                next_state, reward, done, _ = self.env.step(action)
                reward_history[step] = reward * self.discount ** step
            else:
                reward_history[step] = 0
        if self.return_history:
            return self.obs(next_state), reward_history, done, {}
        else:
            return self.obs(next_state), reward_history.sum(), done, {}

    def obs(self, state):
        return np.concatenate([np.array((self.k,)), state], axis=0)

    def reset(self):
        return self.obs(self.env.reset())

    def set_k(self, k):
        assert k > 0, "attempted to set action repeat <= 0"
        self.k = k
