#  Copyright (c) 2022-2023.
#  ProrokLab (https://www.proroklab.org/)
#  All rights reserved.
import typing
from typing import Dict, Callable, List
import numpy as np

import torch
from torch import Tensor
from tensordict.tensordict import TensorDict
from vmas import render_interactively
from vmas.simulator.core import Agent, Box, Landmark, World, Sphere, Entity
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.sensors import Lidar
from vmas.simulator.utils import Color, ScenarioUtils

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom


class Scenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, seed: int, **kwargs):
        self.generator = torch.Generator(device).manual_seed(seed)
        self.plot_grid = False
        self.n_agents = kwargs.get("n_agents", 4)
        self.n_obstacles = kwargs.get("n_obstacles", 2)
        self.collisions = kwargs.get("collisions", True)

        self.state_dim = kwargs.get("state_dim", 6)
        self.action_dim = kwargs.get("action_dim", 2)
        self.input_dim = (self.state_dim + self.action_dim) * self.n_agents
        self.pos_vel_dim = 4 # for the position and velocity

        self.agents_with_same_goal = kwargs.get("agents_with_same_goal", 1)
        self.split_goals = kwargs.get("split_goals", False)
        self.observe_all_goals = kwargs.get("observe_all_goals", False)

        self.lidar_range = kwargs.get("lidar_range", 0)
        self.agent_radius = kwargs.get("agent_radius", 0.1)
        self.comms_range = kwargs.get("comms_range", 0)

        self.shared_rew = kwargs.get("shared_rew", True)
        self.pos_shaping_factor = kwargs.get("pos_shaping_factor", 1)
        self.final_reward = kwargs.get("final_reward", 0.01)

        self.agent_collision_penalty = kwargs.get("agent_collision_penalty", -1)

        self.min_distance_between_entities = self.agent_radius * 2 + 0.05
        self.world_semidim = 1
        self.min_collision_distance = 0.005

        assert 1 <= self.agents_with_same_goal <= self.n_agents
        if self.agents_with_same_goal > 1:
            assert (
                not self.collisions
            ), "If agents share goals they cannot be collidables"
        # agents_with_same_goal == n_agents: all agent same goal
        # agents_with_same_goal = x: the first x agents share the goal
        # agents_with_same_goal = 1: all independent goals
        if self.split_goals:
            assert (
                self.n_agents % 2 == 0
                and self.agents_with_same_goal == self.n_agents // 2
            ), "Splitting the goals is allowed when the agents are even and half the team has the same goal"

        # Make world
        world = World(batch_dim, device, substeps=2)

        known_colors = [
            (0.22, 0.49, 0.72),
            (1.00, 0.50, 0),
            (0.30, 0.69, 0.29),
            (0.97, 0.51, 0.75),
            (0.60, 0.31, 0.64),
            (0.89, 0.10, 0.11),
            (0.87, 0.87, 0),
        ]
        colors = torch.randn(
            (max(self.n_agents - len(known_colors), 0), 3), device=device
        )
        # entity_filter_agents: Callable[[Entity], bool] = lambda e: isinstance(e, Agent)

        # Add agents
        for i in range(self.n_agents):
            color = (
                known_colors[i]
                if i < len(known_colors)
                else colors[i - len(known_colors)]
            )

            # Constraint: all agents have same action range and multiplier
            agent = Agent(
                name=f"agent {i}",
                collide=True,
                color=color,
                shape=Sphere(radius=self.agent_radius),
                render_action=True,
                # sensors=[
                #     Lidar(
                #         world,
                #         n_rays=12,
                #         max_range=self.lidar_range,
                #         entity_filter=entity_filter_agents,
                #     ),
                # ]
                # if self.collisions
                # else None,
            )
            agent.pos_rew = torch.zeros(batch_dim, device=device)
            agent.agent_collision_rew = agent.pos_rew.clone()
            world.add_agent(agent)

            # Add goals
            goal = Landmark(
                name=f"goal {i}",
                collide=False,
                color=color,
            )
            world.add_landmark(goal)
            agent.goal = goal

        self.obstacles = []
        for i in range(self.n_obstacles):
            obstacle = Landmark(
                name=f"obstacle {i}",
                collide=True,
                movable=False,
                mass=50,
                shape=Sphere(radius=self.agent_radius),
                color=Color.RED,
            )
            world.add_landmark(obstacle)
            self.obstacles.append(obstacle)

        self.pos_rew = torch.zeros(batch_dim, device=device)
        self.final_rew = self.pos_rew.clone()

        return world

    def reset_world_at(self, env_index: int = None):
        entities = self.world.agents + self.obstacles
        ScenarioUtils.spawn_entities_randomly(
            entities,
            self.world,
            env_index,
            self.min_distance_between_entities,
            (-self.world_semidim, self.world_semidim),
            (-self.world_semidim, self.world_semidim),
            generator=self.generator,
        )

        occupied_positions = torch.stack(
            [entity.state.pos for entity in entities], dim=1
        )
        if env_index is not None:
            occupied_positions = occupied_positions[env_index].unsqueeze(0)

        goal_poses = []
        for _ in self.world.agents:
            position = ScenarioUtils.find_random_pos_for_entity(
                occupied_positions=occupied_positions,
                env_index=env_index,
                world=self.world,
                min_dist_between_entities=self.min_distance_between_entities,
                x_bounds=(-self.world_semidim, self.world_semidim),
                y_bounds=(-self.world_semidim, self.world_semidim),
                generator=self.generator,
            )
            goal_poses.append(position.squeeze(1))
            occupied_positions = torch.cat([occupied_positions, position], dim=1)

        for i, agent in enumerate(self.world.agents):
            if self.split_goals:
                goal_index = int(i // self.agents_with_same_goal)
            else:
                goal_index = 0 if i < self.agents_with_same_goal else i

            agent.goal.set_pos(goal_poses[goal_index], batch_index=env_index)

            if env_index is None:
                agent.pos_shaping = (
                    torch.linalg.vector_norm(
                        agent.state.pos - agent.goal.state.pos,
                        dim=1,
                    )
                    * self.pos_shaping_factor
                )
            else:
                agent.pos_shaping[env_index] = (
                    torch.linalg.vector_norm(
                        agent.state.pos[env_index] - agent.goal.state.pos[env_index]
                    )
                    * self.pos_shaping_factor
                )

    def reward(self, agent: Agent, update_sampled_flag=False):
        is_first = agent == self.world.agents[0]

        if is_first:
            self.pos_rew[:] = 0
            self.final_rew[:] = 0

            for a in self.world.agents:
                self.pos_rew += self.agent_reward(a)
                a.agent_collision_rew[:] = 0

            self.all_goal_reached = torch.all(
                torch.stack([a.on_goal for a in self.world.agents], dim=-1), dim=-1
            )

            self.final_rew[self.all_goal_reached] = self.final_reward

            for i, a in enumerate(self.world.agents):
                for j, b in enumerate(self.world.agents):
                    if i <= j:
                        continue
                    if self.world.collides(a, b):
                        distance = self.world.get_distance(a, b)
                        a.agent_collision_rew[
                            distance <= self.min_collision_distance
                        ] += self.agent_collision_penalty
                        b.agent_collision_rew[
                            distance <= self.min_collision_distance
                        ] += self.agent_collision_penalty

        pos_reward = self.pos_rew if self.shared_rew else agent.pos_rew
        return pos_reward + self.final_rew + agent.agent_collision_rew

    def agent_reward(self, agent: Agent):
        agent.distance_to_goal = torch.linalg.vector_norm(
            agent.state.pos - agent.goal.state.pos,
            dim=-1,
        )
        agent.on_goal = agent.distance_to_goal < agent.goal.shape.radius

        pos_shaping = agent.distance_to_goal * self.pos_shaping_factor
        agent.pos_rew = agent.pos_shaping - pos_shaping
        agent.pos_shaping = pos_shaping
        return agent.pos_rew

    def observation(self, agent: Agent):
        goal_poses = []
        if self.observe_all_goals:
            for a in self.world.agents:
                goal_poses.append(agent.state.pos - a.goal.state.pos)
        else:
            # goal_poses.append(agent.state.pos - agent.goal.state.pos)
            goal_poses.append(agent.goal.state.pos)
        return torch.cat(
            [
                agent.state.pos,
                agent.state.vel,
            ]
            + goal_poses,
            # + (
            #     [agent.sensors[0]._max_range - agent.sensors[0].measure()]
            #     if self.collisions
            #     else []
            # ),
            dim=-1,
        )

    def done(self):
        return torch.stack(
            [
                torch.linalg.vector_norm(
                    agent.state.pos - agent.goal.state.pos,
                    dim=-1,
                )
                < agent.shape.radius
                for agent in self.world.agents
            ],
            dim=-1,
        ).all(-1)

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        return {
            "pos_rew": self.pos_rew if self.shared_rew else agent.pos_rew,
            "final_rew": self.final_rew,
            "agent_collisions": agent.agent_collision_rew,
        }

    def replay_info(self):
        obstacles_pos = torch.stack([obstacle.state.pos for obstacle in self.obstacles], dim=1)
        return {"agent_goal": self.world.agents[0].goal.state.pos.clone(), "steps": self.steps.clone(), "obstacles_pos": obstacles_pos}

    def set_env_to_obs(self, prev_obs: torch.Tensor, replay_info: TensorDict):
        agent = self.world.agents[0] # TODO: support more agents
        agent_goal = replay_info.get(("agent_goal"))
        prev_steps = replay_info.get(("steps"))

        # Set the obstacles and goal to match the previous version
        agent.goal.set_pos(agent_goal, None)
        if self.n_obstacles > 0:
            obstacles_pos = replay_info.get(("obstacles_pos"))
            for i, obstacle in enumerate(self.obstacles):
                obstacle.set_pos(obstacles_pos[:, i], None)
        self.steps = prev_steps

        # Set agents to the specified position/velocity
        for i in range(len(self.world.agents)):
            agent = self.world.agents[i]
            agent.set_pos(prev_obs[:, i, 0:2], None)
            agent.set_vel(prev_obs[:, i, 2:4], None)
            agent.set_rot(torch.zeros((1,1)), None)
            agent.set_ang_vel(torch.zeros((1,1)), None)

    def get_joint_input(self, obs, action):
        # Takes a joint obs [(agent_pos, agent_vel, delta_0_val, delta_1_val, ...) x n_agents, (action_1, action_2,...) x n_agents]
        reshaped_obs = obs.reshape(obs.shape[0], self.n_agents * self.state_dim)
        reshaped_action = action.reshape(action.shape[0], self.n_agents * self.action_dim)
        return torch.cat((reshaped_obs, reshaped_action), dim=1)

    def get_pos_vel_from_obs(self, obs):
        # Just return position/velocities for each agent
        return obs[:, :, :self.pos_vel_dim].reshape(obs.shape[0], self.n_agents * self.pos_vel_dim)

    def set_agent_pos_vel(self, obs):
        # TODO: support more agents
        agent = self.world.agents[0]
        agent.set_pos(obs[:, 0:2], None)
        agent.set_vel(obs[:, 2:4], None)

    def get_perturbed_obs(self, obs_mean, beta, stddev):
        # Add plausible displacement to agents' state (pos/vel)
        agent_etas = torch.tensor(
                np.random.uniform(low=-0.2, high=0.2, size=(obs_mean.shape)),
                dtype=torch.float32,
            ).to(obs_mean.device)
        perturbation = torch.clamp(beta * stddev * agent_etas, min=-0.2, max=0.2)
        agent_obs = obs_mean + perturbation
        self.set_agent_pos_vel(agent_obs)
        return self.observation(self.world.agents[0]) # TODO: support more agents

    def extra_render(self, env_index: int = 0) -> "List[Geom]":
        from vmas.simulator import rendering

        geoms: List[Geom] = []

        # Communication lines
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= self.comms_range:
                    color = Color.BLACK.value
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)

        return geoms


if __name__ == "__main__":
    render_interactively(
        __file__,
        control_two_agents=True,
    )
