import collections
import itertools
import os.path
import tkinter as tk

import gym
import gym.envs.registration
import gym.spaces

import numpy as np


from expground.types import Dict, AgentID, Any, List
from expground.utils.data import EpisodeKeys
from expground.envs import Environment


UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
ROTATE_RIGHT = 4
ROTATE_LEFT = 5
LASER = 6
NOOP = 7


BASE_DIR = os.path.dirname(os.path.abspath(__file__))


class _GatheringEnv(gym.Env):
    metadata = {"render.modes": ["human"]}
    scale = 20
    viewbox_width = 10
    viewbox_depth = 20
    padding = max(viewbox_width // 2, viewbox_depth - 1)
    agent_colors = ["red", "yellow"]

    def _text_to_map(self, text):
        m = [list(row) for row in text.splitlines()]
        l = len(m[0])
        for row in m:
            if len(row) != l:
                raise ValueError("the rows in the map are not all the same length")

        def pad(a):
            return np.pad(a, self.padding + 1, "constant")

        a = np.array(m).T
        self.initial_food = pad(a == "O").astype(np.int)
        self.walls = pad(a == "#").astype(np.int)

    def __init__(self, n_agents=1, map_name="default"):
        self.n_agents = n_agents
        self.root = None
        if not os.path.exists(map_name):
            expanded = os.path.join(BASE_DIR, "maps", map_name + ".txt")
            if not os.path.exists(expanded):
                raise ValueError("map not found: " + map_name)
            map_name = expanded
        with open(map_name) as f:
            self._text_to_map(f.read().strip())
        self.width = self.initial_food.shape[0]
        self.height = self.initial_food.shape[1]
        self.state_size = self.viewbox_width * self.viewbox_depth * 4
        self.action_space = gym.spaces.MultiDiscrete([8] * n_agents)
        self.observation_space = gym.spaces.MultiDiscrete(
            [[2] * self.state_size] * n_agents
        )
        self._spec = gym.envs.registration.EnvSpec(**_spec)
        self.reset()
        self.done = False

    def reset(self) -> Any:
        return self._reset()

    def _step(self, action_n):
        assert len(action_n) == self.n_agents
        action_n = [NOOP if self.tagged[i] else a for i, a in enumerate(action_n)]
        self.beams[:] = 0
        movement_n = [(0, 0) for a in action_n]
        for i, (a, orientation) in enumerate(zip(action_n, self.orientations)):
            if a not in [UP, DOWN, LEFT, RIGHT]:
                continue
            # a is relative to the agent's orientation, so add the orientation
            # before interpreting in the global coordinate system.
            #
            # This line is really not obvious to read. Replace it with something
            # clearer if you have a better idea.
            a = (a + orientation) % 4
            movement_n[i] = [
                (0, -1),  # up/forward
                (1, 0),  # right
                (0, 1),  # down/backward
                (-1, 0),  # left
            ][a]
        next_locations = [a for a in self.agents]
        next_locations_map = collections.defaultdict(list)
        for i, ((dx, dy), (x, y)) in enumerate(zip(movement_n, self.agents)):
            if self.tagged[i]:
                continue
            next_ = ((x + dx), (y + dy))
            if self.walls[next_]:
                next_ = (x, y)
            next_locations[i] = next_
            next_locations_map[next_].append(i)
        for overlappers in next_locations_map.values():
            if len(overlappers) > 1:
                for i in overlappers:
                    next_locations[i] = self.agents[i]
        self.agents = next_locations

        for i, act in enumerate(action_n):
            if act == ROTATE_RIGHT:
                self.orientations[i] = (self.orientations[i] + 1) % 4
            elif act == ROTATE_LEFT:
                self.orientations[i] = (self.orientations[i] - 1) % 4
            elif act == LASER:
                self.beams[self._viewbox_slice(i, 5, 20, offset=1)] = 1

        obs_n = self.state_n
        reward_n = [0 for _ in range(self.n_agents)]
        done_n = [self.done] * self.n_agents
        info_n = [{}] * self.n_agents

        self.food = (self.food + self.initial_food).clip(max=1)

        for i, a in enumerate(self.agents):
            if self.tagged[i]:
                continue
            if self.food[a] == 1:
                self.food[a] = -15
                reward_n[i] = 1
            if self.beams[a]:
                self.tagged[i] = 25

        for i, tag in enumerate(self.tagged):
            if tag == 1:
                # Relocate to a respawn point.
                for spawn_point in self.spawn_points:
                    if spawn_point not in self.agents:
                        self.agents[i] = spawn_point
                        break

        self.tagged = [max(i - 1, 0) for i in self.tagged]

        return obs_n, reward_n, done_n, info_n

    def _viewbox_slice(self, agent_index, width, depth, offset=0):
        left = width // 2
        right = left if width % 2 == 0 else left + 1
        x, y = self.agents[agent_index]
        return tuple(
            itertools.starmap(
                slice,
                (
                    ((x - left, x + right), (y - offset, y - offset - depth, -1)),  # up
                    ((x + offset, x + offset + depth), (y - left, y + right)),  # right
                    (
                        (x + left, x - right, -1),
                        (y + offset, y + offset + depth),
                    ),  # down
                    (
                        (x - offset, x - offset - depth, -1),
                        (y + left, y - right, -1),
                    ),  # left
                )[self.orientations[agent_index]],
            )
        )

    @property
    def state_n(self):
        agents = np.zeros_like(self.food)
        for i, a in enumerate(self.agents):
            if not self.tagged[i]:
                agents[a] = 1

        food = self.food.clip(min=0)
        s = np.zeros((self.n_agents, self.viewbox_width, self.viewbox_depth, 4))
        for i, (orientation, (x, y)) in enumerate(zip(self.orientations, self.agents)):
            if self.tagged[i]:
                continue
            full_state = np.stack(
                [food, np.zeros_like(food), agents, self.walls], axis=-1
            )
            full_state[x, y, 2] = 0

            xs, ys = self._viewbox_slice(i, self.viewbox_width, self.viewbox_depth)
            observation = full_state[xs, ys, :]

            s[i] = (
                observation
                if orientation in [UP, DOWN]
                else observation.transpose(1, 0, 2)
            )

        return s.reshape((self.n_agents, self.state_size))

    def _reset(self):
        self.food = self.initial_food.copy()

        p = self.padding
        self.walls[p:-p, p] = 1
        self.walls[p:-p, -p - 1] = 1
        self.walls[p, p:-p] = 1
        self.walls[-p - 1, p:-p] = 1

        self.beams = np.zeros_like(self.food)

        self.agents = [
            (i + self.padding + 1, self.padding + 1) for i in range(self.n_agents)
        ]
        self.spawn_points = list(self.agents)
        self.orientations = [UP for _ in self.agents]
        self.tagged = [0 for _ in self.agents]

        return self.state_n

    def _close_view(self):
        if self.root:
            self.root.destroy()
            self.root = None
            self.canvas = None
        self.done = True

    def _render(self, mode="human", close=False):
        if close:
            self._close_view()
            return

        canvas_width = self.width * self.scale
        canvas_height = self.height * self.scale

        if self.root is None:
            self.root = tk.Tk()
            self.root.title("Gathering")
            self.root.protocol("WM_DELETE_WINDOW", self._close_view)
            self.canvas = tk.Canvas(self.root, width=canvas_width, height=canvas_height)
            self.canvas.pack()

        self.canvas.delete(tk.ALL)
        self.canvas.create_rectangle(0, 0, canvas_width, canvas_height, fill="black")

        def fill_cell(x, y, color):
            self.canvas.create_rectangle(
                x * self.scale,
                y * self.scale,
                (x + 1) * self.scale,
                (y + 1) * self.scale,
                fill=color,
            )

        for x in range(self.width):
            for y in range(self.height):
                if self.beams[x, y] == 1:
                    fill_cell(x, y, "yellow")
                if self.food[x, y] == 1:
                    fill_cell(x, y, "green")
                if self.walls[x, y] == 1:
                    fill_cell(x, y, "grey")

        for i, (x, y) in enumerate(self.agents):
            if not self.tagged[i]:
                fill_cell(x, y, self.agent_colors[i])

        if False:
            # Debug view: see the first player's viewbox perspective.
            p1_state = self.state_n[0].reshape(
                self.viewbox_width, self.viewbox_depth, 4
            )
            for x in range(self.viewbox_width):
                for y in range(self.viewbox_depth):
                    food, me, other, wall = p1_state[x, y]
                    assert sum((food, me, other, wall)) <= 1
                    y_ = self.viewbox_depth - y - 1
                    if food:
                        fill_cell(x, y_, "green")
                    elif me:
                        fill_cell(x, y_, "cyan")
                    elif other:
                        fill_cell(x, y_, "red")
                    elif wall:
                        fill_cell(x, y_, "gray")
            self.canvas.create_rectangle(
                0,
                0,
                self.viewbox_width * self.scale,
                self.viewbox_depth * self.scale,
                outline="blue",
            )

        self.root.update()

    def _close(self):
        self._close_view()

    def __del__(self):
        self.close()


_spec = {
    "id": "Gathering-v0",
    "entry_point": _GatheringEnv,
    "reward_threshold": 100,
}


class Gathering(Environment):
    def __init__(self, **configs):
        super(Gathering, self).__init__(**configs)

        env_id = self._configs["env_id"]
        self.is_sequential = False
        self._env = _GatheringEnv(**self._configs["scenario_config"])
        self._env.state_n

        n_agents = self._env.n_agents
        action_space = gym.spaces.Discrete(self.env.action_space.nvec[1])
        observation_space = gym.spaces.Box(
            low=self.env.observation_space.sample().min(),
            high=self.env.observation_space.sample().max(),
            shape=self.env.observation_space.nvec[0].shape,
            dtype=self.env.observation_space.dtype,
        )
        # gym.spaces.Box(
        #     low=,
        #     high=,
        #     dtype=self.env.observation_space.np.dtype
        # )
        self._possible_agents = ["agent_{}".format(i) for i in range(n_agents)]
        self._action_spaces = dict.fromkeys(self._possible_agents, action_space)
        self._observation_spaces = dict.fromkeys(
            self._possible_agents, observation_space
        )
        self._trainable_agents = self._possible_agents

    @property
    def possible_agents(self):
        return self._possible_agents

    @property
    def action_spaces(self) -> Dict[AgentID, gym.Space]:
        return self._action_spaces

    @property
    def observation_spaces(self):
        return self._observation_spaces

    def step(self, actions: Dict[AgentID, Any]):
        actions = [actions[aid] for aid in self.possible_agents]
        obs_n, reward_n, done_n, info_n = self._env._step(actions)
        obs_dict = dict(zip(self.possible_agents, obs_n))
        reward_dict = dict(zip(self.possible_agents, reward_n))
        done_dict = dict(zip(self.possible_agents, done_n))
        info_dict = dict(zip(self.possible_agents, info_n))

        return {
            EpisodeKeys.OBSERVATION.value: obs_dict,
            EpisodeKeys.REWARD.value: reward_dict,
            EpisodeKeys.DONE.value: done_dict,
            EpisodeKeys.INFO.value: info_dict,
        }

    def reset(self, **kwargs):
        obs_n = self._env.reset()
        return {EpisodeKeys.OBSERVATION.value: dict(zip(self.possible_agents, obs_n))}


if __name__ == "__main__":
    env = Gathering(
        env_id="gathering", scenario_config={"n_agents": 2, "map_name": "default"}
    )

    aspaces = env.action_spaces
    print(aspaces, env.observation_spaces)
    rets = env.reset()

    while True:
        rets = env.step({k: v.sample() for k, v in aspaces.items()})
        done = rets[EpisodeKeys.DONE.value]
        done = any(done.values())
        # print("rewardfe", rets[EpisodeKeys.REWARD.value])
        if done:
            rets = env.reset()
