import collections
import os
import random

import gym
import numpy as np
import pygame
from gym.spaces import Box, Discrete
from pygame.compat import geterror

main_dir = os.path.split(os.path.abspath(__file__))[0]
assets_dir = os.path.join(main_dir, 'assets')


def _load_image(name):
    fullname = os.path.join(assets_dir, name)
    try:
        image = pygame.image.load(fullname)
    except pygame.error:
        print('Cannot load image:', fullname)
        raise SystemExit(str(geterror()))
    image = image.convert_alpha()
    return image


def _calculate_topleft_position(position, sprite_size):
    return sprite_size * position[1], sprite_size * position[0]


class _Collectible(pygame.sprite.Sprite):
    _COLLECTIBLE_IMAGES = {
        ('square', 'purple'): 'purple_square.png',
        ('circle', 'purple'): 'purple_circle.png',
        ('square', 'beige'): 'beige_square.png',
        ('circle', 'beige'): 'beige_circle.png',
        ('square', 'blue'): 'blue_square.png',
        ('circle', 'blue'): 'blue_circle.png'
    }

    def __init__(self, sprite_size, shape, colour):
        self.name = shape + '_' + colour
        self._sprite_size = sprite_size
        self.shape = shape
        self.colour = colour
        pygame.sprite.Sprite.__init__(self)
        image = _load_image(self._COLLECTIBLE_IMAGES[(self.shape, self.colour)])
        self.image = pygame.transform.scale(image, (sprite_size, sprite_size))
        self.rect = self.image.get_rect()
        self.position = None

    def reset(self, position):
        self.position = position
        self.rect.topleft = _calculate_topleft_position(position, self._sprite_size)


class _Player(pygame.sprite.Sprite):
    def __init__(self, sprite_size):
        self.name = 'player'
        self._sprite_size = sprite_size
        pygame.sprite.Sprite.__init__(self)
        image = _load_image('character.png')
        self.image = pygame.transform.scale(image, (sprite_size, sprite_size))
        self.rect = self.image.get_rect()
        self.position = None

    def reset(self, position):
        self.position = position
        self.rect.topleft = _calculate_topleft_position(position, self._sprite_size)

    def step(self, move):
        self.position = (self.position[0] + move[0], self.position[1] + move[1])
        self.rect.topleft = _calculate_topleft_position(self.position, self._sprite_size)


