import importlib
from enum import IntEnum
from typing import Any, Dict, List, Tuple, Union

import gymnasium as gym  # type:ignore
import numpy as np  # type:ignore
from gymnasium import spaces

from environments import Agent, Goal  # type:ignore

# Global elements
_LAYER_AGENTS = 0


class Actions(IntEnum):
    Up = 0
    Down = 1
    Left = 2
    Right = 3
    Noop = 4


class MOResourceGathering(gym.Env):  # type:ignore
    def __init__(
        self,
        env_size: str, 
        sampling_strategy:str,
        n_agents: int,
        reward_dim: int,
        agents_finite_size_bags:bool, 
        agents_bags_size: List[int],
        partial_observability: bool = False,
        centralized_controller: bool = False,
        sensor_range: int = 4,
        local_reward: bool = False,
        agent_specific_objectives: bool = False,
        agents_objectives: List[List[int]] = [],
        normalize_reward: bool = True,
    ):
        self.current_step = 0
        self.env_size = env_size
        if env_size == "small":
            self.grid_size = (10, 10)
            self.min_reward_amplitude = 3
            self.max_reward_amplitude = 7
            self.__max_steps__ = 1e3
        elif env_size == "medium":
            self.grid_size = (15,15)
            self.min_reward_amplitude = 5
            self.max_reward_amplitude = 12
            self.__max_steps__ = 5e3
        elif env_size == "large":
            self.grid_size = (20,20)
            self.min_reward_amplitude = 10
            self.max_reward_amplitude = 20
            self.__max_steps__ = 2e4
        else: 
            raise Exception("Unknown environment size")
        
        if(sampling_strategy in ["uniform", "debug", "objective_clusters", "geographical_clusters"]):
            self.sampling_strategy = sampling_strategy
        else:
            raise Exception("Unknown sampling strategy")    
        self.env_name = f"resource_gathering_{env_size}_{sampling_strategy}_{n_agents}a_{reward_dim}r"
        self.n_agents = n_agents
        self.reward_dim = reward_dim
        self.partial_observability = partial_observability
        self.centralized_controller = centralized_controller
        self.local_reward = local_reward
        self.agent_specific_objectives = agent_specific_objectives
        # sensor range is only used when partial observability is enabled
        self.sensor_range = int(self.partial_observability) * sensor_range
        self.normalize_reward = normalize_reward

        self.reward_amplitude = np.zeros(self.reward_dim)
        self.agents_finite_size_bags = agents_finite_size_bags
        self.agents_bags_size = agents_bags_size
        self.reward_space = spaces.Box(
            low=np.array([0 for _ in range(reward_dim)]),
            high=np.array(
                [self.max_reward_amplitude for _ in range(reward_dim)]
            ),
            dtype=np.float32,
        )
        self.goal_layers = list(range(1, 1 + self.reward_dim))
        self.agents_objectives = (
            agents_objectives
            if agent_specific_objectives
            else [list(range(self.reward_dim)) for _ in range(n_agents)]
        )

        # Environment is described using a grid
        # 4 layers + n_objectives: one for each entity of the environment.
        # Layer 0: Describes the positions of the agents.
        # Layer 1 - 1 + n_objectives: Describes the positions of the chests of each objective.

        self.grid = np.zeros((1 + self.reward_dim, *self.grid_size))

        if self.centralized_controller:
            self.action_space = spaces.Discrete(len(Actions) ** n_agents)

        else:
            self.action_space = spaces.Tuple(
                tuple(n_agents * [spaces.Discrete(len(Actions))])
            )

        if self.partial_observability:
            # (sensor_range + 1)*(sensor_range +1) grids for each entity of the environment
            # the content of the agent's backpack, the remaining space in the agent's backpack
            # and the agent's normalized position
            if(agent_specific_objectives):
                self.state_space_dim = [(
                    (self.sensor_range + 1)
                    * (self.sensor_range + 1)
                    * (1 + len(self.agents_objectives[i]))
                    + len(self.agents_objectives[i])
                    + 3
                ) for i in range(self.n_agents)]
                self.observation_space = spaces.Tuple(
                    tuple(
                       
                        [
                            spaces.Box(
                                np.array([0] * self.state_space_dim[i]),
                                np.array([1] * self.state_space_dim[i]),
                            )

                            for i in range(self.n_agents)       
                        ]
                    )
                )
            else:
                self.state_space_dim = (
                    (self.sensor_range + 1)
                    * (self.sensor_range + 1)
                    * (1 + self.reward_dim)
                    + self.reward_dim
                    + 3
                )

                self.observation_space = spaces.Tuple(
                    tuple(
                        n_agents
                        * [
                            spaces.Box(
                                np.array([0] * self.state_space_dim),
                                np.array([1] * self.state_space_dim),
                            )
                        ]
                    )
                )
        else:
            # prod(grid_size) grid for each entity of the environment
            # the content of the agent's backpack, the remaining space in the agent's backpack
            self.state_space_dim = (
                np.prod(self.grid_size) * (1 + self.reward_dim)
                + (self.reward_dim + 1) * self.n_agents
            )
            self.observation_space = spaces.Box(
                np.array([0] * self.state_space_dim),
                np.array([1] * self.state_space_dim),
            )

        self.agents: List[Agent] = []
        self.goals: List[List[Goal]] = [[] for _ in range(self.reward_dim)]

        self._rendering_initialized = False

        self.agent_order: List[int] = list(range(n_agents))
        self.viewer = None

        self.collected_chests = np.zeros(self.reward_dim)

    def reset(self) -> Tuple[Union[Tuple[Any], np.ndarray], Dict[str, Any]]:

        def sample_distinct_coordinates():
            """
            Sample n distinct 2D coordinates from a d x d grid.

            Parameters:
            - n: Number of distinct coordinates to sample.
            - d: Size of the grid (d x d).

            Returns:
            - Array of n distinct 2D coordinates.
            """
            # Create a grid of coordinates
            ## CHANGE THIS to make easy and hard environments
            grid = np.array(
                [
                    (i, j)
                    for i in range(1, self.grid_size[0] - 1)
                    for j in range(1, self.grid_size[1] - 1)
                ]
            )

            # Check if we can sample n distinct coordinates
            if self.reward_dim > len(grid):
                raise ValueError(
                    "Cannot sample more distinct coordinates than available in the grid."
                )

            # Shuffle the grid and select the first n coordinates
            np.random.shuffle(grid)
            return grid[: self.reward_dim]

        def geographic_sampling(cluster_centroids):
            pass
        
        def objective_sampling(cluster_centroids):
            pass

        def uniform_sampling():
            chests_positions = []
            all_samples = []
            for i in range(self.reward_dim):
                samples = []
                l_samples = []
                # Select a component based on the mixing coefficients
                # sigma = (
                #     0.25 * self.grid_size[0],
                #     0.25 * self.grid_size[1],
                # )
                # covariance = np.array([[sigma[0], 0], [0, sigma[1]]])
                # # Sample from the selected multivariate normal distribution
                while len(samples) < self.reward_amplitude[i]:
                    # sample = np.random.multivariate_normal(means[i], covariance)
                    x_coord =np.random.randint(0, self.grid_size[0])
                    y_coord = np.random.randint(0, self.grid_size[1])
                    # Check if the sample is within the truncation bounds
                    sample = np.array([x_coord, y_coord]).astype(int)
                    t_sample = tuple(sample)
                    if (
                        0 <= sample[1] < self.grid_size[0]
                        and 0 <= sample[0] < self.grid_size[1]
                        and t_sample not in l_samples
                        and t_sample not in all_samples
                    ):
                        samples.append(sample)
                        l_samples.append(t_sample)
                all_samples += l_samples
                chests_positions.append(samples)
            return chests_positions, all_samples
        
        def sample_positions(means):
            if(self.sampling_strategy =="debug"):
                self.reward_dim = 3
                self.reward_amplitude = np.array([3] * self.reward_dim)
                chests_positions = np.array([[[9,9], [9,8], [8,9]], [[0,9], [0,8], [1,9]], [[9,0], [9,1], [8,0]]])
                agent_positions = np.array([[0,0], [0,1], [1,0], [1,1]]) 
                return chests_positions, agent_positions[: self.n_agents]
            chests_positions = []
            all_samples = []
            if(self.sampling_strategy == "uniform"):
                chests_positions, all_samples = uniform_sampling()
            all_positions = [
                (i, j)
                for i in range(self.grid_size[0])
                for j in range(self.grid_size[1])
            ]
            mask = np.array([not (i in all_samples) for i in all_positions])
            available_positions = np.array(all_positions)[mask]
            np.random.shuffle(available_positions)
            return chests_positions, available_positions[: self.n_agents]

        def generate_chests_and_agents():
            centroids = sample_distinct_coordinates()
            chests_positions, agent_positions = sample_positions(centroids)
            self.goals = [
                [
                    Goal(
                        "goal",
                        chests_positions[i][j][0],
                        chests_positions[i][j][1],
                    )
                    for j in range(len(chests_positions[i]))
                ]
                for i in range(self.reward_dim)
            ]

            for i in range(self.reward_dim):
                for j in range(len(chests_positions[i])):
                    self.grid[
                        self.goal_layers[i],
                        chests_positions[i][j][1],
                        chests_positions[i][j][0],
                    ] = 1

            for i in range(self.n_agents):
                x, y = agent_positions[i]
                self.agents.append(
                    Agent(i, x, y, len(self.agents_objectives[i]), self.agents_bags_size[i])
                )
                self.agents
                self.grid[_LAYER_AGENTS, y, x] = 1

        # Grid wipe
        self.grid = np.zeros((1 + self.reward_dim, *self.grid_size))
        self.reward_amplitude = np.random.randint(
            self.min_reward_amplitude,
            self.max_reward_amplitude + 1,
            (self.reward_dim),
        )
        self.agents = []
        generate_chests_and_agents()

        self.current_step = 0
        self.collected_chests = np.zeros(self.reward_dim)
        return self._get_obs(), {
            "total_demand":self.reward_amplitude, 
            "agents_bags_size":self.agents_bags_size 
        }

    def step(self, actions: Union[np.ndarray, int]) -> Tuple[
        Union[Tuple[Any], np.ndarray],
        Union[np.ndarray | List[np.ndarray]],
        bool,
        bool,
        Dict[Any, Any],
    ]:
        # Action can be an int in the case of a centralized controller or a vector of actions in the multi-agent case.
        # In the case of a centralized controller we convert the joint action into a list of agent specific actions.
        if isinstance(actions, int) or isinstance(actions, np.int64):
            if self.centralized_controller:
                agent_actions: List[int] = []
                for i in range(self.n_agents):
                    res = actions % len(Actions)
                    actions //= len(Actions)
                    agent_actions.append(res)

                agent_specific_actions = list(agent_actions[::-1])
            else:
                raise ValueError(
                    "Integer actions are only allowed when using a centralized controller"
                )
        else:
            agent_specific_actions = actions
        self.current_step += 1
        np.random.shuffle(self.agent_order)

        for i in self.agent_order:
            proposed_pos = [self.agents[i].x, self.agents[i].y]
            if agent_specific_actions[i] == 0:
                proposed_pos[1] -= 1
                if not self._detect_collision(proposed_pos):
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 0
                    self.agents[i].y -= 1
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 1

            elif agent_specific_actions[i] == 1:
                proposed_pos[1] += 1
                if not self._detect_collision(proposed_pos):
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 0
                    self.agents[i].y += 1
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 1

            elif agent_specific_actions[i] == 2:
                proposed_pos[0] -= 1
                if not self._detect_collision(proposed_pos):
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 0
                    self.agents[i].x -= 1
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 1

            elif agent_specific_actions[i] == 3:
                proposed_pos[0] += 1
                if not self._detect_collision(proposed_pos):
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 0
                    self.agents[i].x += 1
                    self.grid[
                        _LAYER_AGENTS, self.agents[i].y, self.agents[i].x
                    ] = 1

        reward_per_agent, total_reward = self._compute_reward()
        normalized_reward_per_agent, _ = reward_per_agent
        _, unnomalized_total_reward = total_reward

        # Is episode over
        # Episode is only over when all agents are full
        agents_full = [a.curr_load == a.max_load for a in self.agents]
        if self.centralized_controller:
           agents_full = all(agents_full) 
           
        return (
            self._get_obs(),
            normalized_reward_per_agent,
            agents_full or self.__max_steps__ == self.current_step,
            self.__max_steps__ == self.current_step,
            {
                "joint_reward": unnomalized_total_reward,
            },
        )

    def _compute_reward(self) -> List[np.ndarray]:
        def collect_chest(agent_obj_index: int, o: int, agent: Agent, goal: Goal):
            self.collected_chests[o] += 1
            agent.curr_load += 1
            agent.bag[agent_obj_index] += 1
            goal.achieved = True
            self.grid[self.goal_layers[o], goal.y, goal.x] = 0

        if self.centralized_controller:
            reward_for_centralized_agent = np.zeros(self.reward_dim)
            for a, agent in enumerate(self.agents):
                for o in range(self.reward_dim):
                    for goal in self.goals[o]:
                        if (
                            not goal.achieved
                            and ([agent.x, agent.y] == [goal.x, goal.y])
                            and (agent.curr_load < agent.max_load)
                        ):
                            collect_chest(o,o, agent, goal)
                            reward_for_centralized_agent[o] += 1
                            break
            unnormalaized_reward = np.copy(reward_for_centralized_agent)
            if self.normalize_reward:
                for i in range(self.reward_dim):
                    reward_for_centralized_agent[i] /= self.max_reward_amplitude

            return ([reward_for_centralized_agent], unnormalaized_reward), (
                [reward_for_centralized_agent],
                unnormalaized_reward,
            )

        else:
            # computing reward
            total_r = np.zeros(self.reward_dim)
            reward_per_agent = [
                np.zeros(len(self.agents_objectives[i]), dtype=np.float32)
                for i in range(self.n_agents)
            ]
            for a, agent in enumerate(self.agents):
                for o, obj in enumerate(self.agents_objectives[a]):
                    for goal in self.goals[obj]:
                        if (
                            not goal.achieved
                            and ([agent.x, agent.y] == [goal.x, goal.y])
                            and (agent.curr_load < agent.max_load)
                        ):
                            collect_chest(o,obj, agent, goal)
                            total_r[obj] += 1
                            reward_per_agent[a][o] += 1
                            break

            if not self.local_reward:
                for i in range(self.n_agents):
                    reward_per_agent[i] = total_r[self.agents_objectives[i]]
            unnormalized_total_reward = np.copy(total_r)
            unnormalized_agents_reward = np.copy(reward_per_agent)

            if self.normalize_reward:
                for j in range(self.n_agents):
                    for i, _ in enumerate(self.agents_objectives[j]):
                        reward_per_agent[j][i] /= self.max_reward_amplitude
                for i in range(self.reward_dim):
                    total_r[i] /= self.max_reward_amplitude
            return (reward_per_agent, unnormalized_agents_reward), (
                total_r,
                unnormalized_total_reward,
            )

    def _detect_collision(self, proposed_position: List[int]) -> bool:
        """Need to check for collision with (1) grid edge, (2) walls, (3) closed doors (4) other agents"""
        # Grid edge
        if np.any(
            [
                proposed_position[0] < 0,
                proposed_position[1] < 0,
                proposed_position[0] >= self.grid_size[1],
                proposed_position[1] >= self.grid_size[0],
            ]
        ):
            return True

        # Other agents
        for agent in self.agents:

            if proposed_position == [agent.x, agent.y]:
                return True

        return False

    def _get_observable_positions(self, x, y, s_range):
        observable_positions = []
        
        for i in range(x - s_range, x + s_range + 1):
            for j in range(y - s_range, y + s_range + 1):
                if  self.grid_size[0]> i >= 0 and self.grid_size[1] > j >= 0:  # Manhattan distance check
                    observable_positions.append((i, j))
        
        return observable_positions

    def _get_obs(self) -> Union[Tuple[np.ndarray], np.ndarray]:
        obs = []
        if self.partial_observability:
            for a, agent in enumerate(self.agents):
                space_left = (agent.max_load - agent.curr_load) / agent.max_load
                agent_bag_content = [i / agent.max_load for i in agent.bag]
                x, y = agent.x, agent.y
                pad = self.sensor_range // 2
                x_left = max(0, x - pad)
                x_right = min(self.grid_size[1] - 1, x + pad)
                y_up = max(0, y - pad)
                y_down = min(self.grid_size[0] - 1, y + pad)

                x_left_padding = pad - (x - x_left)
                x_right_padding = pad - (x_right - x)
                y_up_padding = pad - (y - y_up)
                y_down_padding = pad - (y_down - y)

                # other_agents_bag_content = []
                # other_agents_space_left = []
                # curr_agent_observable_positions = self._get_observable_positions(x, y, pad)
                # for o, other_agent in enumerate(self.agents):
                #     if [other_agent.x, other_agent.y] in  curr_agent_observable_positions:
                #         other_agent_space_left = (other_agent.max_load - other_agent.curr_load)/other_agent.max_load
                #         other_agent_bag_content = [i / other_agent.max_load for i in other_agent.bag]     
                #         other_agents_space_left.append(other_agent_space_left)
                #         other_agents_bag_content.append(other_agent_bag_content)
                #     else:
                #         other_agents_space_left.append(agent.beliefs["bag_space"])
                #         other_agents_bag_content.append(other_agent_bag_content)

                                    
                # When the agent's vision, as defined by self.sensor_range, goes off of the grid, we
                # pad the grid-version of the observation. For all objects but walls, we pad with zeros.
                # For walls, we pad with ones, as edges of the grid act in the same way as walls.
                # For padding, we follow a simple pattern: pad left, pad right, pad up, pad down
                # Agents
                _agents = self.grid[
                    _LAYER_AGENTS, y_up : y_down + 1, x_left : x_right + 1
                ]
                _agents = np.concatenate(
                    (np.zeros((_agents.shape[0], x_left_padding)), _agents),
                    axis=1,
                )

                _agents = np.concatenate(
                    (_agents, np.zeros((_agents.shape[0], x_right_padding))),
                    axis=1,
                )
                _agents = np.concatenate(
                    (np.zeros((y_up_padding, _agents.shape[1])), _agents),
                    axis=0,
                )
                _agents = np.concatenate(
                    (_agents, np.zeros((y_down_padding, _agents.shape[1]))),
                    axis=0,
                )
                _agents = _agents.reshape(-1)

                # Goal
                _goals = []
                for i in self.agents_objectives[a]:
                    # if i in self.agents_objectives[a]:
                    _goal = self.grid[
                        self.goal_layers[i],
                        y_up : y_down + 1,
                        x_left : x_right + 1,
                    ]
                    # else:
                    #     _goal = np.zeros(
                    #         (y_down - y_up + 1, x_right - x_left + 1)
                    #     )
                    _goal = np.concatenate(
                        (np.zeros((_goal.shape[0], x_left_padding)), _goal),
                        axis=1,
                    )
                    _goal = np.concatenate(
                        (_goal, np.zeros((_goal.shape[0], x_right_padding))),
                        axis=1,
                    )
                    _goal = np.concatenate(
                        (np.zeros((y_up_padding, _goal.shape[1])), _goal),
                        axis=0,
                    )
                    _goal = np.concatenate(
                        (_goal, np.zeros((y_down_padding, _goal.shape[1]))),
                        axis=0,
                    )
                    _goal = _goal.reshape(-1)
                    _goals.append(_goal)

                # Concat
                obs.append(
                    np.concatenate(
                        (
                            _agents,
                            *_goals,
                            np.array(agent_bag_content),
                            np.array([space_left]),
                            np.array(
                                [x / self.grid_size[0], y / self.grid_size[1]]
                            ),
                        ),
                        axis=0,
                        dtype=np.float32,
                    )
                )

            # Return a concatenation of the states of all the agents when
            # using a centralized controller with partial observability
            # else return a tuple of n_agents observation.
            return (
                np.concatenate(obs, axis=0, dtype=np.float32)
                if self.centralized_controller
                else tuple(obs)
            )
        else:
            spaces_left, agents_bags = [], []
            for a, agent in enumerate(self.agents):
                spaces_left.append(
                    (agent.max_load - agent.curr_load) / agent.max_load
                )
                agents_bags.append([i / agent.max_load for i in agent.bag])

            return np.concatenate(
                (
                    self.grid.reshape(-1),
                    np.array(agents_bags).reshape(-1),
                    np.array(spaces_left),
                ),
                axis=0,
                dtype=np.float32,
            )

    def get_action_mask(self) -> List[np.ndarray]:
        # NOT FINISHED
        masks = []
        for agent in self.agents:
            # Agent can move if it won't result in a collision and
            # is not stepping on a plate that opens a door where another
            # agent is on the cell of that door
            mask = np.ones(self.action_space)
            agent_pos = [agent.x, agent.y]
            pos_after_up_action = [agent_pos[0], agent_pos[1] - 1]
            mask[0] = 1 - self._detect_collision(pos_after_up_action)

            pos_after_down_action = [agent_pos[0], agent_pos[1] + 1]
            mask[1] = 1 - self._detect_collision(pos_after_down_action)

            pos_after_left_action = [agent_pos[0] - 1, agent_pos[1]]
            mask[2] = 1 - self._detect_collision(pos_after_left_action)

            pos_after_right_action = [agent_pos[0] + 1, agent_pos[1]]
            mask[3] = 1 - self._detect_collision(pos_after_right_action)
            masks.append(mask)

        return masks

    def _init_render(self) -> None:
        from .rendering import Viewer  # type:ignore

        self.viewer = Viewer(self.grid_size, self.reward_dim)
        self._rendering_initialized = True

    def render(self, mode: str = "human"):  # type: ignore
        if not self._rendering_initialized:
            self._init_render()
        return self.viewer.render(self, mode == "rgb_array")  # type:ignore

    def close(self) -> None:
        if self.viewer:
            self.viewer.close()
