#  Copyright (c) 2022-2023.
#  ProrokLab (https://www.proroklab.org/)
#  All rights reserved.


import torch
import numpy as np

from vmas import render_interactively
from vmas.simulator.core import Agent, Box, Landmark, Sphere, World
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import Color


class Scenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        self.n_agents = kwargs.get("n_agents", 4)
        self.device = device
        self.package_width = kwargs.get("package_width", 0.6)
        self.package_length = kwargs.get("package_length", 0.6)
        self.package_mass = kwargs.get("package_mass", 50)

        self.shaping_factor = 100

        # Make world
        world = World(batch_dim, device, contact_margin=6e-3)
        # Add agents
        for i in range(self.n_agents):
            agent = Agent(name=f"agent {i}", shape=Sphere(0.03), u_multiplier=0.5, index=i, collide=True)
            world.add_agent(agent)
        # Add landmarks
        goal = Landmark(
            name="goal",
            collide=False,
            shape=Sphere(radius=0.09),
            color=Color.LIGHT_GREEN,
        )
        world.add_landmark(goal)

        self.package = Landmark(
            name=f"package",
            collide=True,
            movable=True,
            mass=self.package_mass,
            shape=Box(
                length=self.package_length, width=self.package_width, hollow=True
            ),
            color=Color.RED,
        )
        self.package.goal = goal
        world.add_landmark(self.package)

        self.local_dim = 4
        self.global_dim = 4

        return world
    
    def local_dim(self):
        # number of obs dimensions corresponding to agents' local values (agent position/velocity)
        return self.local_dim
    
    def global_dim(self):
        # number of obs dimensions corresponding to global values (package position/velocity)
        return self.global_dim

    def reset_world_at(self, env_index: int = None):
        package_pos = torch.zeros(
            (1, self.world.dim_p)
            if env_index is not None
            else (self.world.batch_dim, self.world.dim_p),
            device=self.world.device,
            dtype=torch.float32,
        ).uniform_(
            -1.0,
            1.0,
        )

        self.package.set_pos(
            package_pos,
            batch_index=env_index,
        )
        for i, agent in enumerate(self.world.agents):
            agent.set_pos(
                torch.cat(
                    [
                        torch.zeros(
                            (1, 1)
                            if env_index is not None
                            else (self.world.batch_dim, 1),
                            device=self.world.device,
                            dtype=torch.float32,
                        ).uniform_(
                            -self.package_length / 2 + agent.shape.radius,
                            self.package_length / 2 - agent.shape.radius,
                        ),
                        torch.zeros(
                            (1, 1)
                            if env_index is not None
                            else (self.world.batch_dim, 1),
                            device=self.world.device,
                            dtype=torch.float32,
                        ).uniform_(
                            -self.package_width / 2 + agent.shape.radius,
                            self.package_width / 2 - agent.shape.radius,
                        ),
                    ],
                    dim=1,
                )
                + package_pos,
                batch_index=env_index,
            )

        self.package.goal.set_pos(
            torch.zeros(
                (1, self.world.dim_p)
                if env_index is not None
                else (self.world.batch_dim, self.world.dim_p),
                device=self.world.device,
                dtype=torch.float32,
            ).uniform_(
                -1.0,
                1.0,
            ),
            batch_index=env_index,
        )

        if env_index is None:
            self.package.global_shaping = (
                torch.linalg.vector_norm(
                    self.package.state.pos - self.package.goal.state.pos, dim=1
                )
                * self.shaping_factor
            )
            self.package.on_goal = torch.zeros(
                self.world.batch_dim, dtype=torch.bool, device=self.world.device
            )
        else:
            self.package.global_shaping[env_index] = (
                torch.linalg.vector_norm(
                    self.package.state.pos[env_index]
                    - self.package.goal.state.pos[env_index]
                )
                * self.shaping_factor
            )
            self.package.on_goal[env_index] = False

    def reward(self, agent: Agent):
        is_first = agent == self.world.agents[0]

        if is_first:
            self.rew = torch.zeros(
                self.world.batch_dim, device=self.world.device, dtype=torch.float32
            )

            self.package.dist_to_goal = torch.linalg.vector_norm(
                self.package.state.pos - self.package.goal.state.pos, dim=1
            )
            self.package.on_goal = self.world.is_overlapping(
                self.package, self.package.goal
            )
            self.package.color = torch.tensor(
                Color.RED.value, device=self.world.device, dtype=torch.float32
            ).repeat(self.world.batch_dim, 1)
            self.package.color[self.package.on_goal] = torch.tensor(
                Color.GREEN.value, device=self.world.device, dtype=torch.float32
            )

            package_shaping = self.package.dist_to_goal * self.shaping_factor
            self.rew[~self.package.on_goal] += (
                self.package.global_shaping[~self.package.on_goal]
                - package_shaping[~self.package.on_goal]
            )
            self.package.global_shaping = package_shaping

            self.rew[~self.package.on_goal] += (
                self.package.global_shaping[~self.package.on_goal]
                - package_shaping[~self.package.on_goal]
            )
            self.package.global_shaping = package_shaping

        return self.rew

    def observation(self, agent: Agent):
        return torch.cat(
            [
                agent.state.pos,
                agent.state.vel,
                self.package.state.vel,
                self.package.state.pos - agent.state.pos,
                self.package.state.pos - self.package.goal.state.pos,
            ],
            dim=-1,
        )

    def done(self):
        return self.package.on_goal
    
    def set_env_state(self, new_obs: torch.Tensor):
        for e in self.world.entities:
            if isinstance(e, Agent):
                e.set_pos(new_obs[:, e._index * 4 :e._index * 4 + 2], None)
                e.set_vel(new_obs[:, e._index * 4 + 2:e._index * 4 + 4], None)
            elif e.name == "package":
                # TODO: verify that package pos is correct
                package_pos = new_obs[:, self.local_dim * self.n_agents + 2:self.local_dim * self.n_agents + 4] + self.package.goal.state.pos
                e.set_pos(package_pos, None)
                e.set_vel(new_obs[:, self.local_dim * self.n_agents:self.local_dim * self.n_agents + 2], None)
    
    def get_condensed_obs(self, obs):
        # Takes a joint obs [(agent_pos, agent_vel, package_vel, package_pos (relative to agent), package_pos (relative to goal)) x n_agents]
        # Returns [(agent_pos, agent_vel) x n_agents, package_vel, package_pos (relative to goal)]
        local_indices = torch.tensor([i for i in range(self.local_dim)]).to(self.device)
        local_agent_obs = torch.index_select(obs, dim=-1, index=local_indices)
        joint_local_obs = local_agent_obs.reshape(local_agent_obs.shape[0], -1)
        global_indices = torch.tensor([4, 5, 8, 9]).to(self.device)
        global_obs = torch.index_select(obs, dim=-1, index=global_indices) # the position/velocity for the package
        global_obs = torch.select(global_obs, dim=-2, index=0) # package position/velocity is the same for all agents, just pick one
        joint_obs = torch.cat((joint_local_obs, global_obs), dim=1)
        return joint_obs
    
    def get_expanded_obs(self, obs):
        joint_local_obs = obs[:, :self.local_dim * self.n_agents]
        agent_local_obs = joint_local_obs.reshape((obs.shape[0], self.n_agents, self.local_dim)) # per-agent pos/vel
        package_vel = obs[:, self.local_dim * self.n_agents: self.local_dim * self.n_agents + 2]
        agent_package_vel = package_vel.unsqueeze(dim=1).expand(obs.shape[0], self.n_agents, package_vel.shape[-1])
        package_pos = obs[:, self.local_dim * self.n_agents + 2:] + self.package.goal.state.pos 
        package_agent_pos = package_pos.unsqueeze(dim=1) - agent_local_obs[:, :, :2]
        package_goal_pos = obs[:, self.local_dim * self.n_agents + 2:] # position relative to the goal
        agent_package_goal_pos = package_goal_pos.unsqueeze(dim=1).expand(obs.shape[0], self.n_agents, package_vel.shape[-1])
        agent_obs = torch.cat((agent_local_obs, agent_package_vel, package_agent_pos, agent_package_goal_pos), dim=2)
        return agent_obs
    
    def set_plausible_obs(self, obs_mean, beta, stddev):
        # Add plausible displacement to agents' state (pos/vel)
        etas = torch.tensor(
                np.random.uniform(low=-1, high=1, size=(obs_mean.shape[0], self.local_dim * self.n_agents)), 
                dtype=torch.float32,
            ).to(self.device)
        breakpoint()
        plausible_local_obs =  obs_mean[:self.local_dim * self.n_agents] + beta * stddev * etas
        # TODO: compute where the package will be
        # TODO: this function must set the scenario state to the plausible obs


if __name__ == "__main__":
    render_interactively(
        __file__,
        control_two_agents=True,
        n_agents=2,
        package_width=0.3,
        package_length=0.3,
        package_mass=10,
    )
