"""
Environments and wrappers for Sonic training.
"""

import gym
import numpy as np

from spinup.sonic_utils.wrapper_utils import WarpFrame, FrameStack

import retro
import retro_contest
import gym.wrappers

import random

def make_singleLevel_sth_env(level=None, stack=True, scale_rew=True):
    env = retro.make('SonicTheHedgehog-Genesis', scenario='contest')
    if level is None:
        env_list = ['GreenHillZone.Act1']
    else:
        env_list = [level]
    env = SonicLevelPicker(env, random=True, levels=env_list)
    env = SonicDiscretizer(env)
    if scale_rew:
        env = RewardScaler(env)
    env = retro_contest.StochasticFrameSkip(env, n=4, stickprob=0.05)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=2700)
    env = WarpFrame(env)
    env = AllowBacktracking(env)
    if stack:
        env = FrameStack(env,4)
    return env


def make_sth_env(stack=True, scale_rew=True):
    env = retro.make('SonicTheHedgehog-Genesis', scenario='contest')
    env_list = ['SpringYardZone.Act3', 'SpringYardZone.Act2',\
               'GreenHillZone.Act3', 'GreenHillZone.Act1',\
               'StarLightZone.Act2', 'StarLightZone.Act1',\
               'MarbleZone.Act2', 'MarbleZone.Act1', 'MarbleZone.Act3',\
               'ScrapBrainZone.Act2', 'LabyrinthZone.Act2', 'LabyrinthZone.Act1', 'LabyrinthZone.Act3']
    env = SonicLevelPicker(env, random=True, levels=env_list)
    env = SonicDiscretizer(env)
    if scale_rew:
        env = RewardScaler(env)
    env = retro_contest.StochasticFrameSkip(env, n=4, stickprob=0.05)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=2700)
    env = WarpFrame(env)
    env = AllowBacktracking(env)
    # env = FrameStackEnv(env, num_images=4, concat=False)
    if stack:
        env = FrameStack(env,4)
    return env

def make_sth2_env(stack=True, scale_rew=True):
    env = retro.make('SonicTheHedgehog2-Genesis', scenario='contest')
    env_list = [ 'EmeraldHillZone.Act1', 'EmeraldHillZone.Act2', 'ChemicalPlantZone.Act2', 'ChemicalPlantZone.Act1', \
                 'MetropolisZone.Act1', 'MetropolisZone.Act2', 'OilOceanZone.Act1', 'MysticCaveZone.Act1', \
                 'HillTopZone.Act1', 'CasinoNightZone.Act1', 'WingFortressZone', 'AquaticRuinZone.Act2', \
                 'AquaticRuinZone.Act1']
    env = SonicLevelPicker(env, random=True, levels=env_list)
    env = SonicDiscretizer(env)
    if scale_rew:
        env = RewardScaler(env)
    env = retro_contest.StochasticFrameSkip(env, n=4, stickprob=0.05)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=2700)
    env = WarpFrame(env)
    env = AllowBacktracking(env)
    # env = FrameStackEnv(env, num_images=4, concat=False)
    if stack:
        env = FrameStack(env,4)
    return env

class SonicLevelPicker(gym.Wrapper):
    """
    Wrap the sonic gym retro environment to add level picking functionality upon reset
    -- Use immediately on top of retro.make
    """
    def __init__(self, env, random=False, levels=None):
        super(SonicLevelPicker, self).__init__(env)
        self.random = random
        self.levels = levels
        if random and levels is None:
            self.levels = retro.data.list_states(env.gamename)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        return obs, rew, done, info

    def reset(self, state_nm=None, **kwargs):
        if state_nm is not None:
            self.env.load_state(state_nm)
        elif self.random:
            self.env.load_state(random.choice(self.levels))
        return self.env.reset(**kwargs)

"""
Note to self: Do game picker as NOT a gym wrapper. Makes more sense that way so that the game can be killed as necessary
when switching games - otherwise there might be issues.
levels = retro.data.list_states('SonicTheHedgehog-Genesis')
"""

class SonicDiscretizer(gym.ActionWrapper):
    """
    Wrap a gym-retro environment and make it use discrete
    actions for the Sonic game.
    """
    def __init__(self, env):
        super(SonicDiscretizer, self).__init__(env)
        buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"]
        actions = [['LEFT'], ['RIGHT'], ['LEFT', 'DOWN'], ['RIGHT', 'DOWN'], ['DOWN'],
                   ['DOWN', 'B'], ['B']]
        self._actions = []
        for action in actions:
            arr = np.array([False] * 12)
            for button in action:
                arr[buttons.index(button)] = True
            self._actions.append(arr)
        self.action_space = gym.spaces.Discrete(len(self._actions))

    def action(self, a): # pylint: disable=W0221
        if np.isscalar(a):
            a_ = a
        else:
            a_ = a[0]
        return self._actions[a_].copy()

class RewardScaler(gym.RewardWrapper):
    """
    Bring rewards to a reasonable scale for PPO.
    This is incredibly important and effects performance
    drastically.
    """
    def reward(self, reward):
        return reward * 0.01

class AllowBacktracking(gym.Wrapper):
    """
    Use deltas in max(X) as the reward, rather than deltas
    in X. This way, agents are not discouraged too heavily
    from exploring backwards if there is no way to advance
    head-on in the level.
    """
    def __init__(self, env):
        super(AllowBacktracking, self).__init__(env)
        self._cur_x = 0
        self._max_x = 0

    def reset(self, **kwargs): # pylint: disable=E0202
        self._cur_x = 0
        self._max_x = 0
        return self.env.reset(**kwargs)

    def step(self, action): # pylint: disable=E0202
        obs, rew, done, info = self.env.step(action)
        self._cur_x += rew
        rew = max(0, self._cur_x - self._max_x)
        self._max_x = max(self._max_x, self._cur_x)
        return obs, rew, done, info