import gym
from gym import spaces
from gym import spec
import numpy as np
from enum import IntEnum
from .assets import MEDIUM


class Actions(IntEnum):
    Left = 0
    Right = 1


class Entity:
    def __init__(self, id, x, y):
        self.id = id
        self.x = x
        self.y = y


class Agent(Entity):
    def __init__(self, x, y):
        super().__init__(id='agent', x=x, y=y)


class Pixel(Entity):
    def __init__(self, col, x, y, speed):
        super().__init__(id=col, x=x, y=y)
        self.speed = speed


class FallingPixels(gym.Env):
    """
    y-axis is the rows of mat
    x-axis is the cols of mat

    (0,0) is top left of mat
    """
    metadata = {"render.modes": ["rgb_array"]}

    def __init__(self, layout, reward):
        if layout == 'small':
            self.grid = np.array(MEDIUM)
            self.spawning_chance = 0.15
            self._max_episode_steps = 125

        if layout == 'medium':
            self.grid = np.array(MEDIUM)
            self.spawning_chance = 0.025
            self._max_episode_steps = 250

        if layout == 'large':
            self.grid = np.array(MEDIUM)
            self.spawning_chance = 0.15
            self._max_episode_steps = 250

        self.nrows = self.grid.shape[0]
        self.ncols = self.grid.shape[1]

        self.rewards = reward

        self.pixels = []

        self.action_space = spaces.Box(-1.0, 1.0, (1,))
        self.observation_space = spaces.Box(
            low=0,
            high=1,
            shape=(1, self.nrows, self.ncols),
            dtype=np.uint8
        )
        self.curr_step = 0

    def step(self, action):
        reward = 0

        if action > 0:
            discrete_action = 1
        else:
            discrete_action = -1

        proposed_pos = self.agent.x + discrete_action

        if not self._detect_collision_agent(proposed_pos):
            self.grid[self.agent.y, self.agent.x] = 0
            self.agent.x += discrete_action
            self.grid[self.agent.y, self.agent.x] = 255

        for pixel in self.pixels:
            proposed_pos = pixel.y + pixel.speed

            if not self._detect_collision_pixel(proposed_pos):
                self.grid[pixel.y, pixel.x] = 0
                pixel.y += pixel.speed
                self.grid[pixel.y, pixel.x] = 125

            else:
                # Check to see if pixel is just now off-screen or touching agent
                if pixel.x == self.agent.x:
                    if self.rewards == 'identical':
                        reward += 1
                    else:
                        reward += pixel.speed

                self._delete_pixel(pixel.x)
                self.grid[pixel.y, pixel.x] = 0
                self._restore_agent()

        # Now to do the re-spawning
        for i in range(self.ncols):
            found = False
            for pixel in self.pixels:
                if pixel.x == i:
                    found = True
                    break

            if not found:
                if np.random.rand() <= self.spawning_chance:
                    self._spawn_pixel(i)

        self.curr_step += 1

        return self.grid[None], reward, self.curr_step >= self._max_episode_steps, {}

    def _detect_collision_agent(self, proposed_pos):
        if proposed_pos < 0 or proposed_pos >= self.ncols:
            return True
        else:
            return False

    def _detect_collision_pixel(self, proposed_pos):
        if proposed_pos >= self.nrows - 1:
            return True
        else:
            return False

    def _restore_agent(self):
        if self.grid[self.agent.y, self.agent.x] == 0:
            self.grid[self.agent.y, self.agent.x] = 255

    def reset(self):
        # Resetting grid
        self.grid = np.zeros((self.nrows, self.ncols), dtype=np.uint8)

        # Resetting agent position
        self.agent = Agent(x=np.random.choice(self.ncols), y=self.nrows - 1)
        self.grid[self.agent.y, self.agent.x] = 255

        # Resetting pixel positions
        for i in range(self.ncols):
            self._delete_pixel(i)
            self._spawn_pixel(i)

        self.curr_step = 0

        return self.grid[None]

    def _spawn_pixel(self, col):
        speed = np.random.choice([1, 3, 5])
        pixel = Pixel(col=col, x=col, y=0, speed=speed)
        self.pixels.append(pixel)
        self.grid[pixel.y, pixel.x] = 125

    def _delete_pixel(self, col):
        """

        :param col: col of pixel we wish to delete
        :return:
        """
        new_pixels = []

        for pixel in self.pixels:
            if pixel.id == col:
                pass
            else:
                new_pixels.append(pixel)

        self.pixels = new_pixels

    def close(self):
        pass

    def render(self, mode, height, width):
        return np.concatenate([self.grid[None], self.grid[None], self.grid[None]], 0)
