from k_level_policy_gradients.src.core.environment import Environment, MDPInfo
from k_level_policy_gradients.src.utils.spaces import *
from k_level_policy_gradients.src.utils.viewer import Viewer


class Attraction(Environment):
    """
    Creates an environment where n agents want to come together as soon as possible.
    """

    def __init__(
        self,
        n_agents=2,
        obs_state=False,
        wrap_position=False,
        horizon=100,
        gamma=0.99,
        bool_render=False,
    ):
        """Create a new multi-agent Attraction env compatible with MushoomRL.

        Arguments:
            n_agents (int): Number of agents in the environment.

        state has the shape (2 * n_agents,) where each pair of elements represents the x and y coordinates of an agent.

        e.g. for n_agents=2, state = [x1, x2, y1, y2]
        """
        self._n_agents = n_agents
        self._box_size = 10
        self._obs_state = obs_state
        self._wrap_position = wrap_position
        self._bool_render = bool_render
        self._dt = 0.01

        # Set the state, observation, and action spaces
        self.action_space = [Box(-1, 1, shape=(2,)) for _ in range(self._n_agents)]
        if self._obs_state:
            state_space = Box(
                -np.inf, np.inf, shape=((self._n_agents * (self._n_agents - 1)),)
            )
        else:
            state_space = Box(-np.inf, np.inf, shape=(2 * (self._n_agents),))
        observation_space = [
            Box(-np.inf, np.inf, shape=(2 * (self._n_agents - 1),))
            for _ in range(self._n_agents)
        ]

        window_size = 1000
        self._viewer = Viewer(
            width=window_size,
            height=window_size,
            env_width=self._box_size,
            env_height=self._box_size,
            background=(255, 255, 255),
        )

        # Set the MDP info
        mdp_info = MDPInfo(
            state_space=state_space,
            observation_space=observation_space,
            action_space=self.action_space,
            discrete_actions=False,
            gamma=gamma,
            horizon=horizon,
            has_obs=True,
            has_action_masks=False,
            n_agents=self._n_agents,
        )

        super().__init__(mdp_info)

    def reset(self):
        start_state = np.random.uniform(0, self._box_size, size=(2 * self._n_agents,))
        state_normalized = start_state / self._box_size

        observations = []
        xs = state_normalized[: self._n_agents]
        ys = state_normalized[self._n_agents :]
        for i in range(self._n_agents):
            # Compute differences in x and y coordinates
            x_diff = [xs[i] - xs[j] for j in range(self._n_agents) if j != i]
            y_diff = [ys[i] - ys[j] for j in range(self._n_agents) if j != i]

            # Concatenate differences into one vector
            observation = x_diff + y_diff
            observations.append(np.array(observation))

        maddpg_state = self.get_state_from_obs(observations)

        if self._obs_state:
            step = {"state": maddpg_state, "obs": observations, "info": {}}
        else:
            step = {"state": state_normalized, "obs": observations, "info": {}}

        self._state = start_state

        return step

    def step(self, actions):
        """
        Returns the next state, obs, reward, done, and info.

        Arguments:
            actions (np.ndarray): The actions to take in the environment.

        actions are 2D movements between -1 and 1.
        """

        # Update the state
        for i, action in enumerate(actions):
            action = np.clip(
                action, self.action_space[i].low, self.action_space[i].high
            )
            x = self._state[i] + action[0]  # x
            if self._wrap_position:
                x = x % self._box_size  # wrap around edges
            else:
                x = np.clip(x, 0, self._box_size)  # bounce off edges
            self._state[i] = x
            y = self._state[i + self._n_agents] + action[1]  # y
            if self._wrap_position:
                y = y % self._box_size
            else:
                y = np.clip(y, 0, self._box_size)  # bounce off edges
            self._state[i + self._n_agents] = y

        state_normalized = self._state / self._box_size

        observations = []
        xs = state_normalized[: self._n_agents].copy()
        ys = state_normalized[self._n_agents :].copy()
        for i in range(self._n_agents):
            # Compute differences in x and y coordinates
            x_diff = [xs[i] - xs[j] for j in range(self._n_agents) if j != i]
            y_diff = [ys[i] - ys[j] for j in range(self._n_agents) if j != i]

            # Concatenate differences into one vector
            observation = x_diff + y_diff
            observations.append(np.array(observation))

        # Calculate the rewards
        rewards = []
        coordinates = np.stack((xs, ys), axis=1)
        diff = coordinates[:, np.newaxis, :] - coordinates[np.newaxis, :, :]
        distances = np.linalg.norm(diff, axis=2)  # Euclidean distance
        total_distance = np.sum(np.triu(distances, k=1))
        action_penalty = 0
        for i in range(self._n_agents):
            action_penalty -= np.linalg.norm(actions[i]) ** 2
        total_reward = -total_distance + action_penalty
        rewards = [total_reward for _ in range(self._n_agents)]

        absorbing = False

        maddpg_state = self.get_state_from_obs(observations)

        if self._obs_state:
            step = {
                "state": maddpg_state,
                "obs": observations,
                "rewards": rewards,
                "absorbing": absorbing,
                "info": {},
            }
        else:
            step = {
                "state": state_normalized,
                "obs": observations,
                "rewards": rewards,
                "absorbing": absorbing,
                "info": {},
            }

        return step

    def get_state_from_obs(self, obs):
        # STATE MADE OF UNIQUE DIFFS
        unique_differences = []
        for i in range(self._n_agents):
            for j in range(i + 1, self._n_agents):  # Only consider j > i
                # Append x differences (i to j)
                unique_differences.append(obs[i][j - 1])  # x(i to j)
                # Append y differences (i to j)
                unique_differences.append(obs[i][(self._n_agents - 1) + (j - 1)])
        return np.array(unique_differences)

    def render(self, render_info):
        colours = ["red", "blue", "green"]
        for i in range(self._n_agents):
            x = self._state[i]
            y = self._state[i + self._n_agents]
            self._viewer.circle(  # agent
                np.array([x, y]),
                0.1,
                color=colours[i],
            )

        self._viewer.display(self._dt)
        return self._viewer.get_frame()

    def stop(self):
        if self._bool_render:
            self._viewer.close()
