from typing import Optional

import numpy as np


#####################################################################################################################
# Constants
#
#####################################################################################################################

ramp_interval = 100
init_spawn_speed = 10
init_move_interval = 5
shot_cool_down = 5


#####################################################################################################################
# Env
#
# The player can move freely along the 4 cardinal directions. Enemies and treasure spawn from the sides. A reward of
# +1 is given for picking up treasure. Termination occurs if the player makes contact with an enemy. Enemy and
# treasure direction are indicated by a trail channel. Difficulty is periodically increased by increasing the speed
# and spawn rate of enemies and treasure.
#
#####################################################################################################################
class Env:

    def __init__(self, ramping=True, time_limit: int = 1000):
        self.channels ={
            'player':0,
            'enemy':1,
            'trail':2,
            'gold':3
        }
        self.action_map = ['n','l','u','r','d','f']
        self.ramping = ramping
        self.random = np.random.RandomState()
        self.channels_to_exclude = ['trail']
        self.channels_to_keep = [i for key, i in self.channels.items() if key not in self.channels_to_exclude]
        self.time_limit = time_limit
        self._timer = time_limit
        self.reset()

    # Update environment according to agent action
    def act(self, a):
        r = 0
        if(self.terminal):
            return r, self.terminal

        a = self.action_map[a]

        # Spawn enemy if timer is up
        if(self.spawn_timer==0):
            self._spawn_entity()
            self.spawn_timer = self.spawn_speed

        # Resolve player action
        if(a=='l'):
            self.player_x = max(0, self.player_x-1)
        elif(a=='r'):
            self.player_x = min(9, self.player_x+1)
        elif(a=='u'):
            self.player_y = max(1, self.player_y-1)
        elif(a=='d'):
            self.player_y = min(8, self.player_y+1)

        # Update entities
        for i in range(len(self.entities)):
            x = self.entities[i]
            if(x is not None):
                if(x[0:2]==[self.player_x,self.player_y]):
                    if(self.entities[i][3]):
                        self.entities[i] = None
                        r+=1
                    else:
                        self.terminal = True
        if(self.move_timer==0):
            self.move_timer = self.move_speed
            for i in range(len(self.entities)):
                x = self.entities[i]
                if(x is not None):
                    x[0]+=1 if x[2] else -1
                    if(x[0]<0 or x[0]>9):
                        self.entities[i] = None
                    if(x[0:2]==[self.player_x,self.player_y]):
                        if(self.entities[i][3]):
                            self.entities[i] = None
                            r+=1
                        else:
                            self.terminal = True

        # Update various timers
        self.spawn_timer -= 1
        self.move_timer -= 1
        self._timer -= 1

        #Ramp difficulty if interval has elapsed
        if self.ramping and (self.spawn_speed>1 or self.move_speed>1):
            if(self.ramp_timer>=0):
                self.ramp_timer-=1
            else:
                if(self.move_speed>1 and self.ramp_index%2):
                    self.move_speed-=1
                if(self.spawn_speed>1):
                    self.spawn_speed-=1
                self.ramp_index+=1
                self.ramp_timer=ramp_interval

        self.terminal = self.terminal or self._timer <= 0

        return r, self.terminal

    # Spawn a new enemy or treasure at a random location with random direction (if all rows are filled do nothing)
    def _spawn_entity(self):
        lr = self.random.rand() < 1/2
        is_gold = self.random.rand() < 1/3
        x = 0 if lr else 9
        slot_options = [i for i in range(len(self.entities)) if self.entities[i]==None]
        if(not slot_options):
            return
        slot = slot_options[self.random.randint(len(slot_options))]
        self.entities[slot] = [x,slot+1,lr,is_gold]

    # Query the current level of the difficulty ramp, could be used as additional input to agent for example
    def difficulty_ramp(self):
        return self.ramp_index

    # Process the game-state into the 10x10xn state provided to the agent and return
    def state(self):
        state = np.zeros((10,10,len(self.channels)),dtype=bool)
        state[self.player_y,self.player_x,self.channels['player']] = 1
        for x in self.entities:
            if(x is not None):
                c = self.channels['gold'] if x[3] else self.channels['enemy']
                state[x[1], x[0],c] = 1
                back_x = x[0]-1 if x[2] else x[0]+1
                if(back_x>=0 and back_x<=9):
                    state[x[1], back_x, self.channels['trail']] = 1
        return state

    # Reset to start state for new episode
    def reset(self, **kwargs):
        self.player_x = 5
        self.player_y = 5
        self.entities = [None]*8
        self.shot_timer = 0
        self.spawn_speed = init_spawn_speed
        self.spawn_timer = self.spawn_speed
        self.move_speed = init_move_interval
        self.move_timer = self.move_speed
        self.ramp_timer = ramp_interval
        self.ramp_index = 0
        self.terminal = False
        self._timer = self.time_limit

    @property
    def observation(self):
        return self.state()[..., self.channels_to_keep]

    def observation_shape(self):
        return [10,10,len(self.channels_to_keep)]

    # Dimensionality of the game-state (10x10xn)
    def state_shape(self):
        return [10,10,len(self.channels)]

    # Subset of actions that actually have a unique impact in this environment
    def minimal_action_set(self):
        minimal_actions = ['n','l','u','r','d']
        return [self.action_map.index(x) for x in minimal_actions]

