
import numpy as np

#####################################################################################################################
# Constants
#
#####################################################################################################################
shot_cool_down = 5
enemy_move_interval = 12
enemy_shot_interval = 10


#####################################################################################################################
# Env
#
# The player controls a cannon at the bottom of the screen and can shoot bullets upward at a cluster of aliens above.
# The aliens move across the screen until one of them hits the edge, at which point they all move down and switch
# directions. The current alien direction is indicated by 2 channels (one for left and one for right) one of which is
# active at the location of each alien. A reward of +1 is given each time an alien is shot, and that alien is also
# removed. The aliens will also shoot bullets back at the player. When few aliens are left, alien speed will begin to
# increase. When only one alien is left, it will move at one cell per frame. When a wave of aliens is fully cleared a
# new one will spawn which moves at a slightly faster speed than the last. Termination occurs when an alien or bullet
# hits the player.
#
#####################################################################################################################
class Env:
    def __init__(self, ramping=True, random_state=None):
        self.channels = {
            "cannon": 0,
            "alien": 1,
            "alien_left": 2,
            "alien_right": 3,
            "friendly_bullet": 4,
            "enemy_bullet": 5,
        }
        self.action_map = ["n", "l", "u", "r", "d", "f"]
        self.ramping = ramping
        if random_state is None:
            self.random = np.random.RandomState()
        else:
            self.random = random_state
        self.reset()

    # Update environment according to agent action
    def act(self, a):
        r = 0
        if self.terminal:
            return r, self.terminal

        a = self.action_map[a]

        # Resolve player action
        if a == "f" and self.shot_timer == 0:
            self.f_bullet_map[9, self.pos] = 1
            self.shot_timer = shot_cool_down
        elif a == "l":
            self.pos = max(0, self.pos - 1)
        elif a == "r":
            self.pos = min(9, self.pos + 1)

        # Update Friendly Bullets
        self.f_bullet_map = np.roll(self.f_bullet_map, -1, axis=0)
        self.f_bullet_map[9, :] = 0

        # Update Enemy Bullets
        self.e_bullet_map = np.roll(self.e_bullet_map, 1, axis=0)
        self.e_bullet_map[0, :] = 0
        if self.e_bullet_map[9, self.pos]:
            self.terminal = True

        # Update aliens
        if self.alien_map[9, self.pos]:
            self.terminal = True
        if self.alien_move_timer == 0:
            self.alien_move_timer = min(
                np.count_nonzero(self.alien_map), self.enemy_move_interval
            )
            if (np.sum(self.alien_map[:, 0]) > 0 and self.alien_dir < 0) or (
                np.sum(self.alien_map[:, 9]) > 0 and self.alien_dir > 0
            ):
                self.alien_dir = -self.alien_dir
                if np.sum(self.alien_map[9, :]) > 0:
                    self.terminal = True
                self.alien_map = np.roll(self.alien_map, 1, axis=0)
            else:
                self.alien_map = np.roll(self.alien_map, self.alien_dir, axis=1)
            if self.alien_map[9, self.pos]:
                self.terminal = True
        if self.alien_shot_timer == 0:
            self.alien_shot_timer = enemy_shot_interval
            nearest_alien = self._nearest_alien(self.pos)
            self.e_bullet_map[nearest_alien[0], nearest_alien[1]] = 1

        kill_locations = np.logical_and(
            self.alien_map, self.alien_map == self.f_bullet_map
        )

        r += np.sum(kill_locations)
        self.alien_map[kill_locations] = self.f_bullet_map[kill_locations] = 0

        # Update various timers
        self.shot_timer -= self.shot_timer > 0
        self.alien_move_timer -= 1
        self.alien_shot_timer -= 1
        if np.count_nonzero(self.alien_map) == 0:
            if self.enemy_move_interval > 6 and self.ramping:
                self.enemy_move_interval -= 1
                self.ramp_index += 1
            self.alien_map[0:4, 2:8] = 1
        return r, self.terminal

    # Find the alien closest to player in manhattan distance, currently used to decide which alien shoots
    def _nearest_alien(self, pos):
        search_order = [i for i in range(10)]
        search_order.sort(key=lambda x: abs(x - pos))
        for i in search_order:
            if np.sum(self.alien_map[:, i]) > 0:
                return [np.max(np.where(self.alien_map[:, i] == 1)), i]
        return None

    # Query the current level of the difficulty ramp, could be used as additional input to agent for example
    def difficulty_ramp(self):
        return self.ramp_index

    # Process the game-state into the 10x10xn state provided to the agent and return
    def state(self):
        state = np.zeros((10, 10, len(self.channels)), dtype=bool)
        state[9, self.pos, self.channels["cannon"]] = 1
        state[:, :, self.channels["alien"]] = self.alien_map
        if self.alien_dir < 0:
            state[:, :, self.channels["alien_left"]] = self.alien_map
        else:
            state[:, :, self.channels["alien_right"]] = self.alien_map
        state[:, :, self.channels["friendly_bullet"]] = self.f_bullet_map
        state[:, :, self.channels["enemy_bullet"]] = self.e_bullet_map
        return state

    # Reset to start state for new episode
    def reset(self):
        self.pos = 5
        self.f_bullet_map = np.zeros((10, 10))
        self.e_bullet_map = np.zeros((10, 10))
        self.alien_map = np.zeros((10, 10))
        self.alien_map[0:4, 2:8] = 1
        self.alien_dir = -1
        self.enemy_move_interval = enemy_move_interval
        self.alien_move_timer = self.enemy_move_interval
        self.alien_shot_timer = enemy_shot_interval
        self.ramp_index = 0
        self.shot_timer = 0
        self.terminal = False

    # Dimensionality of the game-state (10x10xn)
    def state_shape(self):
        return [10, 10, len(self.channels)]

    # Subset of actions that actually have a unique impact in this environment
    def minimal_action_set(self):
        minimal_actions = ["n", "l", "r", "f"]
        return [self.action_map.index(x) for x in minimal_actions]
