import numpy as np

import gym

class GoalSearchSimple(object):
    def __init__(self, size):
        self.size = size
        self.attention_size = 5
        self.map = np.zeros((self.size, self.size, 4))
        self.action_space = gym.spaces.Discrete(4)
        self.observation_space = gym.spaces.Box(0, 1, (self.size, self.size, 4), dtype=np.int32)

    def reset(self):
        self.ep_step = 0
        self.map = np.zeros((self.size, self.size, 4))
        # place wall (horizontal or vertical)
        if np.random.random() < 0.5:
            wall_x = np.random.randint(1, self.size - 1)
            wall_height = np.random.randint(1, self.size - 1)
            self.map[wall_x, 0:wall_height, 1] = 1  # wall on channel 1
        else:
            wall_y = np.random.randint(1, self.size - 1)
            wall_height = np.random.randint(1, self.size - 1)
            self.map[0:wall_height, wall_y, 1] = 1  # wall on channel 1
        # place agent
        self.agent_x, self.agent_y = np.random.randint(0, self.size - 1, 2)
        # make sure agent is not on wall
        while any(self.map[self.agent_x, self.agent_y]) == 1:
            self.agent_x, self.agent_y = np.random.randint(0, self.size - 1, 2)
        self.map[self.agent_x, self.agent_y, 0] = 1  # agent on channel 0
        # randomly select the number of goals
        n_goals = np.random.randint(1, 5)
        for i in range(n_goals):
            goal_x, goal_y = np.random.randint(0, self.size - 1, 2)
            # make sure goal is not on anything already spawned
            while any(self.map[goal_x, goal_y]) == 1:
                goal_x, goal_y = np.random.randint(0, self.size - 1, 2)
            self.map[goal_x, goal_y, 2] = 1  # goal on 2
        # randomly select the number of enemy
        n_enemy = np.random.randint(0, 3)
        # spawn enemy
        self.enemies = []
        for i in range(n_enemy):
            enemy_x, enemy_y = np.random.randint(0, self.size - 1, 2)
            # make sure goal is not on anything already spawned
            while any(self.map[enemy_x, enemy_y]) == 1:
                enemy_x, enemy_y = np.random.randint(0, self.size - 1, 2)
            self.map[enemy_x, enemy_y, 3] = 1  # enemy on 3
            enemy_vel = [np.random.choice([-1,1]), np.random.choice([-1,1])]
            self.enemies.append(([enemy_x, enemy_y], enemy_vel))
        return self.get_obs()

    def clip_attention(self, x, y):
        """make sure attention location avoids edge case"""
        return np.clip(x, self.attention_size//2, self.size - 1 - self.attention_size//2), \
               np.clip(y, self.attention_size//2, self.size - 1 - self.attention_size//2)

    def get_obs(self):
        return self.map

    def render(self, map=None):
        if map is None:
            map = self.map
            size = self.size
        else:
            size = map.shape[0]
        # first convert map into binary
        # (look at https://gist.github.com/frnsys/91a69f9f552cbeee7b565b3149f29e3e) for this magic
        map_onehot = np.zeros_like(map)
        indices = np.argmax(map, axis=-1)
        x = np.arange(size).reshape((size, 1))
        y = np.tile(np.arange(size), (size, 1))
        map_onehot[x, y, indices] = 1
        # but could have empty space! if max probability thing is < 0, it's empty
        map_onehot *= (map > 0.5)
        map = map_onehot
        # now create an image out of it
        map_image = 255 * np.ones((size, size, 3))
        # first draw in the walls
        map_image[np.where(map[:, :, 1])[0], np.where(map[:, :, 1])[1], :] = [0, 0, 0]  # walls are black
        # then find where the agent is
        map_image[np.where(map[:, :, 0])[0], np.where(map[:, :, 0])[1], :] = [0, 0, 255]  # agent is blue
        # draw in the enemy
        map_image[np.where(map[:, :, 3])[0], np.where(map[:, :, 3])[1], :] = [255, 0, 0]  # enemy is red
        # draw in the goals (goals hide enemies and agent)
        map_image[np.where(map[:, :, 2])[0], np.where(map[:, :, 2])[1], :] = [0, 255, 0]  # goal is green
        # multiply the map image onto itself
        map_image = np.repeat(map_image, 40, axis=0)
        map_image = np.repeat(map_image, 40, axis=1)
        return np.rot90(map_image.astype(np.uint8), k=1)

    def get_reward_done(self):
        # first check if we hit an enemy
        if self.map[self.agent_x, self.agent_y, 3]:
            return -1, True
        # now check if we hit a goal
        elif self.map[self.agent_x, self.agent_y, 2]:
            return 1, False
        # end due to timesteps
        elif self.ep_step >= 200:
            return 0, True
        else:
            return 0, False

    def step(self, action):
        """agent takes an action"""
        if action == 0:  # up
            new_y = min(self.agent_y + 1, self.size - 1)
            new_x = self.agent_x
        elif action == 1:  # down
            new_y = max(self.agent_y - 1, 0)
            new_x = self.agent_x
        elif action == 2:  # left
            new_x = max(self.agent_x - 1, 0)
            new_y = self.agent_y
        elif action == 3:  # right
            new_x = min(self.agent_x + 1, self.size - 1)
            new_y = self.agent_y
        else:
            raise ValueError("action not recognized")
        # if hitting wall, reset to same location
        if self.map[new_x, new_y, 1]:
            new_x, new_y = self.agent_x, self.agent_y
        # move agent to new location!
        self.map[self.agent_x, self.agent_y, 0] = 0
        self.agent_x, self.agent_y = new_x, new_y
        self.map[self.agent_x, self.agent_y, 0] = 1
        # now move the enemy.
        for i, enemy in enumerate(self.enemies):
            (enemy_x, enemy_y), enemy_vel = enemy
            new_enemy_x = enemy_x + enemy_vel[0]
            # if it collides with wall or border reverse directions!
            collision_x = new_enemy_x > self.size - 1 or new_enemy_x < 0 or self.map[new_enemy_x, enemy_y, 1]
            if collision_x:
                enemy_vel[0] *= -1
                new_enemy_x = enemy_x + enemy_vel[0]
            new_enemy_y = enemy_y + enemy_vel[1]
            collision_y = new_enemy_y > self.size - 1 or new_enemy_y < 0 or self.map[enemy_x, new_enemy_y, 1]
            if collision_y:
                enemy_vel[1] *= -1
                new_enemy_y = enemy_y + enemy_vel[1]
            # if neither of those happened, check if it's a diagonal collision with wall!
            if not (collision_y or collision_x):
                if self.map[new_enemy_x, new_enemy_y, 1]:
                    new_enemy_x = enemy_x
                    new_enemy_y = enemy_y
                    enemy_vel[0] *= -1
                    enemy_vel[1] *= -1
            self.map[enemy_x, enemy_y, 3] = 0
            enemy_x = new_enemy_x
            enemy_y = new_enemy_y
            self.map[enemy_x, enemy_y, 3] = 1
            self.enemies[i] = ([enemy_x, enemy_y], enemy_vel)
        self.ep_step += 1
        r, done = self.get_reward_done()
        # remove goal if got any
        if self.map[self.agent_x, self.agent_y, 2]:
            self.map[self.agent_x, self.agent_y, 2] = 0
        return self.get_obs(), r, done, None

# import os
# import torch
# env = GoalSearchSimple(10)
# done = True
# ep = -1
# total_steps = 0
# savedir = '../trainingdata-GoalSearch-v2/'
# while total_steps < 20000:
#     if done:
#         obs = env.reset()
#         ep += 1
#         step = 0
#         actions = []
#         rewards = []
#         os.makedirs(savedir + str(ep))
#         torch.save(torch.from_numpy(obs.transpose(2,0,1)),
#                    open(savedir + '{}/{}.pt'.format(ep, step), 'wb'))
#     actions.append(np.random.randint(4))
#     obs, r, done, _ = env.step(actions[-1])
#     rewards.append(r)
#     step += 1
#     torch.save(torch.from_numpy(obs.transpose(2,0,1)),
#                open(savedir + '{}/{}.pt'.format(ep, step), 'wb'))
#     torch.save(torch.LongTensor(actions),
#                open(savedir + '{}/actions.pt'.format(ep), 'wb'))
#     torch.save(torch.FloatTensor(rewards),
#                open(savedir + '{}/rewards.pt'.format(ep), 'wb'))
#     total_steps += 1


class GoalSearchEnv(object):

    def __init__(self, size):
        self.size = size
        self.attention_size = 5
        self.map = np.zeros((self.size, self.size, 6))
        self.action_space = gym.spaces.Discrete(4)
        self.observation_space = gym.spaces.Box(0, 1, (self.size, self.size, 6), dtype=np.int32)

    def reset(self):
        self.ep_step = 0
        self.map = np.zeros((self.size, self.size, 6))
        # place wall
        wall_x = np.random.randint(1, self.size - 1)
        wall_height = np.random.randint(1, self.size - 1)
        self.map[wall_x, 0:wall_height, 3] = 1  # wall on channel 3
        # place agent
        self.agent_x = 0
        self.agent_y = self.size - 1
        self.map[self.agent_x, self.agent_y, 4] = 1  # agent on channel 4
        self.attention_x, self.attention_y = self.clip_attention(self.agent_x, self.agent_y)
        # pick goal type
        self.goal = np.random.randint(2)
        self.map[1, self.size-1, self.goal] = 1  # goal indicator either on 0 or 1
        # randomly select the goal locations
        self.left_goal_x = np.random.randint(wall_x)
        self.left_goal_y = np.random.randint(wall_height)
        self.map[self.left_goal_x, self.left_goal_y, 2] = 1  # left and right goal on 2
        self.right_goal_x = np.random.randint(wall_x+1, self.size)
        self.right_goal_y = np.random.randint(wall_height)
        self.map[self.right_goal_x, self.right_goal_y, 2] = 1  # left and right goal on 2
        # spawn enemy
        self.enemy_x, self.enemy_y = np.random.randint(0, self.size, size=(2,))
        # do not spawn on top of anything else!
        if np.any(self.map[self.enemy_x, self.enemy_y]):
            self.enemy_x, self.enemy_y = np.random.randint(0, self.size, size=(2,))
        self.map[self.enemy_x, self.enemy_y, 5] = 1  # enemy on channel 5
        self.enemy_vel = [np.random.choice([-1,1]), 0]
        return self.get_obs()

    def clip_attention(self, x, y):
        """make sure attention location avoids edge case"""
        return np.clip(x, self.attention_size//2, self.size - 1 - self.attention_size//2), \
               np.clip(y, self.attention_size//2, self.size - 1 - self.attention_size//2)

    def get_obs(self):
        """return observation under current attention"""
        # return self.map[(self.attention_x - self.attention_size//2):(self.attention_x + 1 + self.attention_size//2),
        #        (self.attention_y - self.attention_size//2):(self.attention_y + 1 + self.attention_size//2), :]
        # only show goal on even steps
        if self.ep_step % 2 == 0:
            self.map[1, self.size-1, self.goal] = 1
        else:
            self.map[1, self.size - 1, self.goal] = 0
        return self.map

    def render(self, map=None):
        if map is None:
            map = self.map
            size = self.size
        else:
            size = map.shape[0]
        # first convert map into binary
        # (look at https://gist.github.com/frnsys/91a69f9f552cbeee7b565b3149f29e3e) for this magic
        map_onehot = np.zeros_like(map)
        indices = np.argmax(map, axis=-1)
        x = np.arange(size).reshape((size, 1))
        y = np.tile(np.arange(size), (size, 1))
        map_onehot[x, y, indices] = 1
        # but could have empty space! if max probability thing is < 0, it's empty
        map_onehot *= (map > 0.5)
        map = map_onehot
        # now create an image out of it
        map_image = 255 * np.ones((size, size, 3))
        # first draw in the goals
        map_image[np.where(map[:, :, 0])[0], np.where(map[:, :, 0])[1], :] = [255, 0, 255]  # left goal is pink
        map_image[np.where(map[:, :, 1])[0], np.where(map[:, :, 1])[1], :] = [0, 0, 255]  # right goal is blue
        map_image[np.where(map[:, :, 2])[0], np.where(map[:, :, 2])[1], :] = [0, 255, 0]  # goal is green
        # then draw in the walls
        map_image[np.where(map[:, :, 3])[0], np.where(map[:, :, 3])[1], :] = [0, 0, 0]  # walls are black
        # then find where the agent is
        map_image[np.where(map[:, :, 4])[0], np.where(map[:, :, 4])[1], :] = [255, 255, 0]  # agent is yellow
        # draw in the enemy
        map_image[np.where(map[:, :, 5])[0], np.where(map[:, :, 5])[1], :] = [255, 0, 0]  # enemy is red
        # multiply the map image onto itself
        map_image = np.repeat(map_image, 40, axis=0)
        map_image = np.repeat(map_image, 40, axis=1)
        return np.rot90(map_image.astype(np.uint8), k=1)

    def get_reward_done(self):
        # first check if we hit an enemy
        if self.map[self.agent_x, self.agent_y, 5]:
            return -1, True
        if (self.agent_x == self.left_goal_x) and (self.agent_y == self.left_goal_y):
            if self.goal == 0:
                return 1, True
            else:
                return -1, True
        if (self.agent_x == self.right_goal_x) and (self.agent_y == self.right_goal_y):
            if self.goal == 1:
                return 1, True
            else:
                return -1, True
        # end due to timesteps
        if self.ep_step < 200:
            return 0, False
        else:
            return 0, True

    def step(self, action):
        """agent takes an action"""
        if action == 0:  # up
            new_y = min(self.agent_y + 1, self.size - 1)
            new_x = self.agent_x
        elif action == 1:  # down
            new_y = max(self.agent_y - 1, 0)
            new_x = self.agent_x
        elif action == 2:  # left
            new_x = max(self.agent_x - 1, 0)
            new_y = self.agent_y
        elif action == 3:  # right
            new_x = min(self.agent_x + 1, self.size - 1)
            new_y = self.agent_y
        else:
            raise ValueError("action not recognized")
        # check if wall is in place
        if self.map[new_x, new_y, 3]:
            new_x, new_y = self.agent_x, self.agent_y
        # move agent to new location!
        self.map[self.agent_x, self.agent_y, 4] = 0
        self.agent_x, self.agent_y = new_x, new_y
        self.map[self.agent_x, self.agent_y, 4] = 1
        # now move the enemy. change directions if hits the border
        if self.enemy_x == self.size - 1:
            self.enemy_vel[0] = -1
        elif self.enemy_x == 0:
            self.enemy_vel[0] = 1
        new_enemy_x = self.enemy_x + self.enemy_vel[0]
        new_enemy_y = self.enemy_y + self.enemy_vel[1]
        # change directions if it hits another object
        if np.any(self.map[new_enemy_x, new_enemy_y, 0:4]):
            self.enemy_vel[0] *= -1
            new_enemy_x = self.enemy_x + self.enemy_vel[0]
            new_enemy_y = self.enemy_y + self.enemy_vel[1]
            # but now check one last time again if this led the agent into a wall!
            if new_enemy_x > (self.size - 1) or new_enemy_x < 0:
                new_enemy_x = self.enemy_x
        self.map[self.enemy_x, self.enemy_y, 5] = 0
        self.enemy_x = new_enemy_x
        self.enemy_y = new_enemy_y
        self.map[self.enemy_x, self.enemy_y, 5] = 1
        r, done = self.get_reward_done()
        # attention (for now) moves to a random location
        self.attention_x, self.attention_y = self.clip_attention(
            np.random.randint(self.size), np.random.randint(self.size))
        self.ep_step += 1
        return self.get_obs(), r, done, None
