"""Source: https://github.com/kenjyoung/MinAtar/blob/master/minatar/environments/asterix.py"""

from typing import Optional
from gymnasium.spaces import Discrete

import numpy as np


#####################################################################################################################
# Constants
#
#####################################################################################################################

ramp_interval = 100
init_spawn_speed = 10
init_move_interval = 5
shot_cool_down = 5

#####################################################################################################################
# Env
#
# The player can move freely along the 4 cardinal directions. Enemies and treasure spawn from the sides. A reward of
# +1 is given for picking up treasure. Termination occurs if the player makes contact with an enemy. Enemy and
# treasure direction are indicated by a trail channel. Difficulty is periodically increased by increasing the speed
# and spawn rate of enemies and treasure.
#
#####################################################################################################################


class AsterixEnv:

    def __init__(self, ramping=True, time_limit: int = 1000):
        self.channels = {
            'player': 0,
            'enemy': 1,
            'trail': 2,
            'gold': 3
        }
        self.action_map = ['n', 'l', 'u', 'r', 'd', 'f']
        self.ramping = ramping
        self.random = np.random.RandomState()
        self.time_limit = time_limit
        self._timer = time_limit
        self.action_space = Discrete(len(self.action_map))
        self.reset()

        self.n_dim = 1
        self.n_actions = len(self.action_map)

    # Update environment according to agent action
    def act(self, a):
        r = 0.1
        if (self.terminal):
            return r, self.terminal

        a = self.action_map[a]

        # Spawn enemy if timer is up
        if (self.spawn_timer == 0):
            self._spawn_entity()
            self.spawn_timer = self.spawn_speed

        # Resolve player action
        if (a == 'l'):
            self.player_x = max(0, self.player_x-1)
        elif (a == 'r'):
            self.player_x = min(9, self.player_x+1)
        elif (a == 'u'):
            self.player_y = max(1, self.player_y-1)
        elif (a == 'd'):
            self.player_y = min(8, self.player_y+1)

        # Update entities
        for i in range(len(self.entities)):
            x = self.entities[i]
            if (x is not None):
                if (x[0:2] == [self.player_x, self.player_y]):
                    if (self.entities[i][3]):
                        self.entities[i] = None
                        r += 10
                    else:
                        self.terminal = True
        if (self.move_timer == 0):
            self.move_timer = self.move_speed
            for i in range(len(self.entities)):
                x = self.entities[i]
                if (x is not None):
                    x[0] += 1 if x[2] else -1
                    if (x[0] < 0 or x[0] > 9):
                        self.entities[i] = None
                    if (x[0:2] == [self.player_x, self.player_y]):
                        if (self.entities[i][3]):
                            self.entities[i] = None
                            r += 10
                        else:
                            self.terminal = True

        # Update various timers
        self.spawn_timer -= 1
        self.move_timer -= 1
        self._timer -= 1

        # Ramp difficulty if interval has elapsed
        if self.ramping and (self.spawn_speed > 1 or self.move_speed > 1):
            if (self.ramp_timer >= 0):
                self.ramp_timer -= 1
            else:
                if (self.move_speed > 1 and self.ramp_index % 2):
                    self.move_speed -= 1
                if (self.spawn_speed > 1):
                    self.spawn_speed -= 1
                self.ramp_index += 1
                self.ramp_timer = ramp_interval

        self.terminal = self.terminal or self._timer <= 0

        return r, self.terminal

    # Spawn a new enemy or treasure at a random location with random direction (if all rows are filled do nothing)
    def _spawn_entity(self):
        lr = self.random.rand() < 1/2
        is_gold = self.random.rand() < 1/3
        x = 0 if lr else 9
        slot_options = [i for i in range(
            len(self.entities)) if self.entities[i] == None]
        if (not slot_options):
            return
        slot = slot_options[self.random.randint(len(slot_options))]
        self.entities[slot] = [x, slot+1, lr, is_gold]

    # Query the current level of the difficulty ramp, could be used as additional input to agent for example
    def difficulty_ramp(self):
        return self.ramp_index

    # Process the game-state into the 10x10xn state provided to the agent and return
    def state(self):
        state = np.zeros((10, 10, len(self.channels)), dtype=bool)
        state[self.player_y, self.player_x, self.channels['player']] = 1
        for x in self.entities:
            if (x is not None):
                c = self.channels['gold'] if x[3] else self.channels['enemy']
                state[x[1], x[0], c] = 1
                back_x = x[0]-1 if x[2] else x[0]+1
                if (back_x >= 0 and back_x <= 9):
                    state[x[1], back_x, self.channels['trail']] = 1
        return state.astype(float).flatten()

    # Reset to start state for new episode
    def reset(self, **kwargs):
        self.player_x = 5
        self.player_y = 5
        self.entities = [None]*8
        self.shot_timer = 0
        self.spawn_speed = init_spawn_speed
        self.spawn_timer = self.spawn_speed
        self.move_speed = init_move_interval
        self.move_timer = self.move_speed
        self.ramp_timer = ramp_interval
        self.ramp_index = 0
        self.terminal = False
        self._timer = self.time_limit

        return self.state() #, None

    # Dimensionality of the game-state(10x10xn)
    def state_shape(self):
        return [10, 10, len(self.channels)]

    # Subset of actions that actually have a unique impact in this environment
    def minimal_action_set(self):
        minimal_actions = ['n', 'l', 'u', 'r', 'd']
        return [self.action_map.index(x) for x in minimal_actions]

    def step(self, action):
        reward, done = self.act(action)

        return self.state(), reward, done, {}