import abc
import random
from copy import copy
from itertools import chain
from typing import Sequence, Any, Tuple, List

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.envs.utils import add_posn


def flatten(iterator_to_flatten):
    return list(chain.from_iterable(iterator_to_flatten))


class FastGridWorld(MultiAgentSafetyEnv, abc.ABC):
    """
    A grid world with many pre-computed properties for extremely fast steps
    (well, as fast as you can get with python)

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.

    Agents have five actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left
    """

    def __init__(self, grid_posns, num_agents, start_idx, ending_idx, randomize_starts: bool = False,
                 collision_reward=-30, agents_bounce: bool = False, terminate_on_collision: bool = False):
        # Create and cache a bunch of useful arrays so that we don't need to recalculate them every time step
        self.grid_posns = grid_posns

        # posn -> idx
        self.grid_posn_inv = {pos: idx for idx, pos in enumerate(self.grid_posns)}

        # dir, idx -> idx: What position after moving in each direction one step
        self.loc_dir_next = [
            [self.grid_posn_inv.get(add_posn(pos, offset), idx) for idx, pos in enumerate(self.grid_posns)]
            for offset in [(0, -1), (1, 0), (0, 1), (-1, 0)]]

        self.num_agents = num_agents

        assert len(start_idx) == num_agents
        assert len(ending_idx) == num_agents

        self.start_idx = start_idx
        self.goal_idx = ending_idx

        self.randomize_starts = randomize_starts
        self.collision_cost = collision_reward
        self.agents_bounce = agents_bounce  # Instead of passing through each other
        self.terminate_on_collision = terminate_on_collision

    def ap_names(self) -> List[str]:
        return ["collision"]

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [self.state_space()] * self.num_agents

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(5)] * self.num_agents

    def state_space(self) -> spaces.Space:
        return spaces.MultiDiscrete([len(self.grid_posns)] * self.num_agents)

    def initial_state(self):
        if self.randomize_starts:
            starting_locs = random.sample(range(len(self.grid_posns)), self.num_agents)
        else:
            starting_locs = copy(self.start_idx)

        return starting_locs, self.project_obs(starting_locs)

    def get_next_loc(self, loc, act):
        if act == 0:
            return loc
        else:
            return self.loc_dir_next[act - 1][loc]

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        new_env_state = tuple(
            self.get_next_loc(loc, act) for (loc, act) in zip(environment_state, joint_action))

        def check_collisions_and_crossings(idx1, idx2):
            loc1 = environment_state[idx1]
            loc2 = environment_state[idx2]
            nloc1 = new_env_state[idx1]
            nloc2 = new_env_state[idx2]

            return (nloc1 == nloc2) or (nloc1 == loc2 and nloc2 == loc1)

        collisions_or_crossings = any(
            any(check_collisions_and_crossings(idx1, idx2) for idx2 in range(idx1))
            for idx1 in range(self.num_agents)
        )

        reached_goal = all(loc == goal for loc, goal in zip(new_env_state, self.goal_idx))
        done = reached_goal

        def rew_for_agent(loc, old_loc, action):
            if reached_goal:
                return 100
            elif loc == old_loc and action != 0:  # Hit a wall
                return -10
            else:
                return -1

        rewards = [rew_for_agent(loc, old_loc, action) for loc, old_loc, action in
                   zip(new_env_state, environment_state, joint_action)]

        if collisions_or_crossings:
            rewards = [rew + self.collision_cost for rew in rewards]
            if self.agents_bounce:
                new_env_state = environment_state

            if self.terminate_on_collision:
                done = True

        return new_env_state, self.project_obs(new_env_state), rewards, done, (not collisions_or_crossings)

    def project_obs(self, state) -> Sequence[Any]:
        return tuple([tuple(state)] * self.num_agents)
