# example from https://github.com/Baekalfen/PyBoy/wiki/Using-PyBoy-with-Gym
import gymnasium as gym
from gymnasium import spaces
import numpy as np

actions = ['up', 'down', 'left', 'right']

matrix_shape = (16, 20)
game_area_observation_space = spaces.Box(low=0, high=255, shape=matrix_shape, dtype=np.uint8)

class GenericPyBoyEnv(gym.Env):

    def __init__(self, pyboy, state_path, debug=False, max_episode_steps=64):
        super().__init__()
        self.pyboy = pyboy
        self.state_path = state_path
        self._fitness=0
        self._previous_fitness=0
        self.debug = debug
        self.max_episode_steps = max_episode_steps
        self.num_steps = 0

        if not self.debug:
            self.pyboy.set_emulation_speed(0)

        self.action_space = spaces.Discrete(len(actions))
        self.observation_space = game_area_observation_space

        self.pyboy.game_wrapper.start_game()

    def step(self, action):
        assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))
        print('DOING ACTION: ', actions[action])
        # Move the agent
        self.pyboy.button(actions[action])

        # Consider disabling renderer when not needed to improve speed:
        # self.pyboy.tick(1, False)
        self.pyboy.tick(60)
        # print(str(self.pyboy.game_wrapper))
        done = int(not 'True' in str(self.pyboy.game_wrapper))

        self.num_steps += 1

        if done:
            self.reset()
        reward = done

        info = {}
        truncated = False

        if self.num_steps >= self.max_episode_steps:
            info, done = {'TimeLimit.truncated': 0}, 1
            self.num_steps = 0
            self.reset()

        observation=self.pyboy.screen.image

        return [observation], np.array([reward]), [done], truncated, [info]

    def _calculate_fitness(self):
        self._previous_fitness=self._fitness

        # NOTE: Only some game wrappers will provide a score
        # If not, you'll have to investigate how to score the game yourself
        self._fitness=self.pyboy.game_wrapper.score

    def reset(self, **kwargs):
        with open(self.state_path, "rb") as f:
            self.pyboy.load_state(f)
        self._fitness=0
        self._previous_fitness=0
        self.num_steps = 0

        observation=self.pyboy.screen.image
        info = {}
        return [observation], [info]

    def render(self, mode='human'):
        pass

    def close(self):
        self.pyboy.stop()