from copy import deepcopy

import gym
import numpy as np
from gym.spaces import Discrete, Dict, Box


class CartPole:
    """
    Wrapper for gym CartPole environment where the reward
    is accumulated to the end
    """

    def __init__(self, config=None):
        self.env = gym.make("CartPole-v0")
        self.action_space = Discrete(2)
        self.observation_space = Dict({
            "obs": self.env.observation_space,
            "action_mask": Box(low=0, high=1, shape=(self.action_space.n, ))
        })
        self.running_reward = 0

    def reset(self):
        self.running_reward = 0
        return {"obs": self.env.reset(), "action_mask": np.array([1, 1])}

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        self.running_reward += rew
        score = self.running_reward if done else 0
        return {"obs": obs, "action_mask": np.array([1, 1])}, score, done, info

    def set_state(self, state):
        self.running_reward = state[1]
        self.env = deepcopy(state[0])
        obs = np.array(list(self.env.unwrapped.state))
        return {"obs": obs, "action_mask": np.array([1, 1])}

    def get_state(self):
        return deepcopy(self.env), self.running_reward
