import numpy as np


#####################################################################################################################
# Constants
#
#####################################################################################################################
from gym.core import ObsType

player_speed = 3
time_limit = 2500


#####################################################################################################################
# Env
#
# The player begins at the bottom of the screen and motion is restricted to traveling up and down. Player speed is
# also restricted such that the player can only move every 3 frames. A reward of +1 is given when the player reaches
# the top of the screen, at which point the player is returned to the bottom. Cars travel horizontally on the screen
# and teleport to the other side when the edge is reached. When hit by a car, the player is returned to the bottom of
# the screen. Car direction and speed is indicated by 5 trail channels, the location of the trail gives direction
# while the specific channel indicates how frequently the car moves (from once every frame to once every 5 frames).
# Each time the player successfully reaches the top of the screen, the car speeds are randomized. Termination occurs
# after X frames have elapsed.
#
#####################################################################################################################
class Env:

    def __init__(self, ramping=None, time_limit: int = 1000):
        self.channels ={
            'chicken':0,
            'car':1,
            'speed1':2,
            'speed2':3,
            'speed3':4,
            'speed4':5,
            'speed5':6,
        }
        self.action_map = ['n','l','u','r','d','f']
        self.random = np.random.RandomState()
        self.channels_to_exclude = [f'speed{i:d}' for i in range(1, 6)]
        self.channels_to_keep = [i for key, i in self.channels.items() if key not in self.channels_to_exclude]
        self.time_limit = time_limit
        self.reset()

    # Update environment according to agent action
    def act(self, a):
        r = 0
        if(self.terminal):
            return r, self.terminal

        a = self.action_map[a]

        if(a=='u' and self.move_timer==0):
            self.move_timer = player_speed
            self.pos = max(0, self.pos-1)
        elif(a=='d' and self.move_timer==0):
            self.move_timer = player_speed
            self.pos = min(9, self.pos+1)

        # Win condition
        if(self.pos==0):
            r+=1
            self._randomize_cars(initialize=False)
            self.pos = 9

        # Update cars
        for car in self.cars:
            if(car[0:2]==[4,self.pos]):
                self.pos = 9
            if(car[2]==0):
                car[2]=abs(car[3])
                car[0]+=1 if car[3]>0 else -1
                if(car[0]<0):
                    car[0]=9
                elif(car[0]>9):
                    car[0]=0
                if(car[0:2]==[4,self.pos]):
                    self.pos = 9
            else:
                car[2]-=1

        # Update various timers
        self.move_timer-=self.move_timer>0
        self.terminate_timer-=1
        if(self.terminate_timer<0):
            self.terminal = True
        return r, self.terminal

    # Query the current level of the difficulty ramp, difficulty does not ramp in this game, so return None
    def difficulty_ramp(self):
        return None

    # Process the game-state into the 10x10xn state provided to the agent and return
    def state(self):
        state = np.zeros((10,10,len(self.channels)),dtype=bool)
        state[self.pos,4,self.channels['chicken']] = 1
        for car in self.cars:
            state[car[1],car[0], self.channels['car']] = 1
            back_x = car[0]-1 if car[3]>0 else car[0]+1
            if(back_x<0):
                back_x=9
            elif(back_x>9):
                back_x=0
            if(abs(car[3])==1):
                trail = self.channels['speed1']
            elif(abs(car[3])==2):
                trail = self.channels['speed2']
            elif(abs(car[3])==3):
                trail = self.channels['speed3']
            elif(abs(car[3])==4):
                trail = self.channels['speed4']
            elif(abs(car[3])==5):
                trail = self.channels['speed5']
            state[car[1],back_x, trail] = 1
        return state

    # Randomize car speeds and directions, also reset their position if initialize=True
    def _randomize_cars(self, initialize=False):
        speeds = self.random.randint(1,6,8)
        directions = np.sign(self.random.rand(8) - 0.5).astype(int)
        speeds*=directions
        if(initialize):
            self.cars = []
            for i in range(8):
                self.cars+=[[0,i+1,abs(speeds[i]),speeds[i]]]
        else:
            for i in range(8):
                self.cars[i][2:4]=[abs(speeds[i]),speeds[i]]

    # Reset to start state for new episode
    def reset(self, **kwargs):
        self._randomize_cars(initialize=True)
        self.pos = 9
        self.move_timer = player_speed
        self.terminate_timer = self.time_limit
        self.terminal = False

    @property
    def observation(self):
        return self.state()[..., self.channels_to_keep]

    # Dimensionality of the game-state (10x10xn)
    def state_shape(self):
        return [10,10,len(self.channels)]

    def observation_shape(self):
        return [10,10,len(self.channels_to_keep)]

    # Subset of actions that actually have a unique impact in this environment
    def minimal_action_set(self):
        minimal_actions = ['n','u','d']
        return [self.action_map.index(x) for x in minimal_actions]
