import random
from typing import Sequence, Any, Tuple

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv


def add_posn(ab, cd):
    a, b = ab
    c, d = cd
    return (a + c, b + d)


class OctothorpeGridWorld(MultiAgentSafetyEnv):
    """
    A grid world laid out in the following shape:

     B C
    A####
     # #
    #####
     # #

    A represents the goal for agent 0, B for agent 1, and C for agent 2.
    The agents themselves are randomly placed.

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

     #
    ###
     A

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.
    The state space is (16 * 4)^3 = 262144, representing the position and direction of each agent.
    The observation space for each agent is 16 * 4 * 3^4 = 5184, representing the position and direction,
    along with whether each of the four spaces within its cone of visibility is empty, filled by an agent, or a wall.

    Agents have four actions.
    0 = Do nothing
    1 = Move forward
    2 = Turn left (counterclockwise)
    3 = Turn right (clockwise)

    Agent locations (A=10, B=11, ...)
     0 1
    23456
     7 8
    9ABCD
     E F

    Agent directions:
    0 = Facing up
    1 = Right
    2 = Down
    3 = Left
    """

    def __init__(self):
        # Create and cache a bunch of useful arrays so that we don't need to recalculate them every time step
        self.grid_posns = [
            (1, 0),
            (3, 0),
            (0, 1),
            (1, 1),
            (2, 1),
            (3, 1),
            (4, 1),
            (1, 2),
            (3, 2),
            (0, 3),
            (1, 3),
            (2, 3),
            (3, 3),
            (4, 3),
            (1, 4),
            (3, 4)
        ]

        self.grid_posn_inv = {pos: idx for idx, pos in enumerate(self.grid_posns)}

        self.loc_dir_next = [
            [self.grid_posn_inv.get(add_posn(pos, offset), idx) for idx, pos in enumerate(self.grid_posns)]
            for offset in [(0, -1), (1, 0), (0, 1), (-1, 0)]]
        """
        Creates
        [
            # 0, 1, 2, 3, 4, 5, 6,  7,  8,  9, 10, 11, 12, 13, 14, 15   <-- From here
            [0, 1, 2, 0, 4, 1, 6, 3, 5, 9, 7, 11, 8, 13, 10, 12],  # Facing up, what is the next state forward?
            [0, 1, 3, 4, 5, 6, 6, 7, 8, 10, 11, 12, 13, 13, 14, 15],  # When going right
            [3, 5, 2, 7, 4, 8, 6, 10, 12, 9, 14, 11, 15, 13, 14, 15],  # When going down
            [0, 1, 2, 2, 3, 4, 5, 7, 8, 9, 9, 10, 11, 12, 14, 15]  # When going left
        ]
        """

        self.obs_offsets = [
            #    f-l       f         f-r       2f
            [(-1, -1), (0, -1), (1, -1), (0, -2)],  # facing up
            [(1, -1), (1, 0), (1, 1), (2, 0)],  # right
            [(1, 1), (0, 1), (-1, 1), (0, 2)],  # down
            [(-1, 1), (-1, 0), (-1, -1), (-2, 0)]  # left
        ]

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.MultiDiscrete([16, 4, 3, 3, 3, 3])] * 3

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(4)] * 3

    def state_space(self) -> spaces.Space:
        return spaces.MultiDiscrete([16, 4, 16, 4, 16, 4])

    def initial_state(self):
        loc1, loc2, loc3 = random.sample(range(16), 3)
        dir1, dir2, dir3 = random.choices(range(4), k=3)

        state = (loc1, dir1, loc2, dir2, loc3, dir3)
        return state, self.project_obs(state)

    def get_next_loc_dir(self, loc, dir, act):
        if act == 0:
            return loc, dir
        elif act == 1:
            return self.loc_dir_next[dir][loc], dir
        elif act == 2:
            return loc, (dir - 1) % 4
        else:
            return loc, (dir + 1) % 4

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        loc1, dir1, loc2, dir2, loc3, dir3 = environment_state
        act1, act2, act3 = joint_action

        nloc1, ndir1 = self.get_next_loc_dir(loc1, dir1, act1)
        nloc2, ndir2 = self.get_next_loc_dir(loc2, dir2, act2)
        nloc3, ndir3 = self.get_next_loc_dir(loc3, dir3, act3)

        collisions = (nloc1 == nloc2) or (nloc2 == nloc3) or (nloc3 == nloc1)
        crossings = (nloc1 == loc2 and nloc2 == loc1) or (nloc2 == loc3 and nloc3 == loc2) or (
                nloc3 == loc1 and nloc1 == loc3)

        rewards = (1 if (nloc1 == 2) else 0, 1 if (nloc2 == 0) else 0, 1 if (nloc3 == 0) else 0)
        done = sum(rewards) == 3

        new_state = (nloc1, ndir1, nloc2, ndir2, nloc3, ndir3)

        return (new_state, self.project_obs(new_state), rewards, done, (not collisions or crossings))

    def project_single_obs(self, loc, dir, other_locs):
        def get_single_offset_obs(offset):
            offset_loc = self.grid_posn_inv.get(add_posn(self.grid_posns[loc], offset), -1)
            if offset_loc == -1:
                return 2
            elif offset_loc in other_locs:
                return 1
            else:
                return 0

        ahead_values = [get_single_offset_obs(offset) for offset in self.obs_offsets[dir]]
        return loc, dir, *ahead_values

    def project_obs(self, state) -> Sequence[Any]:
        nloc1, ndir1, nloc2, ndir2, nloc3, ndir3 = state
        return (self.project_single_obs(nloc1, ndir1, (nloc2, nloc3)),
                self.project_single_obs(nloc2, ndir2, (nloc1, nloc3)),
                self.project_single_obs(nloc3, ndir3, (nloc1, nloc2)))
