import random
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import torch as t

from .base_simulator import BaseSimulator


class Navigation(BaseSimulator):
    name = "Navigation"

    N = 20
    n_actions = 4
    n_exits = 2
    max_steps = 100
    failed = -1000.0

    ACTIONS = ["EAST", "SOUTH", "WEST", "NORTH"]

    walls = np.asarray(
        [
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        ],
    )

    def reset(self):
        walls = deepcopy(self.walls)

        # keep on sampling till we get at least one opening out of the center hall.
        while (np.sum(walls[5, 6:14]) + np.sum(walls[14, 6:14]) + np.sum(walls[6:14, 5]) + np.sum(walls[6:14, 14])) != (
            32 - self.n_exits
        ):
            walls = deepcopy(self.walls)
            if random.random() < 0.5:
                walls[5, random.randint(6, 13)] = 0
            if random.random() < 0.5:
                walls[14, random.randint(6, 13)] = 0
            if random.random() < 0.5:
                walls[random.randint(6, 13), 5] = 0
            if random.random() < 0.5:
                walls[random.randint(6, 13), 14] = 0

        goal = (10, 10)
        while 5 <= goal[0] <= 14 and 5 <= goal[1] <= 14:
            goal = (random.randint(1, 18), random.randint(1, 18))
        start = (random.randint(7, 12), random.randint(7, 12))

        self.state = (
            walls.reshape(-1),
            goal[0] * self.N + goal[1],
            start[0] * self.N + start[1],
        )
        self.timestep = 0
        return self.state

    def step(self, action):
        self.timestep += 1
        grid, goal, robot = self.state

        x, y = robot // self.N, robot % self.N

        if action == 0:
            y = min(y + 1, self.N - 1)
        elif action == 1:
            x = min(x + 1, self.N - 1)
        elif action == 2:
            y = max(y - 1, 0)
        elif action == 3:
            x = max(x - 1, 0)
        else:
            print(f"Invalid action passed: {action}")

        robot_next = x * self.N + y

        # if collision with wall
        if grid[robot_next] == 1:
            self.state = (grid, goal, robot)
            return self.state, self.failed, True
        else:
            self.state = (grid, goal, robot_next)
            return (
                self.state,
                -1,
                (robot_next == goal) or self.timestep >= self.max_steps,
            )

    def render(
        self,
        as_image=True,
    ):
        grid, goal, robot = self.state

        string = np.array([" "] * self.N * self.N)
        string[np.array(grid) == 1] = "#"
        string[goal] = "X"
        string[robot] = "@"

        string = string.reshape(self.N, self.N)
        if as_image:
            image = np.zeros((self.N, self.N, 3))
            image[string == "#", 0] = 0.5

            image[string == "X", 1] = 0.85

            image[string == "@", 0] = 1
            image[string == "@", 1] = 1
            image[string == "@", 2] = 0

            plt.clf()
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])
            plt.draw()
            plt.pause(0.0001)
            plt.show()

        else:
            for i in range(self.N):
                for j in range(self.N):
                    print(string[i, j], end=" ")
                print()

    def is_solved(self):
        grid, goal, robot = self.state
        return goal == robot

    def state_tensor(self):
        grid, goal, agent = self.state

        t_state = t.zeros(3, self.N * self.N).float()
        t_state[0] += grid
        t_state[1, goal] = 1
        t_state[2, agent] = 1

        return t_state.reshape(1, 3, self.N, self.N)
