import gym
import retro
import numpy as np


class StreetFighter2ActionWrapper(gym.ActionWrapper):
    """Wrap a gym-retro environment with discrete actions.

    Args:
        combos: ordered list of lists of valid button combinations
    """

    def __init__(self, env, players):
        super().__init__(env)
        combos = [[],
                  ['UP'],
                  ['DOWN'],
                  ['LEFT'],
                  ['UP', 'LEFT'],
                  ['DOWN', 'LEFT'],
                  ['RIGHT'],
                  ['UP', 'RIGHT'],
                  ['DOWN', 'RIGHT'],
                  ['B'],
                  ['B', 'DOWN'],
                  ['B', 'LEFT'],
                  ['B', 'RIGHT'],
                  ['A'],
                  ['A', 'DOWN'],
                  ['A', 'LEFT'],
                  ['A', 'RIGHT'],
                  ['C'],
                  ['DOWN', 'C'],
                  ['LEFT', 'C'],
                  ['RIGHT', 'C'],
                  ['Y'],
                  ['DOWN', 'Y'],
                  ['LEFT', 'Y'],
                  ['DOWN', 'LEFT', 'Y'],
                  ['RIGHT', 'Y'],
                  ['X'],
                  ['DOWN', 'X'],
                  ['LEFT', 'X'],
                  ['DOWN', 'LEFT', 'X'],
                  ['RIGHT', 'X'],
                  ['DOWN', 'RIGHT', 'X'],
                  ['Z'],
                  ['DOWN', 'Z'],
                  ['LEFT', 'Z'],
                  ['DOWN', 'LEFT', 'Z'],
                  ['RIGHT', 'Z'],
                  ['DOWN', 'RIGHT', 'Z']]
        assert isinstance(env.action_space, gym.spaces.MultiBinary)
        print("env.action_space:", env.action_space)
        self.players = players
        single_action_len = int(env.action_space.n / players)
        buttons = env.unwrapped.buttons  # ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
        self._decode_discrete_action = []
        self._combos = combos
        for combo in combos:
            arr = np.array([0] * single_action_len)
            for button in combo:
                arr[buttons.index(button)] = 1
            self._decode_discrete_action.append(arr)

        self.action_space = gym.spaces.Tuple(
            [gym.spaces.Discrete(len(self._decode_discrete_action)) for _ in range(players)]
        )

    def action(self, action):
        atomic_actions = []
        for i in range(self.players):
            atomic_actions.extend(self._decode_discrete_action[action[i]].copy())
        print("atomic_actions:", atomic_actions)
        return atomic_actions

    def reverse_action(self, action):
        actions = []
        for i in range(self.players):
            actions.extend(self._combos[action[i * 12:(i + 1) * 12]].copy())
        return actions


class BaseRetroEnv(gym.Env):
    """Wraps gym-retro env to be compatible RLlib."""

    def __init__(
            self,
            game="CustomStreetFighterIISpecialChampionEdition-Genesis",
            players=2,
            obs_type=retro.Observations.RAM,
    ):
        self._env = StreetFighter2ActionWrapper(
            env=retro.make(
                game=game,
                players=players,
                obs_type=obs_type,
                state=retro.State.NONE,
            ),
            players=players
        )
        self.action_space = self._env.action_space
        self.observation_space = self._env.observation_space

    def reset(self, **kwargs):
        """Reset the environment."""
        obs = self._env.reset()
        return obs

    def step(self, actions: list):
        """Steps in the environment."""
        obs, rew, done, info = self._env.step(actions)
        return obs, rew, done, info

    def close(self):
        """Close the environment."""
        self._env.close()

    def render(self, mode="human"):
        self._env.render()


def test():
    import time
    env = BaseRetroEnv()
    obs = env.reset()
    print("obs:", obs)  # obs.shape, type(obs))
    print("====:", env.action_space, env.observation_space)
    # env.action_space: MultiBinary(12 * num_players)
    # env.observation_space: Box(0, 255, (200, 256, 3), uint8) (IMAGE)
    # Box(0, 255, (65536,), uint8) (RAM)
    for _ in range(10000):
        act = env.action_space.sample()
        # print("act:", act)
        obs, rew, done, info = env.step([0, 1])
        # (200, 256, 3) 0.0 False {'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}
        env.render("human")
        if done:
            obs = env.reset()
            print("done:", done)
        # time.sleep(0.1)
    env.close()


if __name__ == "__main__":
    test()