class CollectEnv(gym.Env):
    """
    This environment consists of an agent attempting to collect a number of objects. The agents has four actions
    to move him up, down, left and right, but may be impeded by walls.
    There are two types of objects in the environment: fridges and TVs, each of which take one of three colours
    (white, blue and purple) for a total of six objects.

    The objects the agent must collect can be specified by passing a goal condition lambda to the environment.

    """

    metadata = {
        'render.modes': ['human', 'rgb_array'],
        'video.frames_per_second': 5
    }

    #_BOARDS = {
    #    'original': ['**********',
    #                 '##########',
    #                 '#        #',
    #                 '#    #   #',
    #                 '#   ##   #',
    #                 '#  ##    #',
    #                 '#   #    #',
    #                 '#        #',
    #                 '#        #',
    #                 '##########'],
    #}

    _BOARDS = {
        'original': ['**********',
                     '##########',
                     '#        #',
                     '#        #',
                     '#        #',
                     '#        #',
                     '##########'],
    }

    _AVAILABLE_COLLECTIBLES = [
        ('square', 'purple'),
        ('circle', 'purple'),
        ('square', 'beige'),
        ('circle', 'beige'),
        ('square', 'blue'),
        ('circle', 'blue')
    ]

    _WALL_IMAGE = 'wall.png'
    _GROUND_IMAGE = 'ground.png'
    _BLANK_IMAGE = 'blank.png'    
    _TASK_IMAGES = ['blue.png', 'beige.png', 'purple.png', 'square.png', 'circle.png']

    _ACTIONS = {
        0: (-1, 0),  # North
        1: (0, 1),  # East
        2: (1, 0),  # South
        3: (0, -1)  # West
    }

    _SCREEN_SIZE = (400, 400)
    _SPRITE_SIZE = 40

    def __init__(self, board='original', available_collectibles=None, start_positions=None,
                 task='blue'):
        """
        Create a new instance of the RepoMan environment. The observation space is a single RGB image of size 400x400,
        and the action space are four discrete actions corresponding to North, East, South and West.


        :param board: the state of the walls and free space
        :param available_collectibles: the items that will be placed in the task
        :param start_positions: an option parameter to specify the starting position of the objects and player
        :param goal_condition: a lambda that determines when to end the episode based on teh object just collected.
        By default, the episode ends when any object is collected.
        """
        self.viewer = None
        self.start_positions = start_positions
        self.action_space = Discrete(4)
        self.tasks = [['blue'],['beige'],['purple'],['square'],['circle']]
        self.task = self.tasks.index([task])
        self.task_conditions = {
                      'blue': lambda x: x.colour == 'blue',
                      'beige': lambda x: x.colour == 'beige',
                      'purple': lambda x: x.colour == 'purple',
                      'square': lambda x: x.shape == 'square',
                      'circle': lambda x: x.shape == 'circle'
                      }
        self.goal_condition = self.task_conditions[self.tasks[self.task][0]]

        self.available_collectibles = available_collectibles \
            if available_collectibles is not None else self._AVAILABLE_COLLECTIBLES

        self.board = np.array([list(row) for row in self._BOARDS[board]])

        self.observation_space = Box(0, 255, [self._SCREEN_SIZE[0], self._SCREEN_SIZE[1], 3], dtype=int)
        pygame.init()
        pygame.display.init()
        pygame.display.set_mode((1, 1))

        self._bestdepth = pygame.display.mode_ok(self._SCREEN_SIZE, 0, 32)
        self._surface = pygame.Surface(self._SCREEN_SIZE, 0, self._bestdepth)
        self._background = pygame.Surface(self._SCREEN_SIZE)
        self._clock = pygame.time.Clock()

        self.free_spaces = list(map(tuple, np.argwhere((self.board != '#')*(self.board != '*'))))
        self._build_board()

        self.initial_positions = None
        self.collectibles = pygame.sprite.Group()
        self.collected = pygame.sprite.Group()
        self.render_group = pygame.sprite.RenderPlain()
        self.player = _Player(self._SPRITE_SIZE)
        self.render_group.add(self.player)

        for shape, colour in self.available_collectibles:
            self.collectibles.add(_Collectible(self._SPRITE_SIZE, shape, colour))

    def _build_board(self):
        for col in range(self.board.shape[1]):
            for row in range(self.board.shape[0]):
                position = _calculate_topleft_position((row, col), self._SPRITE_SIZE)
                if self.board[row, col] == '#':
                  image = self._WALL_IMAGE
                elif self.board[row, col] == '*':
                  image = self._BLANK_IMAGE
                elif self.board[row, col] == '0':
                  image = self._TASK_IMAGES[0]
                elif self.board[row, col] == '1':
                  image = self._TASK_IMAGES[1]
                elif self.board[row, col] == '2':
                  image = self._TASK_IMAGES[2]
                elif self.board[row, col] == '3':
                  image = self._TASK_IMAGES[3]
                else:
                  image = self._GROUND_IMAGE
                image = _load_image(image)
                image = pygame.transform.scale(image, (self._SPRITE_SIZE, self._SPRITE_SIZE))
                self._background.blit(image, position)

    def _draw_screen(self, surface):
        surface.blit(self._background, (0, 0))
        self.render_group.draw(surface)
        surface_array = pygame.surfarray.array3d(surface)
        observation = np.copy(surface_array).swapaxes(0, 1)
        del surface_array
        return observation

    def reset(self):
        # self.task = np.random.randint(len(self.tasks))
        self._BOARDS['original'][0] = '*'*(self.task) + str(self.task) + '*'*(10-self.task-1)
        self.board = np.array([list(row) for row in self._BOARDS['original']])
        self._build_board()
        
        self.goal_condition = self.task_conditions[self.tasks[self.task][0]]
        
        collected = self.collected.sprites()
        self.collectibles.add(collected)
        self.collected.empty()
        self.render_group.empty()
        self.render_group.add(self.player)
        self.render_group.add(self.collectibles.sprites())

        render_group = sorted(self.render_group, key=lambda x: x.name)
        if self.start_positions is None:
            positions = random.sample(self.free_spaces, k=len(render_group))
        else:
            start_positions = collections.OrderedDict(sorted(self.start_positions.items()))
            positions = start_positions.values()

        self.initial_positions = collections.OrderedDict()

        to_remove = list()
        for position, sprite in zip(positions, render_group):
            if position is None:
                to_remove.append(sprite)
            else:
                if sprite.name != 'player':
                    self.initial_positions[sprite] = position
                sprite.reset(position)

        self.collected.add(to_remove)
        self.render_group.remove(to_remove)
        return self._draw_screen(self._surface)

    def step(self, action):
        direction = self._ACTIONS[action]
        prev_pos = self.player.position
        next_pos = (direction[0] + prev_pos[0], direction[1] + prev_pos[1])
        if self.board[next_pos] != '#':
            self.player.step(direction)

        collected = pygame.sprite.spritecollide(self.player, self.collectibles, True)
        self.collected.add(collected)
        self.render_group.remove(collected)
        done, reward = False, -0.1
        if len(collected) > 0:
            if self.goal_condition(collected[0]):
                done, reward = True, 1.0

        return self._draw_screen(self._surface), reward, done, {'collected': self.collected}

    def render(self, mode='human', close=False):
        if close:
            if self.viewer is not None:
                pygame.quit()
                self.viewer = None
            return

        if self.viewer is None:
            self.viewer = pygame.display.set_mode(self._SCREEN_SIZE, 0, self._bestdepth)

        self._clock.tick(10 if mode != 'human' else 2)
        arr = self._draw_screen(self.viewer)
        pygame.display.flip()
        return arr

if __name__ == "__main__":

    env = CollectEnv()
    obs = env.reset()
    env.render()
    for _ in range(10000):
        obs, reward, done, _ = env.step(env.action_space.sample())
        env.render()
        if done:
            env.reset()
