import time

import numpy as np
import torch
import gymnasium as gym
from k_level_policy_gradients.src.core.environment import Environment, MDPInfo
from k_level_policy_gradients.src.utils.spaces import *
from vmas import make_env


class VMAS(Environment):
    """
    Interface for Vectorized Multi Agent Simulator Environments (VMAS).
    Possible to use any VMAS environent just using the environment name.
    """

    def __init__(
        self,
        scenario="",
        n_agents=2,
        continuous_actions=True,
        include_absolute_positions=False,
        horizon=None,
        gamma=0.99,
        seed=0,
        bool_render=False,
        use_cuda=False,
    ):
        """
        Constructor.

        Args:
            scenario (str): gym id of the environment;
            continuous_actions (bool, True): whether the agents have continuous or discrete actions;
            include_absolute_positions (bool, True): whether to include absolute positions of each agent in the observations;
            horizon (int): the horizon. If None, use the one from PettingZoo;
            gamma (float, 0.99): the discount factor;
            seed (int, 0): the seed for the environment;
            bool_render (bool, False): whether to render the environment;
            use_cuda (bool, True): whether to use cuda.


        Agents in VMAS are named "agent 0", "agent 1", etc.
        """
        self.use_cuda = use_cuda
        render_mode = "human" if bool_render else None
        device = "cuda" if use_cuda else "cpu"
        self._env = make_env(
            scenario=scenario,
            n_agents=n_agents,
            num_envs=1,
            device=device,
            continuous_actions=continuous_actions,
            seed=seed,
            render_mode=render_mode,
            clamp_actions=True,
        )
        self._include_absolute_positions = include_absolute_positions

        self._render_dt = (
            self._env.world._dt if hasattr(self._env.world, "dt") else 0.01
        )

        self._n_agents = self._env.n_agents
        self._agents = list()
        for agent in self._env.agents:
            self._agents.append(agent.name)

        horizon = self._set_horizon(self._env, horizon)

        action_space = [Box(space.low, space.high) for space in self._env.action_space]
        if self._include_absolute_positions:
            observation_space = [
                Box(space.low, space.high) for space in self._env.observation_space
            ]
        else:
            # Truncate the first 2 elements of the observation space to remove the absolute position
            observation_space = [
                Box(space.low[2:], space.high[2:])
                for space in self._env.observation_space
            ]
        state_space_low = np.concatenate([space.low for space in observation_space])
        state_space_high = np.concatenate([space.high for space in observation_space])
        state_space = Box(state_space_low, state_space_high)

        mdp_info = MDPInfo(
            state_space,
            observation_space,
            action_space,
            gamma,
            horizon,
            self._render_dt,
            has_obs=True,
            has_action_masks=False,
            n_agents=self._n_agents,
        )

        super().__init__(mdp_info)

    def reset(self):
        obs_list = self._env.reset()
        obs = self._preprocess_obs_list(obs_list)
        state = self._state_from_obs_list(obs)
        step = {"state": state, "obs": obs}
        return step

    def step(self, actions):
        """
        Execute a step of the environment.
        action is a list of np.ndarray actions, one for each agent.
        """
        # action_clipped = self._clip_action(action)
        actions_expanded = self._expand_actions(actions)
        observations, rewards, absorbing, info = self._env.step(actions_expanded)

        obs = self._preprocess_obs_list(observations)
        state = self._state_from_obs_list(obs)
        rewards = self._flatten_twice_tensor_list(rewards)
        absorbing = bool(absorbing)

        step = {
            "state": state,
            "obs": obs,
            "rewards": rewards,
            "absorbing": absorbing,
            "info": info,
        }

        return step

    def render(self, render_info, mode="human"):
        self._env.render(mode=mode)
        time.sleep(self._render_dt)

    def stop(self):
        pass

    def _state_from_obs_list(self, obs_list):
        """
        VMAS only outputs observations.
        Construct state by concatenating all observations.
        """
        state = np.concatenate(obs_list)
        return state

    def _preprocess_obs_list(self, obs_list):
        """
        VMAS outputs everything as tensors with extra dimension for parallel envs.
        Sometimes we need to flatten them.
        Additionally, sometimes we need to truncate the first 2 elements of the state to remove the absolute position of each agent.
        """
        flattened_obs_list = self._flatten_tensor_list(obs_list)
        if not self._include_absolute_positions:
            flattened_obs_list_truncated = [obs[2:] for obs in flattened_obs_list]
            return flattened_obs_list_truncated
        else:
            return flattened_obs_list

    def _expand_actions(self, action):
        """
        VMAS needs actions in the shape (batch_dim, action_dim).
        Since we only use 1 env at a time, we need to add a dimension at axis=0.
        """
        if self.use_cuda:
            action_expanded = [torch.unsqueeze(a, dim=0) for a in action]
        else:
            action_expanded = [np.expand_dims(a, axis=0) for a in action]
        return action_expanded

    def _flatten_tensor_list(self, tensor_list):
        """
        VMAS outputs everything as tensors with extra dimension for parallel envs.
        Need to convert everything to numpy arrays/torch tensors and flatten them.
        """
        if self.use_cuda:
            return [item.flatten() for item in tensor_list]
        else:
            return [item.detach().cpu().numpy().flatten() for item in tensor_list]

    @staticmethod
    def _flatten_twice_tensor_list(tensor_list):
        """
        VMAS outputs everything as tensors with extra dimension for parallel envs.
        Sometimes we just need the individual values (e.g. for rewards).
        """
        return [item.detach().cpu().numpy().flatten()[0] for item in tensor_list]

    @staticmethod
    def _set_horizon(env, horizon):

        if horizon is None:
            if not hasattr(env, "max_steps"):
                raise RuntimeError("This MPE environment has no specified time limit!")
            horizon = env.max_steps

        if hasattr(env, "max_steps"):
            env.max_steps = None  # Hack to ignore PettingZoo time limit.
        return horizon
