
import numpy as np

#####################################################################################################################
# Constants
#
#####################################################################################################################
ramp_interval = 100
max_oxygen = 200
init_spawn_speed = 20
diver_spawn_speed = 30
init_move_interval = 5
shot_cool_down = 5
enemy_shot_interval = 10
enemy_move_interval = 5
diver_move_interval = 5


#####################################################################################################################
# Env
#
# The player controls a submarine consisting of two cells, front and back, to allow direction to be determined. The
# player can also fire bullets from the front of the submarine. Enemies consist of submarines and fish, distinguished
# by the fact that submarines shoot bullets and fish do not. A reward of +1 is given each time an enemy is struck by
# one of the player's bullets, at which point the enemy is also removed. There are also divers which the player can
# move onto to pick up, doing so increments a bar indicated by another channel along the bottom of the screen. The
# player also has a limited supply of oxygen indicated by another bar in another channel. Oxygen degrades over time,
# and is replenished whenever the player moves to the top of the screen as long as the player has at least one rescued
# diver on board. The player can carry a maximum of 6 divers. When surfacing with less than 6, one diver is removed.
# When surfacing with 6, all divers are removed and a reward is given for each active cell in the oxygen bar. Each
# time the player surfaces the difficulty is increased by increasing the spawn rate and movement speed of enemies.
# Termination occurs when the player is hit by an enemy fish, sub or bullet; or when oxygen reached 0; or when the
# player attempts to surface with no rescued divers. Enemy and diver directions are indicated by a trail channel
# active in their previous location to reduce partial observability.
#
#####################################################################################################################
class Env:
    def __init__(self, ramping=True, random_state=None):
        self.channels = {
            "sub_front": 0,
            "sub_back": 1,
            "friendly_bullet": 2,
            "trail": 3,
            "enemy_bullet": 4,
            "enemy_fish": 5,
            "enemy_sub": 6,
            "oxygen_guage": 7,
            "diver_guage": 8,
            "diver": 9,
        }
        self.action_map = ["n", "l", "u", "r", "d", "f"]
        self.ramping = ramping
        if random_state is None:
            self.random = np.random.RandomState()
        else:
            self.random = random_state
        self.reset()

    # Update environment according to agent action
    def act(self, a):
        r = 0
        if self.terminal:
            return r, self.terminal

        a = self.action_map[a]

        # Spawn enemy if timer is up
        if self.e_spawn_timer == 0:
            self._spawn_enemy()
            self.e_spawn_timer = self.e_spawn_speed

        if self.d_spawn_timer == 0:
            self._spawn_diver()
            self.d_spawn_timer = diver_spawn_speed

        # Resolve player action
        if a == "f" and self.shot_timer == 0:
            self.f_bullets += [[self.sub_x, self.sub_y, self.sub_or]]
            self.shot_timer = shot_cool_down
        elif a == "l":
            self.sub_x = max(0, self.sub_x - 1)
            self.sub_or = False
        elif a == "r":
            self.sub_x = min(9, self.sub_x + 1)
            self.sub_or = True
        elif a == "u":
            self.sub_y = max(0, self.sub_y - 1)
        elif a == "d":
            self.sub_y = min(8, self.sub_y + 1)

        # Update friendly Bullets
        for bullet in reversed(self.f_bullets):
            bullet[0] += 1 if bullet[2] else -1
            if bullet[0] < 0 or bullet[0] > 9:
                self.f_bullets.remove(bullet)
            else:
                removed = False
                for x in self.e_fish:
                    if bullet[0:2] == x[0:2]:
                        self.e_fish.remove(x)
                        self.f_bullets.remove(bullet)
                        r += 1
                        removed = True
                        break
                if not removed:
                    for x in self.e_subs:
                        if bullet[0:2] == x[0:2]:
                            self.e_subs.remove(x)
                            self.f_bullets.remove(bullet)
                            r += 1
                            break

        # Update divers
        for diver in reversed(self.divers):
            if diver[0:2] == [self.sub_x, self.sub_y] and self.diver_count < 6:
                self.divers.remove(diver)
                self.diver_count += 1
            else:
                if diver[3] == 0:
                    diver[3] = diver_move_interval
                    diver[0] += 1 if diver[2] else -1
                    if diver[0] < 0 or diver[0] > 9:
                        self.divers.remove(diver)
                    elif (
                        diver[0:2] == [self.sub_x, self.sub_y] and self.diver_count < 6
                    ):
                        self.divers.remove(diver)
                        self.diver_count += 1
                else:
                    diver[3] -= 1

        # Update enemy subs
        for sub in reversed(self.e_subs):
            if sub[0:2] == [self.sub_x, self.sub_y]:
                self.terminal = True
            if sub[3] == 0:
                sub[3] = self.move_speed
                sub[0] += 1 if sub[2] else -1
                if sub[0] < 0 or sub[0] > 9:
                    self.e_subs.remove(sub)
                elif sub[0:2] == [self.sub_x, self.sub_y]:
                    self.terminal = True
                else:
                    for x in self.f_bullets:
                        if sub[0:2] == x[0:2]:
                            self.e_subs.remove(sub)
                            self.f_bullets.remove(x)
                            r += 1
                            break
            else:
                sub[3] -= 1
            if sub[4] == 0:
                sub[4] = enemy_shot_interval
                self.e_bullets += [[sub[0] if sub[2] else sub[0], sub[1], sub[2]]]
            else:
                sub[4] -= 1

        # Update enemy bullets
        for bullet in reversed(self.e_bullets):
            if bullet[0:2] == [self.sub_x, self.sub_y]:
                self.terminal = True
            bullet[0] += 1 if bullet[2] else -1
            if bullet[0] < 0 or bullet[0] > 9:
                self.e_bullets.remove(bullet)
            else:
                if bullet[0:2] == [self.sub_x, self.sub_y]:
                    self.terminal = True

        # Update enemy fish
        for fish in reversed(self.e_fish):
            if fish[0:2] == [self.sub_x, self.sub_y]:
                self.terminal = True
            if fish[3] == 0:
                fish[3] = self.move_speed
                fish[0] += 1 if fish[2] else -1
                if fish[0] < 0 or fish[0] > 9:
                    self.e_fish.remove(fish)
                elif fish[0:2] == [self.sub_x, self.sub_y]:
                    self.terminal = True
                else:
                    for x in self.f_bullets:
                        if fish[0:2] == x[0:2]:
                            self.e_fish.remove(fish)
                            self.f_bullets.remove(x)
                            r += 1
                            break
            else:
                fish[3] -= 1

        # Update various timers
        self.e_spawn_timer -= self.e_spawn_timer > 0
        self.d_spawn_timer -= self.d_spawn_timer > 0
        self.shot_timer -= self.shot_timer > 0
        if self.oxygen < 0:
            self.terminal = True
        if self.sub_y > 0:
            self.oxygen -= 1
            self.surface = False
        else:
            if not self.surface:
                if self.diver_count == 0:
                    self.terminal = True
                else:
                    r += self._surface()
        return r, self.terminal

    # Called when player hits surface (top row) if they have no divers, this ends the game,
    # if they have 6 divers this gives reward proportional to the remaining oxygen and restores full oxygen
    # otherwise this reduces the number of divers and restores full oxygen
    def _surface(self):
        self.surface = True
        if self.diver_count == 6:
            self.diver_count = 0
            r = self.oxygen * 10 // max_oxygen
        else:
            r = 0
        self.oxygen = max_oxygen
        self.diver_count -= 1
        if self.ramping and (self.e_spawn_speed > 1 or self.move_speed > 2):
            if self.move_speed > 2 and self.ramp_index % 2:
                self.move_speed -= 1
            if self.e_spawn_speed > 1:
                self.e_spawn_speed -= 1
            self.ramp_index += 1
        return r

    # Spawn an enemy fish or submarine in random row and random direction,
    # if the resulting row and direction would lead to a collision, do nothing instead
    def _spawn_enemy(self):
        lr = self.random.choice([True, False])
        is_sub = self.random.choice([True, False], p=[1 / 3, 2 / 3])
        x = 0 if lr else 9
        y = self.random.choice(np.arange(1, 9))

        # Do not spawn in same row an opposite direction as existing
        if any([z[1] == y and z[2] != lr for z in self.e_subs + self.e_fish]):
            return
        if is_sub:
            self.e_subs += [[x, y, lr, self.move_speed, enemy_shot_interval]]
        else:
            self.e_fish += [[x, y, lr, self.move_speed]]

    # Spawn a diver in random row with random direction
    def _spawn_diver(self):
        lr = self.random.choice([True, False])
        x = 0 if lr else 9
        y = self.random.choice(np.arange(1, 9))
        self.divers += [[x, y, lr, diver_move_interval]]

    # Query the current level of the difficulty ramp, could be used as additional input to agent for example
    def difficulty_ramp(self):
        return self.ramp_index

    # Process the game-state into the 10x10xn state provided to the agent and return
    def state(self):
        state = np.zeros((10, 10, len(self.channels)), dtype=bool)
        state[self.sub_y, self.sub_x, self.channels["sub_front"]] = 1
        back_x = self.sub_x - 1 if self.sub_or else self.sub_x + 1
        state[self.sub_y, back_x, self.channels["sub_back"]] = 1
        state[9, 0 : self.oxygen * 10 // max_oxygen, self.channels["oxygen_guage"]] = 1
        state[9, 9 - self.diver_count : 9, self.channels["diver_guage"]] = 1
        for bullet in self.f_bullets:
            state[bullet[1], bullet[0], self.channels["friendly_bullet"]] = 1
        for bullet in self.e_bullets:
            state[bullet[1], bullet[0], self.channels["enemy_bullet"]] = 1
        for fish in self.e_fish:
            state[fish[1], fish[0], self.channels["enemy_fish"]] = 1
            back_x = fish[0] - 1 if fish[2] else fish[0] + 1
            if back_x >= 0 and back_x <= 9:
                state[fish[1], back_x, self.channels["trail"]] = 1
        for sub in self.e_subs:
            state[sub[1], sub[0], self.channels["enemy_sub"]] = 1
            back_x = sub[0] - 1 if sub[2] else sub[0] + 1
            if back_x >= 0 and back_x <= 9:
                state[sub[1], back_x, self.channels["trail"]] = 1
        for diver in self.divers:
            state[diver[1], diver[0], self.channels["diver"]] = 1
            back_x = diver[0] - 1 if diver[2] else diver[0] + 1
            if back_x >= 0 and back_x <= 9:
                state[diver[1], back_x, self.channels["trail"]] = 1

        return state

    # Reset to start state for new episode
    def reset(self):
        self.oxygen = max_oxygen
        self.diver_count = 0
        self.sub_x = 5
        self.sub_y = 0
        # 0=left, 1=right
        self.sub_or = False
        self.f_bullets = []
        self.e_bullets = []
        self.e_fish = []
        self.e_subs = []
        self.divers = []
        self.e_spawn_speed = init_spawn_speed
        self.e_spawn_timer = self.e_spawn_speed
        self.d_spawn_timer = diver_spawn_speed
        self.move_speed = init_move_interval
        self.ramp_index = 0
        self.shot_timer = 0
        self.surface = True
        self.terminal = False

    # Dimensionality of the game-state (10x10xn)
    def state_shape(self):
        return [10, 10, len(self.channels)]

    # Subset of actions that actually have a unique impact in this environment
    def minimal_action_set(self):
        minimal_actions = ["n", "l", "u", "r", "d", "f"]
        return [self.action_map.index(x) for x in minimal_actions]
