import typing
from typing import Dict, List

import torch
import numpy as np
from torch import Tensor
from vmas.simulator.core import Agent, Sphere, World, Landmark, Line
from vmas.simulator.dynamics.holonomic import Holonomic
from vmas.simulator.dynamics.common import Dynamics
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import Color, TorchUtils
from vmas.simulator.controllers.velocity_controller import VelocityController
from vmas.simulator import rendering

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom

import numpy as np


# Agent control types
VELOCITY_CONTROL = "VELOCITY_CONTROL"
QP_CONTROL_CONSTRAINT = "QP_CONTROL_CONSTRAINT"
QP_CONTROL_CBF = "QP_CONTROL_CBF"


# Careful - only partially usable:
QP_CONTROL_OBJECTIVE_AND_CONSTRAINT = "QP_CONTROL_OBJECTIVE_AND_CONSTRAINT"
QP_CONTROL_OBJECTIVE = "QP_CONTROL_OBJECTIVE"
QP_CONTROL_CONSTRAINT_ROBOMASTER = "QP_CONTROL_CONSTRAINT_ROBOMASTER"
QP_CONTROL_RVO = "QP_CONTROL_RVO" 

class HolonomicQP(Dynamics):
    @property
    def needed_action_size(self) -> int:
        return 2

    def process_action(self):
        self.agent.state.force = self.agent.action.u[:, :2]


class WaypointScenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        self.agent_control_type = kwargs.get("agent_control_type", QP_CONTROL_RVO)  # Set RVO as default
        
        self.n_agents = kwargs.get("n_agents", 8)
        self.n_obstacles = kwargs.get("n_obstacles", 0)
        self.v_range = kwargs.get("v_range", 1.0)

        if self.agent_control_type == VELOCITY_CONTROL:
            self.a_range = kwargs.get("a_range", 0.3) # trains better if this is lower (use 0.6 under r=0.2, 0.3 ow)
        else:
            self.a_range = kwargs.get("a_range", 1.0)
        self.linear_friction = kwargs.get("linear_friction", 0.0)
        self.collisions = kwargs.get("collisions", True)

        self.max_velocity = kwargs.get("max_velocity", 5.0)
        self.max_steps = kwargs.get("max_steps", 100)
        self.agent_radius = kwargs.get("agent_radius", 0.3)

        # RVO parameters
        self.rvo_time_horizon = kwargs.get("rvo_time_horizon", 0.1)
        self.rvo_visibility_radius = kwargs.get("rvo_visibility_radius", 1.75) # full visibility radius
        self.rvo_target_region = kwargs.get("rvo_target_region", 0.05)

        # QP / reward coefficients
        self.lambda_u = kwargs.get("LAMBDA_U", 10.0)  # collision penalty
        self.lambda_g = kwargs.get("LAMBDA_G", 1.0)   # near-goal bonus (distance < 0.1)
        self.final_reward = kwargs.get("final_reward", 0.5)  # new final reward
        self.delta_scaling = kwargs.get("delta_scaling", 99.0)

        # Read QP constraint flag
        self.constraint_on = kwargs.get("constraint_on", True)

        # Read visualization options
        self.disable_collision_viz = kwargs.get("disable_collision_viz", False)
        self.debug_prints = kwargs.get("debug_prints", False)
        # ADDED: Read action render flag
        self.disable_action_render = kwargs.get("disable_action_render", False)

        self.collision_gap = 0.0002
        self.min_distance_between_entities = self.agent_radius * 2 + self.collision_gap

        # Arena bounds
        self.world_x_semidim = 1
        self.world_y_semidim = 1

        self.min_input_norm = kwargs.get("min_input_norm", 0.03)

        self.f_range = self.a_range + self.linear_friction
        # For velocity control we clamp the velocity, otherwise clamp the force
        if self.agent_control_type == VELOCITY_CONTROL:
            self.u_range = self.v_range
        else:
            self.u_range = self.f_range

        # Velocity controller internal params
        controller_params = [2, 6, 0.002]

        self.resolution_factor = kwargs.pop("resolution_factor", 1200)
        self.viewer_zoom = 2
        self.viewer_size = kwargs.pop(
            "viewer_size",
            (
                int(self.resolution_factor),
                int(self.resolution_factor),
            ),
        )

        # Create the underlying world
        world = World(
            batch_dim,
            device,
            dt=0.1,
            substeps=1,
            linear_friction=self.linear_friction,
            drag=0.0,
            x_semidim=self.world_x_semidim + self.agent_radius,
            y_semidim=self.world_y_semidim + self.agent_radius,
        )

        known_colors = [
            (0.22, 0.40, 0.72),  # bluish
            (0.22, 0.72, 0.40),  # greenish
            (0.72, 0.40, 0.22),  # more
            (0.72, 0.22, 0.40),
        ]

        # Determine the appropriate action size / range for each control type
        if self.agent_control_type == VELOCITY_CONTROL:
            # Using an action size=2 for velocity is typical (x, y).
            action_size = 2
            u_range = self.u_range
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            action_size = 2 + 1 + 1 + 1
            u_range = [
                self.u_range,
                self.u_range,
                self.u_range,
                self.world_x_semidim,
                self.world_y_semidim,
            ]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE or self.agent_control_type == QP_CONTROL_RVO:
            action_size = 2
            u_range = [self.world_x_semidim, self.world_y_semidim]
        elif self.agent_control_type in [QP_CONTROL_CONSTRAINT, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
            action_size = 2 + 1
            u_range = self.u_range
        elif self.agent_control_type == QP_CONTROL_CBF:
            self.k_cbf_range = kwargs.get("k_cbf_range", 2.0)
            action_size = 1
            u_range = self.k_cbf_range
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        # Create agents
        for i in range(self.n_agents):
            if self.agent_control_type == VELOCITY_CONTROL:
                dynamics = Holonomic()
            else:
                dynamics = HolonomicQP()

            agent = Agent(
                name=f"agent_{i}",
                collide=self.collisions,
                alpha=1.0,
                shape=Sphere(radius=self.agent_radius),
                # Use the flag read from kwargs
                render_action=(not self.disable_action_render),
                v_range=self.v_range,
                f_range=self.f_range,
                u_range=u_range,
                dynamics=dynamics,
                action_size=action_size,
            )

            # Each agent's goal
            agent.goal_pos = torch.zeros(batch_dim, 2, device=device)

            # Color
            agent.original_color = known_colors[i % len(known_colors)]
            agent.color = agent.original_color

            # Controller
            agent.controller = VelocityController(agent, world, controller_params, "standard")

            # Initialize rewards/tensors
            agent.agent_collision_rew = torch.zeros(batch_dim, device=device)
            agent.individual_rew = torch.zeros(batch_dim, device=device)
            agent.computed_velocity = torch.zeros(batch_dim, 2, device=device)
            agent.collision_timer = torch.zeros(batch_dim, device=device)

            if self.agent_control_type != VELOCITY_CONTROL:
                agent.action.delta = torch.zeros(batch_dim, device=device)
                agent.delta_rew = torch.zeros(batch_dim, device=device)

            # We'll store a 'final_rew' in case the agent is truly on its goal
            agent.final_rew = torch.zeros(batch_dim, device=device)

            world.add_agent(agent)

        # Walls
        self.walls = []
        for i in range(4):
            if i < 2:
                wall_length = 2 * (world.x_semidim)
            else:
                wall_length = 2 * (world.y_semidim)
            wall = Landmark(
                name=f"wall_{i}",
                collide=True,
                shape=Line(length=wall_length),
                color=Color.BLACK,
            )
            wall.original_color = (0.0, 0.0, 0.0)
            wall.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_landmark(wall)
            self.walls.append(wall)

        # Obstacles
        self.obstacle_radius = self.agent_radius
        self.obstacles = []
        for i in range(self.n_obstacles):
            obs = Landmark(
                name=f"obstacle_{i}",
                collide=True,
                shape=Sphere(radius=self.obstacle_radius),
                color=Color.BLACK,
            )
            obs.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_landmark(obs)
            self.obstacles.append(obs)

        self._world = world

        # Spawn walls (fixed)
        self.spawn_walls()

        # Init QP layer
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        return world

    def reset_world_at(self, env_index: int = None):
        # Re-init QP
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        # Place walls
        self.spawn_walls(env_index)

        # Randomly spawn agents & obstacles
        self.respawn_agents_no_overlap(env_index)
        self.respawn_obstacles_no_overlap(env_index)

        # Random goals
        self.assign_random_goals(env_index)

        # Reset timers and controllers
        for agent in self.world.agents:
            agent.controller.reset(env_index)
            agent.collision_timer = torch.zeros_like(agent.collision_timer)
            agent.final_rew = torch.zeros_like(agent.final_rew)
        for obs in self.obstacles:
            obs.collision_timer = torch.zeros_like(obs.collision_timer)

    def assign_random_goals(self, env_index=None):
        self.goal_radius = 0.2
        batch_dim = self._world.batch_dim
        device = self._world.device

        x_min = -self._world.x_semidim + self.goal_radius
        x_max = self._world.x_semidim - self.goal_radius
        y_min = -self._world.y_semidim + self.goal_radius
        y_max = self._world.y_semidim - self.goal_radius

        max_tries = 10000
        for agent in self.world.agents:
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0

            while not torch.all(placed) and tries < max_tries:
                tries += 1
                rand_x = torch.rand(batch_dim, device=device) * (x_max - x_min) + x_min
                rand_y = torch.rand(batch_dim, device=device) * (y_max - y_min) + y_min
                candidate_positions = torch.stack([rand_x, rand_y], dim=-1)

                final_positions = agent.goal_pos.clone()

                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue

                    candidate = candidate_positions[bi]
                    collision_detected = False

                    # 1) Check vs other agents' goals
                    for other_agent in self.world.agents:
                        if other_agent == agent:
                            continue
                        dist = torch.norm(candidate - other_agent.goal_pos[bi], p=2)
                        if dist < (2 * self.goal_radius + self.collision_gap):
                            collision_detected = True
                            break

                    # 2) Check vs agent itself
                    if not collision_detected:
                        dist_to_own_agent = torch.norm(candidate - agent.state.pos[bi], p=2)
                        if dist_to_own_agent < (self.goal_radius + self.agent_radius + self.collision_gap):
                            collision_detected = True

                    # 3) Check vs obstacles
                    if not collision_detected:
                        for obs in self.obstacles:
                            dist_to_obs = torch.norm(candidate - obs.state.pos[bi], p=2)
                            if dist_to_obs < (self.goal_radius + self.obstacle_radius + self.collision_gap):
                                collision_detected = True
                                break

                    if not collision_detected:
                        final_positions[bi] = candidate
                        placed[bi] = True

                agent.goal_pos = final_positions

            if tries >= max_tries:
                pass

    def respawn_agents_no_overlap(self, env_index=None, max_tries=50):
        batch_dim = self._world.batch_dim
        device = self._world.device
        x_min = -self._world.x_semidim + self.agent_radius
        x_max = self._world.x_semidim - self.agent_radius
        y_min = -self._world.y_semidim + self.agent_radius
        y_max = self._world.y_semidim - self.agent_radius
        gap_coef = 1.1
        for agent in self.world.agents:
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0
            while not torch.all(placed) and tries < max_tries:
                tries += 1
                rand_x = torch.rand(batch_dim, device=device) * (x_max - x_min) + x_min
                rand_y = torch.rand(batch_dim, device=device) * (y_max - y_min) + y_min
                candidate_positions = torch.stack([rand_x, rand_y], dim=-1)

                final_positions = agent.state.pos.clone()
                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue

                    final_positions[bi] = candidate_positions[bi]
                    collision_detected = False

                    # vs other agents
                    for other_agent in self.world.agents:
                        if other_agent == agent:
                            continue
                        dist = torch.norm(final_positions[bi] - other_agent.state.pos[bi], p=2)
                        if dist < (2 * self.agent_radius + gap_coef*self.collision_gap):
                            collision_detected = True
                            break

                    # vs obstacles
                    if not collision_detected:
                        for obs in self.obstacles:
                            dist = torch.norm(final_positions[bi] - obs.state.pos[bi], p=2)
                            if dist < (self.agent_radius + self.obstacle_radius + gap_coef*self.collision_gap):
                                collision_detected = True
                                break

                    if not collision_detected:
                        placed[bi] = True

                agent.set_pos(final_positions, batch_index=None)
                agent.set_vel(torch.zeros_like(agent.state.vel), batch_index=None)

            if tries >= max_tries:
                pass

    def respawn_obstacles_no_overlap(self, env_index=None, max_tries=50):
        batch_dim = self._world.batch_dim
        device = self._world.device
        x_min = -self._world.x_semidim + self.obstacle_radius
        x_max = self._world.x_semidim - self.obstacle_radius
        y_min = -self._world.y_semidim + self.obstacle_radius
        y_max = self._world.y_semidim - self.obstacle_radius

        for obs in self.obstacles:
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0
            while not torch.all(placed) and tries < max_tries:
                tries += 1
                rand_x = torch.rand(batch_dim, device=device) * (x_max - x_min) + x_min
                rand_y = torch.rand(batch_dim, device=device) * (y_max - y_min) + y_min
                candidate_positions = torch.stack([rand_x, rand_y], dim=-1)

                final_positions = obs.state.pos.clone()
                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue

                    collision_detected = False
                    final_positions[bi] = candidate_positions[bi]

                    # vs other obstacles
                    for other_obs in self.obstacles:
                        if other_obs == obs:
                            continue
                        dist = torch.norm(final_positions[bi] - other_obs.state.pos[bi], p=2)
                        if dist < (2 * self.obstacle_radius + self.collision_gap):
                            collision_detected = True
                            break

                    # vs agents
                    if not collision_detected:
                        for a in self.world.agents:
                            dist = torch.norm(final_positions[bi] - a.state.pos[bi], p=2)
                            if dist < (self.agent_radius + self.obstacle_radius + self.collision_gap):
                                collision_detected = True
                                break

                    if not collision_detected:
                        placed[bi] = True

                obs.set_pos(final_positions, batch_index=None)

            if tries >= max_tries:
                pass

    def spawn_walls(self, env_index=None):
        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim
        for i, wall in enumerate(self.walls):
            if i == 0:  # Top
                pos = torch.tensor([0.0, world_y_semidim], device=self._world.device)
                rot = 0.0
            elif i == 1:  # Bottom
                pos = torch.tensor([0.0, -world_y_semidim], device=self._world.device)
                rot = 0.0
            elif i == 2:  # Left
                pos = torch.tensor([-world_x_semidim, 0.0], device=self._world.device)
                rot = torch.pi / 2
            else:  # Right
                pos = torch.tensor([world_x_semidim, 0.0], device=self._world.device)
                rot = torch.pi / 2

            wall.set_pos(pos, batch_index=env_index)
            wall.set_rot(torch.tensor([rot], device=self._world.device), batch_index=env_index)

    def reward(self, agent: Agent):
        """
        Summation of shaped reward + final reward + collision penalties (+ delta penalty if CBF).
        """
        is_first = agent == self.world.agents[0]
        if is_first:
            # Update each agent's shaped reward
            for a in self.world.agents:
                self.agent_reward(a)
                # Reset collision penalty
                a.agent_collision_rew[:] = 0

            # Collisions with obstacles
            for a in self.world.agents:
                for obs in self.obstacles:
                    if self._world.collides(a, obs):
                        distance = self._world.get_distance(a, obs)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        obs.collision_timer[collision_mask] = 5

            # Collisions with walls
            for a in self.world.agents:
                for wall in self.walls:
                    if self._world.collides(a, wall):
                        distance = self._world.get_distance(a, wall)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        wall.collision_timer[collision_mask] = 5

            # Collisions: agent-agent
            for i, a in enumerate(self.world.agents):
                for j, b in enumerate(self.world.agents):
                    if j <= i:
                        continue
                    if self._world.collides(a, b):
                        distance = self._world.get_distance(a, b)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        b.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        b.collision_timer[collision_mask] = 5

            # Decrement collision timers
            for a in self.world.agents:
                a.collision_timer = torch.maximum(a.collision_timer - 1, torch.zeros_like(a.collision_timer))
            for wall in self.walls:
                wall.collision_timer = torch.maximum(wall.collision_timer - 1, torch.zeros_like(wall.collision_timer))
            for obs in self.obstacles:
                obs.collision_timer = torch.maximum(obs.collision_timer - 1, torch.zeros_like(obs.collision_timer))

        # Combine everything: shaped reward + final reward + collisions (+ delta if CBF)
        if self.agent_control_type == QP_CONTROL_CBF:
            total_rew = agent.individual_rew + agent.final_rew + agent.agent_collision_rew + agent.delta_rew
        else:
            total_rew = agent.individual_rew + agent.final_rew + agent.agent_collision_rew

        return total_rew

    def agent_reward(self, agent: Agent):
        """
        Distance-shaping reward + near-goal bonus + final reward for being very close.
        If using CBF, add a penalty for delta as well.
        """
        device = self._world.device

        # Zero them out before computing
        agent.individual_rew[:] = 0
        agent.final_rew[:] = 0

        if self.agent_control_type == QP_CONTROL_CBF:
            agent.delta_rew[:] = 0

        dist_to_goal = torch.norm(agent.state.pos - agent.goal_pos, dim=-1)

        # Initialize prev_dist if needed
        if not hasattr(agent, "prev_dist"):
            agent.prev_dist = dist_to_goal.clone()

        # Shaping = old_distance - new_distance
        shaping_rew = agent.prev_dist - dist_to_goal
        agent.prev_dist = dist_to_goal.clone()

        # Near-goal bonus if < 0.2
        near_goal_mask = dist_to_goal < 0.2
        shaping_rew[near_goal_mask] += self.lambda_g

        # Final reward if REALLY close, e.g. < 0.05
        on_goal_mask = dist_to_goal < 0.05
        agent.final_rew[on_goal_mask] = self.final_reward

        # Assign shaping reward
        agent.individual_rew = shaping_rew

        # If CBF, penalize delta
        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent.action, "delta"):
                agent.delta_rew = -self.lambda_u * agent.action.delta
            else:
                agent.delta_rew[:] = 0

        return agent.individual_rew

    def observation(self, agent: Agent):
        device = self._world.device
        dist_to_goal = torch.norm(agent.state.pos - agent.goal_pos, dim=-1, keepdim=True)
        rel_goal_pos = agent.goal_pos - agent.state.pos

        obs = {
            "agent_pos": agent.state.pos,
            "agent_vel": agent.state.vel,
            "goal_pos": agent.goal_pos,
            "distance_to_goal": dist_to_goal,
            "rel_goal_pos": rel_goal_pos,
        }
        return obs

    def done(self):
        dones = torch.zeros(self._world.batch_dim, device=self._world.device, dtype=torch.bool)
        if not self.qp_solver_initialized:
            dones[:] = True
        return dones

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        info = {
            "collision_reward": agent.agent_collision_rew,
            "individual_reward": agent.individual_rew,
            "final_reward": agent.final_rew,
            "total_reward_without_delta_rew": agent.individual_rew + agent.final_rew + agent.agent_collision_rew,
        }
        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent, "original_action_u"):
                info["original_action_u"] = agent.original_action_u.squeeze(-1)
            else:
                info["original_action_u"] = torch.zeros_like(agent.agent_collision_rew)

        if hasattr(agent.action, "delta_rew"):
            info["delta_reward"] = agent.delta_rew

        return info

    def init_qp_layer(self):
        """
        Set up the QP problem for controlling each agent:
          - Always includes an objective to move toward (x_goal, y_goal).
          - Delta is a slack variable for collision constraints / boundary constraints.
        """
        if self.agent_control_type == VELOCITY_CONTROL:
            return None

        # "Obstacle" = other agents + obstacles
        N_ = (self.n_agents - 1) + self.n_obstacles

        u = cp.Variable(2)
        delta = cp.Variable(1)

        ego_position = cp.Parameter(2)
        ego_velocity = cp.Parameter(2)

        obstacle_positions = cp.Parameter((N_, 2))
        h_values = cp.Parameter(N_)
        dx_vel_0 = cp.Parameter(N_)
        dy_vel_1 = cp.Parameter(N_)

        parameters = [
            ego_position,
            ego_velocity,
            obstacle_positions,
            h_values,
            dx_vel_0,
            dy_vel_1,
        ]

        constraints = []
        # CBF-like constraints or simple constraints for obstacles
        for j in range(N_):
            h = h_values[j]
            lgh = 2 * (
                dx_vel_0[j] + dy_vel_1[j]
                + u[0] * (ego_position[0] - obstacle_positions[j, 0])
                + u[1] * (ego_position[1] - obstacle_positions[j, 1])
            )
            if self.agent_control_type in [QP_CONTROL_CBF, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
                constraints.append(lgh + h + delta >= 0)
            else:
                constraints.append(lgh + h >= 0)

        # Boundaries
        safety_margin = 1.05 * (self.agent_radius + self.collision_gap)
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            # This block demonstrates how boundary constraints might be handled with a slack variable.
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0])
                    <= self._world.x_semidim - safety_margin + delta,
                    ego_position[0] + (ego_velocity[0] + u[0])
                    >= -self._world.x_semidim + safety_margin - delta,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    <= self._world.y_semidim - safety_margin + delta,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    >= -self._world.y_semidim + safety_margin - delta,
                    cp.norm(u) <= self.max_velocity + delta, # slightly inconsistent to use same slack variable in multiple places, but a_range is capped anyway
                    delta >= 0,
                ]
            )
        else:
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0])
                    <= self._world.x_semidim - safety_margin,
                    ego_position[0] + (ego_velocity[0] + u[0])
                    >= -self._world.x_semidim + safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    <= self._world.y_semidim - safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    >= -self._world.y_semidim + safety_margin,
                    cp.norm(u) <= self.max_velocity,
                    delta >= 0,
                ]
            )

        # Depending on the control type, define the objective
        if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([a, b, x_target, y_target])
            constraints.append(cp.norm(u - a) <= b + delta)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + 999.0 * delta
            )

        elif self.agent_control_type in [QP_CONTROL_OBJECTIVE, QP_CONTROL_RVO]:
            # For both objective and RVO, we use the same approach:
            # The difference is in how the target positions (x_target, y_target) are determined
            # For QP_CONTROL_OBJECTIVE, these come from agent.goal_pos
            # For QP_CONTROL_RVO, these come from RVO computed velocities
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([x_target, y_target])
            constraints.append(delta >= 0)
            constraints.append(delta <= 10000)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + 999.0 * delta
            )

        elif self.agent_control_type == QP_CONTROL_CONSTRAINT:
            # Linear constraint a^T u <= b + delta
            # plus objective to approach (x_goal,y_goal).
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            x_goal = cp.Parameter(1)
            y_goal = cp.Parameter(1)
            parameters.extend([a, b, x_goal, y_goal])

            # Use the instance variable read from kwargs
            if self.constraint_on:
                constraints.append(cp.norm(u - a) <= b + delta)
            else:
                constraints.append(0*a.T @ u <= 0*b + delta)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_goal,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_goal,
                        ]
                    ),
                    2,
                )
                + 999.0 * delta
            )

        elif self.agent_control_type == QP_CONTROL_CBF:
            # Let agent param be e.g. "k_cbf" in action or so,
            # but also pass x_goal, y_goal for the objective.
            x_goal = cp.Parameter(1)
            y_goal = cp.Parameter(1)
            parameters.extend([x_goal, y_goal])

            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_goal,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_goal,
                        ]
                    ),
                    2
                )
                + 999.0 * delta
            )
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        prob = cp.Problem(objective, constraints)
        qp_layer = CvxpyLayer(prob, parameters=parameters, variables=[u, delta])
        return qp_layer

    def solve_qp_layer(self, parameters_batch, solver_args):
        """
        Solve batched QP. If fail, mark done.
        """
        try:
            with torch.no_grad():
                u_qp_batch, delta_batch = self.qp_layer(*parameters_batch, solver_args=solver_args)
            return u_qp_batch, delta_batch
        except Exception as e:
            print(f"SolverError encountered: {e}. Marking the environment as done.")
            self.qp_solver_initialized = False
            return None, None

    def process_action(self, agent: Agent):
        """
        If velocity control, clamp velocity.
        Otherwise, gather all agents' states and solve QP in one big batch.
        """
        if self.agent_control_type == VELOCITY_CONTROL:
            # Clamp square to circle
            agent.action.u = TorchUtils.clamp_with_norm(
                agent.action.u, agent.action.u_range
            )

            # Zero small input
            action_norm = torch.linalg.vector_norm(agent.action.u, dim=1)
            agent.action.u[action_norm < self.min_input_norm] = 0

            # Reset controller
            vel_is_zero = torch.linalg.vector_norm(agent.action.u, dim=1) < 1e-3
            agent.controller.reset(vel_is_zero)
            agent.controller.process_force()

        else:
            is_first = agent == self.world.agents[0]
            if is_first:
                # Solve QP for all agents
                pos_list = []
                vel_list = []
                for a in self.world.agents:
                    pos_list.append(a.state.pos)
                    vel_list.append(a.state.vel)
                    a.original_action_u = a.action.u.clone()

                # Special handling for RVO control type - compute RVO velocities
                if self.agent_control_type == QP_CONTROL_RVO:
                    # Import the RVO function
                    from RVO import compute_velocities_with_goals
                    
                    # Compute RVO velocities for each environment in the batch
                    for env_idx in range(self._world.batch_dim):
                        # Extract positions, velocities, and goals for this environment
                        positions = np.array([a.state.pos[env_idx].cpu().numpy() for a in self.world.agents])
                        velocities = np.array([a.state.vel[env_idx].cpu().numpy() for a in self.world.agents])
                        goals = np.array([a.goal_pos[env_idx].cpu().numpy() for a in self.world.agents])
                        
                        # Compute RVO velocities using the refactored function
                        rvo_velocities = compute_velocities_with_goals(
                            positions=positions,
                            velocities=velocities,
                            goals=goals,
                            dt=self._world.dt,
                            max_speed=self.v_range,
                            visibility_radius=self.rvo_visibility_radius,
                            target_region=self.rvo_target_region,
                            simple_mode=False  # Explicitly use standard mode
                        )
                        
                        # If RVO computation failed, use simple goal-directed velocities
                        if rvo_velocities is None:
                            print(f"RVO computation failed for environment {env_idx}, using fallback.")
                            # Use the simple mode directly as a fallback
                            from RVO import compute_velocities_simple
                            rvo_velocities = compute_velocities_simple(
                                positions=positions,
                                velocities=velocities,
                                goals=goals,
                                max_speed=self.v_range,
                                visibility_radius=self.rvo_visibility_radius,
                                target_region=self.rvo_target_region
                            )
                        
                        # Store the RVO velocities for each agent
                        for i, a in enumerate(self.world.agents):
                            # Convert numpy array to torch tensor
                            a.rvo_velocity = torch.tensor(rvo_velocities[i], device=self._world.device)
                            
                            # Compute target position by adding RVO velocity to current position
                            # This will be used as the target for the QP objective
                            current_pos = a.state.pos[env_idx]
                            target_pos = current_pos + a.rvo_velocity
                            
                            # Store as action for the QP_CONTROL_OBJECTIVE handling
                            a.action.u[env_idx, 0] = target_pos[0]  # x target
                            a.action.u[env_idx, 1] = target_pos[1]  # y target

                # Build up the collision parameters for each agent
                obstacle_positions_list = []
                current_h_values_list = []
                dx_vel_0_list = []
                dy_vel_1_list = []

                for agent_idx in range(self.n_agents):
                    for env_idx in range(self._world.batch_dim):
                        ego_pos = pos_list[agent_idx][env_idx]
                        ego_vel = vel_list[agent_idx][env_idx]

                        # other agents
                        other_indices = [i for i in range(self.n_agents) if i != agent_idx]
                        other_pos = torch.stack([pos_list[i][env_idx] for i in other_indices], dim=0)

                        # If we have obstacles, collect them
                        if self.n_obstacles > 0:
                            obs_positions = torch.stack(
                                [obs.state.pos[env_idx] for obs in self.obstacles],
                                dim=0
                            )
                            full_obstacles_pos = torch.cat([other_pos, obs_positions], dim=0)
                        else:
                            # No obstacles, so "full_obstacles_pos" is just the other agents
                            full_obstacles_pos = other_pos

                        obstacle_positions_list.append(full_obstacles_pos)

                        # Distances^2 to other agents & obstacles
                        dist_agents = torch.norm(other_pos - ego_pos, dim=1) ** 2
                        if self.n_obstacles > 0:
                            dist_obstacles = torch.norm(obs_positions - ego_pos, dim=1) ** 2
                            all_distances = torch.cat([dist_agents, dist_obstacles], dim=0)
                        else:
                            all_distances = dist_agents

                        # Build h-values
                        if self.agent_control_type == QP_CONTROL_CBF:
                            # Use the agent's action to set k_cbf, etc.
                            k_cbf_val = (
                                self.world.agents[agent_idx].action.u[env_idx][0]
                                + self.k_cbf_range
                            ) / 2
                            p_cbf = 1.0
                            h_values = k_cbf_val * (all_distances - self.min_distance_between_entities**2) ** p_cbf
                        elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                            k_cbf_val = 0.9
                            p_cbf = 1.0
                            h_values = k_cbf_val * (all_distances - self.min_distance_between_entities**2) ** p_cbf
                        else:
                            # Default
                            h_values = all_distances - self.min_distance_between_entities**2

                        current_h_values_list.append(h_values)

                        # partial derivatives
                        dx_agents = ego_pos[0] - other_pos[:, 0]
                        dy_agents = ego_pos[1] - other_pos[:, 1]
                        dx_vel_agents = dx_agents * ego_vel[0]
                        dy_vel_agents = dy_agents * ego_vel[1]

                        if self.n_obstacles > 0:
                            dx_obs = ego_pos[0] - obs_positions[:, 0]
                            dy_obs = ego_pos[1] - obs_positions[:, 1]
                            dx_vel_obs = dx_obs * ego_vel[0]
                            dy_vel_obs = dy_obs * ego_vel[1]

                            dx_vel_all = torch.cat([dx_vel_agents, dx_vel_obs], dim=0)
                            dy_vel_all = torch.cat([dy_vel_agents, dy_vel_obs], dim=0)
                        else:
                            dx_vel_all = dx_vel_agents
                            dy_vel_all = dy_vel_agents

                        dx_vel_0_list.append(dx_vel_all)
                        dy_vel_1_list.append(dy_vel_all)

                # Stack into big batch
                ego_positions_batch = torch.cat(pos_list, dim=0).view(self._world.batch_dim * self.n_agents, 2)
                ego_velocities_batch = torch.cat(vel_list, dim=0).view(self._world.batch_dim * self.n_agents, 2)
                obstacle_positions_batch = torch.stack(obstacle_positions_list, dim=0)
                current_h_values_batch = torch.stack(current_h_values_list, dim=0)
                dx_vel_0_batch = torch.stack(dx_vel_0_list, dim=0)
                dy_vel_1_batch = torch.stack(dy_vel_1_list, dim=0)

                parameters_batch = [
                    ego_positions_batch,
                    ego_velocities_batch,
                    obstacle_positions_batch,
                    current_h_values_batch,
                    dx_vel_0_batch,
                    dy_vel_1_batch,
                ]

                # Now attach the extra parameters for each control type
                if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                    a_list, b_list, x_list, y_list = [], [], [], []
                    for a in self.world.agents:
                        # a.action.u shape ~ (batch_dim, 5)
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        x_list.append(a.action.u[:, 3:4])
                        y_list.append(a.action.u[:, 4:5])
                    A = torch.cat(a_list, dim=0)
                    B = torch.cat(b_list, dim=0)
                    X = torch.cat(x_list, dim=0)
                    Y = torch.cat(y_list, dim=0)
                    parameters_batch.extend([A, B, X, Y])

                elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
                    x_list, y_list = [], []
                    for ag in self.world.agents:
                        # We'll just pass agent.goal_pos
                        x_list.append(ag.goal_pos[:, 0:1])
                        y_list.append(ag.goal_pos[:, 1:2])
                    X = torch.cat(x_list, dim=0)
                    Y = torch.cat(y_list, dim=0)
                    parameters_batch.extend([X, Y])
                
                elif self.agent_control_type == QP_CONTROL_RVO:
                    x_list, y_list = [], []
                    for ag in self.world.agents:
                        # For RVO, use the target positions computed from RVO velocities
                        x_list.append(ag.action.u[:, 0:1])
                        y_list.append(ag.action.u[:, 1:2])
                    X = torch.cat(x_list, dim=0)
                    Y = torch.cat(y_list, dim=0)
                    parameters_batch.extend([X, Y])

                elif self.agent_control_type == QP_CONTROL_CONSTRAINT:
                    a_list, b_list, gx_list, gy_list = [], [], [], []
                    for ag in self.world.agents:
                        # Suppose the agent's action has shape (batch_dim, 3): (a_x, a_y, b)
                        a_list.append(ag.action.u[:, :2])
                        b_list.append(ag.action.u[:, 2:3])

                        # Now pass each agent's goal
                        gx_list.append(ag.goal_pos[:, 0:1])
                        gy_list.append(ag.goal_pos[:, 1:2])
                    A = torch.cat(a_list, dim=0)
                    B = torch.cat(b_list, dim=0)
                    GX = torch.cat(gx_list, dim=0)
                    GY = torch.cat(gy_list, dim=0)
                    parameters_batch.extend([A, B, GX, GY])

                elif self.agent_control_type == QP_CONTROL_CBF:
                    # but also pass x_goal, y_goal
                    x_list, y_list = [], []
                    for ag in self.world.agents:
                        # e.g. shape (batch_dim, 1) => the "k_cbf"
                        x_list.append(ag.goal_pos[:, 0:1])
                        y_list.append(ag.goal_pos[:, 1:2])
                    X = torch.cat(x_list, dim=0)
                    Y = torch.cat(y_list, dim=0)
                    parameters_batch.extend([X, Y])

                elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                    # Example usage
                    a_list, b_list, gx_list, gy_list = [], [], [], []
                    for ag in self.world.agents:
                        a_list.append(ag.action.u[:, :2])
                        b_list.append(ag.action.u[:, 2:3])
                        # also pass goal
                        gx_list.append(ag.goal_pos[:, 0:1])
                        gy_list.append(ag.goal_pos[:, 1:2])
                    A = torch.cat(a_list, dim=0)
                    B = torch.cat(b_list, dim=0)
                    GX = torch.cat(gx_list, dim=0)
                    GY = torch.cat(gy_list, dim=0)
                    parameters_batch.extend([A, B, GX, GY])

                # Solve QP
                u_qp_batch, delta_batch = self.solve_qp_layer(parameters_batch, solver_args={"eps": 1e-8})

                if u_qp_batch is None or delta_batch is None:
                    print("QP solver failed, fallback to zero.")
                    for ag in self.world.agents:
                        ag.action.u = torch.zeros((ag.action.u.size(0), 2), device=self._world.device)
                        vel_is_zero = torch.linalg.vector_norm(ag.action.u[:, :2], dim=1) < 1e-3
                        ag.controller.reset(vel_is_zero)
                        ag.controller.process_force()
                else:
                    # Distribute solutions
                    split_sizes = [a.action.u.size(0) for a in self.world.agents]
                    u_qp_list = torch.split(u_qp_batch, split_sizes, dim=0)
                    delta_list = torch.split(delta_batch, split_sizes, dim=0)

                    for i, ag in enumerate(self.world.agents):
                        ag.action.u = u_qp_list[i]
                        ag.action.delta = delta_list[i].squeeze()
                        ag.action.u = TorchUtils.clamp_with_norm(ag.action.u, self.u_range)
                        action_norm = torch.linalg.vector_norm(ag.action.u, dim=1)
                        ag.action.u[action_norm < self.min_input_norm] = 0
                        ag.computed_velocity = ag.action.u.clone()

                        vel_is_zero = action_norm < 1e-3
                        ag.controller.reset(vel_is_zero)
                        ag.controller.process_force()

    def extra_render(self, env_index: int = 0) -> "List[rendering.Geom]":
        geoms: List[rendering.Geom] = []

        for wall in self.walls:
            if wall.collision_timer[env_index] > 0:
                wall_color = (1.0, 0.0, 0.0)
            else:
                wall_color = (0.0, 0.0, 0.0)
            wall._color = wall_color

        # Obstacles
        for obs in self.obstacles:
            if obs.collision_timer[env_index] > 0:
                obstacle_color = (1.0, 0.0, 0.0)
            else:
                obstacle_color = (0.0, 0.0, 0.0)
            obstacle_geom = rendering.make_circle(self.obstacle_radius, res=20, filled=True)
            obstacle_transform = rendering.Transform(translation=obs.state.pos[env_index].cpu().numpy())
            obstacle_geom.add_attr(obstacle_transform)
            obstacle_geom.set_color(*obstacle_color)
            geoms.append(obstacle_geom)

        # Agents & goals
        for idx, agent in enumerate(self._world.agents):
            # Determine agent color based on collision_timer
            if hasattr(agent, 'collision_timer') and agent.collision_timer[env_index] > 0 and not self.disable_collision_viz:
                agent_color = (1.0, 0.0, 0.0)
            else:
                agent_color = agent.original_color

            # Possibly shade color if QP_CONTROL_CBF
            shade_factor = 1.0
            if self.agent_control_type == QP_CONTROL_CBF:
                if hasattr(agent, "original_action_u"):
                    k_cbf = agent.original_action_u[env_index][0].item()
                    k_cbf = (k_cbf + getattr(self, "k_cbf_range", 4.0)) / 2
                    normalized_k_cbf = k_cbf / getattr(self, "k_cbf_range", 4.0)
                    min_shade = 0.5
                    shade_factor = min_shade + (1.0 - min_shade) * normalized_k_cbf

            shaded_color = (
                agent_color[0] * shade_factor,
                agent_color[1] * shade_factor,
                agent_color[2] * shade_factor,
            )
            agent.color = agent_color

            # Draw goal
            goal_pos = agent.goal_pos[env_index].cpu().numpy()
            goal_geom = rendering.make_circle(radius=0.2, res=20, filled=True)
            goal_transform = rendering.Transform(translation=goal_pos)
            goal_geom.add_attr(goal_transform)
            goal_geom.set_color(
                agent.original_color[0], agent.original_color[1], agent.original_color[2], alpha=0.3
            )
            geoms.append(goal_geom)

        # Communication lines
        edge_radius = 1.75
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= edge_radius:
                    color = Color.BLACK.value
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)

        return geoms


# Example usage:
if __name__ == "__main__":
    import sys
    import os
    # Add the parent directory (supplementary_material) to sys.path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils import use_vmas_env

    for i in range(5):
        use_vmas_env(
            scenario=WaypointScenario(),
            render=True,
            save_render=False,
            num_envs=1,
            n_steps=200,
            device="cpu",
            continuous_actions=True,
            random_action=True,
            deterministic_action_value=0.0,
            n_agents=6,
            n_obstacles=0,
        )
