#  Copyright (c) 2023.
#  ProrokLab (https://www.proroklab.org/)
#  All rights reserved.
from typing import Dict, Callable

import torch
import copy
import numpy as np
from torch import Tensor
from tensordict.tensordict import TensorDict
from torch.distributions import MultivariateNormal

from vmas import render_interactively
from vmas.simulator.core import World, Line, Agent, Sphere, Entity
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.sensors import Lidar
from vmas.simulator.utils import Color, X, Y


class Scenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, seed: int, **kwargs):
        self.n_agents = kwargs.get("n_agents", 3)
        self.device = device
        self.shared_rew = kwargs.get("shared_rew", True)
        self.generator = torch.Generator(self.device).manual_seed(seed)
        self.state_dim = kwargs.get("state_dim", 12)
        self.action_dim = kwargs.get("action_dim", 1)
        self.input_dim = (self.state_dim + self.action_dim) * self.n_agents
        self.pos_vel_dim = 4 # for the position and velocity

        self.comms_range = kwargs.get("comms_range", 0.0)
        self.lidar_range = kwargs.get("lidar_range", 0.0) # used to be 0.2
        self.agent_radius = kwargs.get("agent_radius", 0.025)
        self.xdim = kwargs.get("xdim", 1)
        self.ydim = kwargs.get("ydim", 1)
        self.grid_spacing = kwargs.get("grid_spacing", 0.05)
        self.deltas = [[self.grid_spacing, 0],
                        [-self.grid_spacing, 0],
                        [0, self.grid_spacing],
                        [0, -self.grid_spacing],
                        [-self.grid_spacing, -self.grid_spacing],
                        [self.grid_spacing, -self.grid_spacing],
                        [-self.grid_spacing, self.grid_spacing],
                        [self.grid_spacing, self.grid_spacing]]

        self.n_gaussians = kwargs.get("n_gaussians", 3)
        self.cov = kwargs.get("cov", 0.05)

        assert (self.xdim / self.grid_spacing) % 1 == 0 and (
            self.ydim / self.grid_spacing
        ) % 1 == 0

        self.plot_grid = False
        self.n_x_cells = int((2 * self.xdim) / self.grid_spacing)
        self.n_y_cells = int((2 * self.ydim) / self.grid_spacing)
        self.max_pdf = torch.zeros((batch_dim,), device=device, dtype=torch.float32)
        self.alpha_plot: float = 0.5

        # Make world
        world = World(
            batch_dim,
            device,
            x_semidim=self.xdim - self.agent_radius,
            y_semidim=self.ydim - self.agent_radius,
        )
        entity_filter_agents: Callable[[Entity], bool] = lambda e: isinstance(e, Agent)
        for i in range(self.n_agents):
            agent = Agent(
                name=f"agent {i}",
                render_action=True,
                collide=True,
                shape=Sphere(radius=self.agent_radius),
                sensors=[
                    Lidar(
                        world,
                        angle_start=0.05,
                        angle_end=2 * torch.pi + 0.05,
                        n_rays=12,
                        max_range=self.lidar_range,
                        entity_filter=entity_filter_agents,
                    ),
                ],
            )

            world.add_agent(agent)

        self.sampled = torch.zeros(
            (batch_dim, self.n_x_cells, self.n_y_cells),
            device=device,
            dtype=torch.bool,
        )

        self.locs = [
            torch.zeros((batch_dim, world.dim_p), device=device, dtype=torch.float32)
            for _ in range(self.n_gaussians)
        ]
        self.cov_matrix = torch.tensor(
            [[self.cov, 0], [0, self.cov]], dtype=torch.float32, device=device
        ).expand(batch_dim, world.dim_p, world.dim_p)

        return world

    def reset_world_at(self, env_index: int = None):
        for i, loc in enumerate(self.locs):
            x = torch.zeros(
                (1,) if env_index is not None else (self.world.batch_dim, 1),
                device=self.world.device,
                dtype=torch.float32,
            ).uniform_(-self.xdim, self.xdim, generator=self.generator)
            y = torch.zeros(
                (1,) if env_index is not None else (self.world.batch_dim, 1),
                device=self.world.device,
                dtype=torch.float32,
            ).uniform_(-self.ydim, self.ydim, generator=self.generator)
            new_loc = torch.cat([x, y], dim=-1)
            if env_index is None:
                self.locs[i] = new_loc
            else:
                self.locs[i][env_index] = new_loc

        self.gaussians = [
            MultivariateNormal(
                loc=loc,
                covariance_matrix=self.cov_matrix,
            )
            for loc in self.locs
        ]

        if env_index is None:
            self.max_pdf[:] = 0
            self.sampled[:] = False
        else:
            self.max_pdf[env_index] = 0
            self.sampled[env_index] = False
        self.normalize_pdf(env_index=env_index)

        for agent in self.world.agents:
            agent.set_pos(
                torch.cat(
                    [
                        torch.zeros(
                            (1, 1)
                            if env_index is not None
                            else (self.world.batch_dim, 1),
                            device=self.world.device,
                            dtype=torch.float32,
                        ).uniform_(-self.xdim, self.xdim, generator=self.generator),
                        torch.zeros(
                            (1, 1)
                            if env_index is not None
                            else (self.world.batch_dim, 1),
                            device=self.world.device,
                            dtype=torch.float32,
                        ).uniform_(-self.ydim, self.ydim, generator=self.generator),
                    ],
                    dim=-1,
                ),
                batch_index=env_index,
            )
            agent.sample = self.sample(agent.state.pos)

    def sample(
        self,
        pos,
        update_sampled_flag: bool = False,
        get_sample_value: bool = False,
        get_seen_value: bool = False,
        env_indices: torch.tensor = None,
        norm: bool = True,
    ):
        out_of_bounds = (
            (pos[:, X] < -self.xdim)
            + (pos[:, X] > self.xdim)
            + (pos[:, Y] < -self.ydim)
            + (pos[:, Y] > self.ydim)
        )
        pos[:, X].clamp_(-self.world.x_semidim, self.world.x_semidim)
        pos[:, Y].clamp_(-self.world.y_semidim, self.world.y_semidim)

        index = pos / self.grid_spacing
        index[:, X] += self.n_x_cells / 2
        index[:, Y] += self.n_y_cells / 2
        index = index.to(torch.long)
        v = torch.stack(
            [gaussian.log_prob(pos).exp() for gaussian in self.gaussians], dim=-1
        ).sum(-1)
        if norm:
            v = v / self.max_pdf

        sampled = self.sampled[
            torch.arange(self.world.batch_dim), index[:, 0], index[:, 1]
        ]

        if get_sample_value:
            v[out_of_bounds] = 0
        elif get_seen_value:
            v[~sampled + out_of_bounds] = 0
        else:
            v[sampled + out_of_bounds] = 0
        # v[sampled + out_of_bounds] = 0
        if update_sampled_flag:
            if env_indices is None:
                env_indices = torch.arange(self.world.batch_dim)
            self.sampled[env_indices, index[env_indices, 0], index[env_indices, 1]] = True

        return v

    def sample_single_env(
        self,
        pos,
        env_index,
        norm: bool = True,
    ):
        pos = pos.view(-1, self.world.dim_p)

        out_of_bounds = (
            (pos[:, X] < -self.xdim)
            + (pos[:, X] > self.xdim)
            + (pos[:, Y] < -self.ydim)
            + (pos[:, Y] > self.ydim)
        )
        pos[:, X].clamp_(-self.world.x_semidim, self.world.x_semidim)
        pos[:, Y].clamp_(-self.world.y_semidim, self.world.y_semidim)

        index = pos / self.grid_spacing
        index[:, X] += self.n_x_cells / 2
        index[:, Y] += self.n_y_cells / 2
        index = index.to(torch.long)

        pos = pos.unsqueeze(1).expand(pos.shape[0], self.world.batch_dim, 2)

        v = torch.stack(
            [gaussian.log_prob(pos).exp() for gaussian in self.gaussians], dim=-1
        ).sum(-1)[:, env_index]
        if norm:
            v = v / self.max_pdf[env_index]

        sampled = self.sampled[env_index, index[:, 0], index[:, 1]]

        v[sampled + out_of_bounds] = 0

        return v

    def sampled_at_pos(self, pos):
        out_of_bounds = (
            (pos[:, X] < -self.xdim)
            + (pos[:, X] > self.xdim)
            + (pos[:, Y] < -self.ydim)
            + (pos[:, Y] > self.ydim)
        )
        pos[:, X].clamp_(-self.world.x_semidim, self.world.x_semidim)
        pos[:, Y].clamp_(-self.world.y_semidim, self.world.y_semidim)

        index = pos / self.grid_spacing
        index[:, X] += self.n_x_cells / 2
        index[:, Y] += self.n_y_cells / 2
        index = index.to(torch.long)

        sampled = self.sampled[
            torch.arange(self.world.batch_dim), index[:, 0], index[:, 1]
        ]

        res = -torch.ones(self.world.batch_dim, device=pos.device)
        res[out_of_bounds + sampled] = 0
        return res

    def normalize_pdf(self, env_index: int = None):
        xpoints = torch.arange(
            -self.xdim, self.xdim, self.grid_spacing, device=self.world.device
        )
        ypoints = torch.arange(
            -self.ydim, self.ydim, self.grid_spacing, device=self.world.device
        )
        if env_index is not None:
            ygrid, xgrid = torch.meshgrid(ypoints, xpoints, indexing="ij")
            pos = torch.stack((xgrid, ygrid), dim=-1).reshape(-1, 2)
            sample = self.sample_single_env(pos, env_index, norm=False)
            self.max_pdf[env_index] = sample.max()
        else:
            for x in xpoints:
                for y in ypoints:
                    pos = torch.tensor(
                        [x, y], device=self.world.device, dtype=torch.float32
                    ).repeat(self.world.batch_dim, 1)
                    sample = self.sample(pos, norm=False)
                    self.max_pdf = torch.maximum(self.max_pdf, sample)

    def reward(self, agent: Agent, update_sampled_flag: bool = True) -> Tensor:
        is_first = self.world.agents.index(agent) == 0
        if is_first:
            for a in self.world.agents:
                a.sample = self.sample(a.state.pos, update_sampled_flag=update_sampled_flag)
            self.sampling_rew = torch.stack(
                [a.sample for a in self.world.agents], dim=-1
            ).sum(-1)

        return self.sampling_rew if self.shared_rew else agent.sample

    def observation(self, agent: Agent) -> Tensor:
        agent_sample = self.sample(agent.state.pos, get_sample_value=True, update_sampled_flag=False)
        observations = [agent.state.pos, agent.state.vel, agent_sample.unsqueeze(-1)] # , agent.sensors[0].measure()]

        for delta in self.deltas:
            pos = agent.state.pos + torch.tensor(
                delta,
                device=self.world.device,
                dtype=torch.float32,
            )
            sample = self.sample(
                pos,
                update_sampled_flag=False,
            ).unsqueeze(-1)
            observations.append(sample)

        for loc in self.locs: # include the locs in order to learn the reward structure
            observations.append(loc)

        return torch.cat(observations, dim=-1)

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        return {"agent_sample": agent.sample}
    
    def replay_info(self):
        return {"loc": torch.stack(self.locs).permute(1,0,2).clone(), "sampled": self.sampled.clone(), "max_pdf": self.max_pdf.clone(), "steps": self.steps.clone()}
    
    def set_env_to_obs(self, prev_obs: torch.Tensor, replay_info: TensorDict):
        prev_locs = replay_info.get(("loc"))
        prev_max_pdf = replay_info.get(("max_pdf"))
        prev_sampled = replay_info.get(("sampled"))
        prev_steps = replay_info.get(("steps"))

        # Set the sample distribution to match the previous version
        self.locs = prev_locs.squeeze(dim=1).permute(1,0,2)
        self.gaussians = [
            MultivariateNormal(
                loc=loc,
                covariance_matrix=self.cov_matrix,
            )
            for loc in self.locs
        ]

        # Reset the world
        self.max_pdf[:] = prev_max_pdf
        self.sampled[:] = prev_sampled
        self.steps = prev_steps

        # Set agents to the specified position/velocity
        for i in range(len(self.world.agents)):
            agent = self.world.agents[i]
            agent.set_pos(prev_obs[:, i, 0:2], None)
            agent.set_vel(prev_obs[:, i, 2:4], None)
            agent.set_rot(torch.zeros((1,1)), None)
            agent.set_ang_vel(torch.zeros((1,1)), None)
            agent.sample = self.sample(agent.state.pos)
    
    def get_input_from_obs(self, obs):
        return obs.reshape(obs.shape[0], self.n_agents * self.state_dim)
    
    def get_joint_input(self, obs, action):
        reshaped_obs = obs.reshape(obs.shape[0], self.n_agents * self.state_dim)
        reshaped_action = action.reshape(action.shape[0], self.n_agents * self.action_dim)
        return torch.cat((reshaped_obs, reshaped_action), dim=1)

    def get_pos_vel_from_obs(self, obs):
        # Just return position/velocities for each agent
        return obs[:, :, :self.pos_vel_dim].reshape(obs.shape[0], self.n_agents * self.pos_vel_dim)
    
    def set_agent_pos_vel(self, obs):
        # TODO: support more agents
        agent = self.world.agents[0]
        agent.set_pos(obs[:, 0:2], None)
        agent.set_vel(obs[:, 2:4], None)
    
    def get_perturbed_obs(self, obs_mean, beta, stddev):
        # Add plausible displacement to agents' state (pos/vel)
        agent_etas = torch.tensor(
                np.random.uniform(low=-1, high=1, size=(obs_mean.shape)),
                dtype=torch.float32,
            ).to(self.device)
        perturbation = torch.clamp(beta * stddev * agent_etas, min=-self.grid_spacing, max=self.grid_spacing)
        agent_obs = obs_mean + perturbation
        self.set_agent_pos_vel(agent_obs)
        return self.observation(self.world.agents[0]) # TODO: support more agents

    def get_index_from_obs(self, obs):
        obs = obs.reshape(self.state_dim)
        pos = obs[:2]
        index = pos / self.grid_spacing
        index[X] += self.n_x_cells / 2
        index[Y] += self.n_y_cells / 2
        index = index.to(torch.long)
        return index

    def get_index_from_pos(self, pos):
        index = pos / self.grid_spacing
        index[:, X] += self.n_x_cells / 2
        index[:, Y] += self.n_y_cells / 2
        index = index.to(torch.long)
        return index

    def density_for_plot(self, env_index):
        def f(x):
            sample = self.sample_single_env(
                torch.tensor(x, dtype=torch.float32, device=self.world.device),
                env_index=env_index,
            )
            return sample
        return f

    def extra_render(self, env_index: int = 0):
        from vmas.simulator import rendering
        from vmas.simulator.rendering import render_function_util

        # Function
        geoms = [
            render_function_util(
                f=self.density_for_plot(env_index=env_index),
                plot_range=(self.xdim, self.ydim),
                cmap_alpha=self.alpha_plot,
            )
        ]

        # Communication lines
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= self.comms_range:
                    color = Color.BLACK.value
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)

        # Perimeter
        for i in range(4):
            geom = Line(
                length=2
                * ((self.ydim if i % 2 == 0 else self.xdim) - self.agent_radius)
                + self.agent_radius * 2
            ).get_geometry()
            xform = rendering.Transform()
            geom.add_attr(xform)

            xform.set_translation(
                0.0
                if i % 2
                else (
                    self.world.x_semidim + self.agent_radius
                    if i == 0
                    else -self.world.x_semidim - self.agent_radius
                ),
                0.0
                if not i % 2
                else (
                    self.world.y_semidim + self.agent_radius
                    if i == 1
                    else -self.world.y_semidim - self.agent_radius
                ),
            )
            xform.set_rotation(torch.pi / 2 if not i % 2 else 0.0)
            color = Color.BLACK.value
            if isinstance(color, torch.Tensor) and len(color.shape) > 1:
                color = color[env_index]
            geom.set_color(*color)
            geoms.append(geom)

        return geoms


if __name__ == "__main__":
    render_interactively(__file__, control_two_agents=True)