import typing
from typing import Dict, List

import torch
import numpy as np
from torch import Tensor
from vmas.simulator.core import Agent, Sphere, World, Landmark, Line
from vmas.simulator.dynamics.holonomic import Holonomic
from vmas.simulator.dynamics.common import Dynamics
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import Color, TorchUtils
from vmas.simulator.controllers.velocity_controller import VelocityController
from vmas.simulator import rendering

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom

import numpy as np


# Agent control types
VELOCITY_CONTROL = "VELOCITY_CONTROL"
QP_CONTROL_CONSTRAINT = "QP_CONTROL_CONSTRAINT"
QP_CONTROL_CBF = "QP_CONTROL_CBF"

# LEGACY, NOT IMPLEMENTED, or only partially used:
QP_CONTROL_OBJECTIVE_AND_CONSTRAINT = "QP_CONTROL_OBJECTIVE_AND_CONSTRAINT"
QP_CONTROL_OBJECTIVE = "QP_CONTROL_OBJECTIVE"
QP_CONTROL_CONSTRAINT_ROBOMASTER = "QP_CONTROL_CONSTRAINT_ROBOMASTER"


class HolonomicQP(Dynamics):
    @property
    def needed_action_size(self) -> int:
        return 2

    def process_action(self):
        self.agent.state.force = self.agent.action.u[:, :2]


class SensorCoverageScenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):
        # Force PyTorch to use float32 as the default dtype to ensure consistency
        original_default = torch.get_default_dtype()
        torch.set_default_dtype(torch.float32)
        
        self.agent_control_type = kwargs.get("agent_control_type", QP_CONTROL_CONSTRAINT)
        self.constraint_on = kwargs.get("constraint_on", False) # Set to True to enable constraint design; otherwise runs the expert controller alone
        self.disable_collision_viz = kwargs.get("disable_collision_viz", False)
        
        # Debug: Print the control type being set
        print(f"Setting agent_control_type to: {self.agent_control_type}")

        self.n_agents = kwargs.get("n_agents", 4)
        self.n_obstacles = kwargs.get("n_obstacles", 2)  # Default to 2 obstacles
        self.v_range = kwargs.get("v_range", 1.0)

        if self.agent_control_type == VELOCITY_CONTROL:
            self.a_range = kwargs.get("a_range", 0.3) # trains better if this is lower (use 0.6 under r=0.2, 0.3 ow)
        else:
            self.a_range = kwargs.get("a_range", 1.0)
        self.linear_friction = kwargs.get("linear_friction", 0.0)
        self.collisions = kwargs.get("collisions", True)

        self.max_velocity = kwargs.get("max_velocity", 1.0)
        self.max_steps = kwargs.get("max_steps", 100)
        # Use a smaller agent radius to give more room for movement
        self.agent_radius = kwargs.get("agent_radius", 0.3)  # Reduced from 0.3

        # Connectivity parameters
        self.formation_radius = kwargs.get("formation_radius", 2)  # Increased from 1.75 for more flexibility

        # QP / reward coefficients
        self.lambda_u = kwargs.get("LAMBDA_U", 10.0)  # collision/connectivity penalty
        self.lambda_g = kwargs.get("LAMBDA_G", 1.0)   # near-goal bonus (distance < 0.1)
        self.delta_scaling = kwargs.get("delta_scaling", 99.0)

        self.collision_gap = 0.01 # Increased from 0.0002
        self.min_distance_between_entities = self.agent_radius * 2 + self.collision_gap

        # Arena bounds
        self.world_x_semidim = 3.0  # Reducing from 3.0 to 1.0
        self.world_y_semidim = 3.0  # Reducing from 3.0 to 1.0

        self.min_input_norm = kwargs.get("min_input_norm", 0.03)

        self.f_range = self.a_range + self.linear_friction
        # For velocity control we clamp the velocity, otherwise clamp the force
        if self.agent_control_type == VELOCITY_CONTROL:
            self.u_range = self.v_range
        else:
            self.u_range = self.f_range

        # Velocity controller internal params
        controller_params = [2, 6, 0.002]

        # Simple camera settings similar to narrowcorridor_env.py
        self.resolution_factor = kwargs.pop("resolution_factor", 1200) # Increased default
        self.viewer_zoom = 2.0  # Slightly zoomed out to see the entire arena
        self.viewer_size = kwargs.pop(
            "viewer_size",
            (
                int(self.resolution_factor),
                int(self.resolution_factor),
            ),
        )

        # Create the underlying world
        world = World(
            batch_dim,
            device,
            dt=0.1,
            substeps=1,
            linear_friction=self.linear_friction,
            drag=0.0,
            x_semidim=self.world_x_semidim + self.agent_radius,
            y_semidim=self.world_y_semidim + self.agent_radius,
        )

        known_colors = [
            (0.22, 0.40, 0.72),  # bluish
            (0.22, 0.72, 0.40),  # greenish
            (0.72, 0.40, 0.22),  # more
            (0.72, 0.22, 0.40),
        ]

        # Determine the appropriate action size / range for each control type
        if self.agent_control_type == VELOCITY_CONTROL:
            # Using an action size=2 for velocity is typical (x, y).
            action_size = 2
            u_range = self.u_range
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            action_size = 2 + 1 + 1 + 1
            u_range = [
                self.u_range,
                self.u_range,
                self.u_range,
                self.world_x_semidim,
                self.world_y_semidim,
            ]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            action_size = 2
            u_range = [self.world_x_semidim, self.world_y_semidim]
        elif self.agent_control_type in [QP_CONTROL_CONSTRAINT, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
            action_size = 2 + 1
            u_range = self.u_range
        elif self.agent_control_type == QP_CONTROL_CBF:
            self.k_cbf_range = kwargs.get("k_cbf_range", 2.0)
            action_size = 1
            u_range = self.k_cbf_range
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        # Create agents
        for i in range(self.n_agents):
            if self.agent_control_type == VELOCITY_CONTROL:
                dynamics = Holonomic()
            else:
                dynamics = HolonomicQP()

            agent = Agent(
                name=f"agent_{i}",
                collide=self.collisions,
                alpha=1.0,
                shape=Sphere(radius=self.agent_radius),
                render_action=True,
                v_range=self.v_range,
                f_range=self.f_range,
                u_range=u_range,
                dynamics=dynamics,
                action_size=action_size,
            )

            # Each agent's goal
            agent.goal_pos = torch.zeros(batch_dim, 2, device=device, dtype=torch.float32)

            # Target direction used by QP controllers
            agent.target_direction = torch.ones(batch_dim, device=device, dtype=torch.float32)

            # Color
            agent.original_color = known_colors[i % len(known_colors)]
            agent.color = agent.original_color

            # Controller
            agent.controller = VelocityController(agent, world, controller_params, "standard")

            # Initialize rewards/tensors
            agent.agent_collision_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
            agent.individual_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)
            agent.computed_velocity = torch.zeros(batch_dim, 2, device=device, dtype=torch.float32)
            agent.collision_timer = torch.zeros(batch_dim, device=device, dtype=torch.float32)
            
            # Add formation reward for connectivity
            agent.formation_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)

            if self.agent_control_type != VELOCITY_CONTROL:
                agent.action.delta = torch.zeros(batch_dim, device=device, dtype=torch.float32)
                agent.delta_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)

            # We'll store a 'final_rew' in case the agent is truly on its goal
            agent.final_rew = torch.zeros(batch_dim, device=device, dtype=torch.float32)

            world.add_agent(agent)
            
        # Restore original default dtype
        torch.set_default_dtype(original_default)

        # Walls
        self.walls = []
        for i in range(4):
            if i < 2:
                wall_length = 2 * (world.x_semidim)
            else:
                wall_length = 2 * (world.y_semidim)
            wall = Landmark(
                name=f"wall_{i}",
                collide=True,
                shape=Line(length=wall_length),
                color=Color.BLACK,
            )
            wall.original_color = (0.0, 0.0, 0.0)
            wall.collision_timer = torch.zeros(batch_dim, device=device, dtype=torch.float32)
            world.add_landmark(wall)
            self.walls.append(wall)

        # Obstacles
        self.obstacle_radius = self.agent_radius
        self.obstacles = []
        for i in range(self.n_obstacles):
            obs = Landmark(
                name=f"obstacle_{i}",
                collide=True,
                shape=Sphere(radius=self.obstacle_radius),
                color=Color.BLACK,
            )
            obs.collision_timer = torch.zeros(batch_dim, device=device, dtype=torch.float32)
            world.add_landmark(obs)
            self.obstacles.append(obs)

        self._world = world

        # Spawn walls (fixed)
        self.spawn_walls()

        # Init QP layer
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        return world

    def reset_world_at(self, env_index: int = None):
        """Reinitialize the world with agents in a valid formation"""
        # print("Resetting the environment.") # Removed debug print
        
        # Reset the QP layer
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        # Safety margin from the edge of the world
        safety_margin = 3 * (self.agent_radius + self.collision_gap)

        # Calculate a tight formation that ensures all pairwise distances are well within formation_radius
        # For n agents, we'll place them in a grid or compact formation
        n_agents = len(self.world.agents)
        
        # We want to ensure the maximum possible distance between any two agents is less than formation_radius
        # Using a grid/circular arrangement where the maximum distance is the diagonal of the grid or diameter of circle
        
        # For safety, we'll use 50% of the formation radius as our maximum initial distance
        max_allowed_distance = 0.5 * self.formation_radius
        
        # Determine grid dimensions based on number of agents
        if n_agents <= 4:
            # For 4 or fewer agents, use a 2x2 grid (or smaller)
            rows = min(2, n_agents)
            cols = (n_agents + rows - 1) // rows  # Ceiling division
        else:
            # For more agents, try to make a roughly square grid
            rows = int(np.ceil(np.sqrt(n_agents)))
            cols = (n_agents + rows - 1) // rows
        
        # Calculate spacing to ensure maximum distance (diagonal) is within limits
        # Maximum diagonal distance in a grid = sqrt((rows-1)^2 + (cols-1)^2) * spacing
        max_diagonal_units = np.sqrt((rows-1)**2 + (cols-1)**2)
        
        # If max_diagonal_units is 0 (single agent), set it to 1 to avoid division by zero
        if max_diagonal_units == 0:
            max_diagonal_units = 1
            
        spacing = max_allowed_distance / max_diagonal_units
        
        # Start from the bottom of the arena with a safety margin
        x_center = 0
        y_start = (-self._world.y_semidim) + safety_margin + (rows-1)*spacing/2
        
        # Check if the formation fits within the arena width
        if (cols-1) * spacing > 2 * (self._world.x_semidim - safety_margin):
            # Reduce spacing if the formation would exceed arena width
            spacing = 2 * (self._world.x_semidim - safety_margin) / (cols-1 if cols > 1 else 1)
            
        # Position agents in a grid formation
        for i, agent in enumerate(self.world.agents):
            row = i // cols
            col = i % cols
            
            # Calculate position with center alignment
            x = x_center + (col - (cols-1)/2) * spacing
            y = y_start - row * spacing  # Start from y_start and go down
            
            pos = torch.tensor([x, y], device=self._world.device, dtype=torch.float32).unsqueeze(0)
            agent.set_pos(pos, batch_index=env_index)
            
            # Set agent's initial velocity to zero
            agent.set_vel(torch.zeros_like(agent.state.vel, dtype=torch.float32), batch_index=env_index)
            
            # Reset controller
            agent.controller.reset(env_index)
        
        # Verify that all pairwise distances are within formation_radius
        if n_agents > 1:
            # Get positions as [n_agents, batch_dim, 2]
            positions = torch.stack([a.state.pos for a in self.world.agents])
            
            if env_index is not None:
                # If reset for specific env, check only that env
                positions_to_check = positions[:, env_index].unsqueeze(1)  # [n_agents, 1, 2]
            else:
                # Check all envs
                positions_to_check = positions  # [n_agents, batch_dim, 2]
                
            # For each environment
            batch_size = positions_to_check.shape[1]
            
            max_initial_distance_overall = 0.0
            for b in range(batch_size):
                pos_batch = positions_to_check[:, b]  # [n_agents, 2]
                
                # Calculate pairwise distances
                distances = torch.cdist(pos_batch, pos_batch)
                
                # Get maximum distance (excluding self-distances)
                mask = torch.eye(n_agents, device=self._world.device) == 0
                if mask.any():
                    max_distance_batch = distances[mask].max().item()
                    if env_index is None or b == env_index:
                        max_initial_distance_overall = max(max_initial_distance_overall, max_distance_batch)
                    assert max_distance_batch < self.formation_radius, "Initial agent placement violates formation radius constraint"
                else:
                    pass # Only one agent, no pairwise distances to check
            # REMOVED DEBUG LOG
            # print(f"[Reset Env {env_index if env_index is not None else 'All'}] Max Initial Pairwise Distance: {max_initial_distance_overall:.4f} / {self.formation_radius:.4f}")
        
        # Spawn walls
        self.spawn_walls(env_index)
        
        # Respawn obstacles
        self.respawn_obstacles_no_overlap(env_index)
        
        # Assign random goals to agents
        self.assign_random_goals(env_index)
        
        # Reset collision timers
        for agent in self.world.agents:
            agent.collision_timer = torch.zeros_like(agent.collision_timer)
            agent.final_rew = torch.zeros_like(agent.final_rew)
            agent.formation_rew = torch.zeros_like(agent.formation_rew)
        
        for obs in self.obstacles:
            obs.collision_timer = torch.zeros_like(obs.collision_timer)

    def pre_step(self):
        """Store the pre-step states of the agents."""
        if self.agent_control_type == VELOCITY_CONTROL:
            # Store pre-step positions and velocities
            batch_dim = self.world.batch_dim
            n_agents = len(self.world.agents)
            device = self.world.device
            
            self.pre_step_positions = torch.zeros((batch_dim, n_agents, 2), device=device)
            self.pre_step_velocities = torch.zeros((batch_dim, n_agents, 2), device=device)

            for i, agent in enumerate(self.world.agents):
                self.pre_step_positions[:, i, :] = agent.state.pos
                self.pre_step_velocities[:, i, :] = agent.state.vel

    def post_step(self):
        """Check formation constraint and revert positions if violated."""
        if self.agent_control_type == VELOCITY_CONTROL and len(self.world.agents) > 1:
            # Compute pairwise distances after step
            batch_dim = self.world.batch_dim
            n_agents = len(self.world.agents)
            current_positions = torch.stack([a.state.pos for a in self.world.agents], dim=1)  # (batch_dim, n_agents, 2)

            # For each agent, we check if distance to ANY other agent > formation_radius
            # We'll do this by creating a distance matrix:
            # dist[i, j] = distance between agent i and agent j
            # dist matrix shape: (batch_dim, n_agents, n_agents)
            # Start by expanding dimensions for broadcasting
            pos_expanded_1 = current_positions.unsqueeze(2)  # (batch_dim, n_agents, 1, 2)
            pos_expanded_2 = current_positions.unsqueeze(1)  # (batch_dim, 1, n_agents, 2)

            dist_matrix = torch.norm(pos_expanded_1 - pos_expanded_2, dim=-1)  # (batch_dim, n_agents, n_agents)

            # We don't consider distance to self
            # Check if any distance > formation_radius for each agent
            too_far = dist_matrix > self.formation_radius
            
            # Ignore self-distance by setting diagonal to False
            eye_mask = torch.eye(n_agents, dtype=torch.bool, device=self.world.device).unsqueeze(0)
            idx = torch.arange(n_agents, device=self.world.device)
            too_far[:, idx, idx] = False

            # For each agent, if any too_far is True, we revert that agent
            revert_mask = too_far.any(dim=2)  # (batch_dim, n_agents)
            # revert_mask[b, i] = True if agent i in batch b violates formation

            # Only print violations if a violation actually causes a reversion
            violation_occurred = False
            for env_idx in range(batch_dim):
                for i in range(n_agents):
                    if revert_mask[env_idx, i]:
                        # Find which agents this agent is too far from
                        too_far_indices = torch.where(dist_matrix[env_idx, i, :] > self.formation_radius)[0]
                        too_far_indices = [j.item() for j in too_far_indices if j != i]
                        if too_far_indices and not violation_occurred:
                            max_dist = dist_matrix[env_idx, i, too_far_indices].max().item()
                            # Removed warning print about violated distance constraint
                            violation_occurred = True

            # Apply reversion where needed
            for i, agent in enumerate(self.world.agents):
                # For all envs where revert_mask[:, i] is True, revert pos/vel
                mask = revert_mask[:, i]
                if mask.any():
                    # Revert position to pre-step position
                    agent.state.pos[mask] = self.pre_step_positions[mask, i, :]
                    
                    # Also revert velocity to zero to prevent immediate re-violation
                    agent.state.vel[mask] = torch.zeros_like(agent.state.vel[mask])
                    
                    # Reset controller for this agent to make sure it doesn't retain motion
                    if hasattr(agent, 'controller'):
                        agent.controller.reset(mask)

    def assign_random_goals(self, env_index=None):
        self.goal_radius = 0.2
        batch_dim = self._world.batch_dim
        device = self._world.device

        # Define outer boundary (actual world boundaries)
        outer_x_min = -self._world.x_semidim + self.goal_radius
        outer_x_max = self._world.x_semidim - self.goal_radius
        outer_y_min = -self._world.y_semidim + self.goal_radius
        outer_y_max = self._world.y_semidim - self.goal_radius
        
        # Define an inner boundary (for the central area where goals will NOT be placed)
        # Goals will be placed in the "edge zone" between the outer and inner boundaries
        edge_zone_width = 0.4 * self._world.x_semidim  # Width of the edge zone
        inner_x_min = outer_x_min + edge_zone_width
        inner_x_max = outer_x_max - edge_zone_width
        inner_y_min = outer_y_min + edge_zone_width
        inner_y_max = outer_y_max - edge_zone_width

        max_tries = 10000
        for agent in self.world.agents:
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0

            while not torch.all(placed) and tries < max_tries:
                tries += 1
                
                # Generate candidate positions in the edge zone by first randomly selecting 
                # from the full range and then conditionally adjusting
                rand_x = torch.rand(batch_dim, device=device, dtype=torch.float32) * (outer_x_max - outer_x_min) + outer_x_min
                rand_y = torch.rand(batch_dim, device=device, dtype=torch.float32) * (outer_y_max - outer_y_min) + outer_y_min
                
                # Masks for checking if the coordinates fall in the central area
                x_in_middle = (rand_x >= inner_x_min) & (rand_x <= inner_x_max)
                y_in_middle = (rand_y >= inner_y_min) & (rand_y <= inner_y_max)
                
                # If both x and y are in the middle, push them to the edge zone
                both_in_middle = x_in_middle & y_in_middle
                
                # For points in the middle, randomly choose whether to adjust x or y
                adjust_x = torch.rand(batch_dim, device=device) > 0.5
                
                # For points where we adjust x, push them left or right
                push_left = torch.rand(batch_dim, device=device) > 0.5
                rand_x[both_in_middle & adjust_x & push_left] = torch.rand(sum(both_in_middle & adjust_x & push_left), device=device, dtype=torch.float32) * (inner_x_min - outer_x_min) + outer_x_min
                rand_x[both_in_middle & adjust_x & ~push_left] = torch.rand(sum(both_in_middle & adjust_x & ~push_left), device=device, dtype=torch.float32) * (outer_x_max - inner_x_max) + inner_x_max
                
                # For points where we adjust y, push them up or down
                push_down = torch.rand(batch_dim, device=device) > 0.5
                rand_y[both_in_middle & ~adjust_x & push_down] = torch.rand(sum(both_in_middle & ~adjust_x & push_down), device=device, dtype=torch.float32) * (inner_y_min - outer_y_min) + outer_y_min
                rand_y[both_in_middle & ~adjust_x & ~push_down] = torch.rand(sum(both_in_middle & ~adjust_x & ~push_down), device=device, dtype=torch.float32) * (outer_y_max - inner_y_max) + inner_y_max
                
                candidate_positions = torch.stack([rand_x, rand_y], dim=-1)
                final_positions = agent.goal_pos.clone().to(torch.float32)

                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue

                    candidate = candidate_positions[bi]
                    collision_detected = False

                    # 1) Check vs other agents' goals
                    for other_agent in self.world.agents:
                        if other_agent == agent:
                            continue
                        dist = torch.norm(candidate - other_agent.goal_pos[bi], p=2)
                        if dist < (2 * self.goal_radius + self.collision_gap):
                            collision_detected = True
                            break

                    # 2) Check vs agent itself
                    if not collision_detected:
                        dist_to_own_agent = torch.norm(candidate - agent.state.pos[bi], p=2)
                        if dist_to_own_agent < (self.goal_radius + self.agent_radius + self.collision_gap):
                            collision_detected = True

                    # 3) Check vs obstacles
                    if not collision_detected:
                        for obs in self.obstacles:
                            dist_to_obs = torch.norm(candidate - obs.state.pos[bi], p=2)
                            if dist_to_obs < (self.goal_radius + self.obstacle_radius + self.collision_gap):
                                collision_detected = True
                                break

                    if not collision_detected:
                        final_positions[bi] = candidate
                        placed[bi] = True

                agent.goal_pos = final_positions

            if tries >= max_tries:
                # print( # Commented out warning
                #     f"Warning: Could not place goal for agent {agent.name} "
                #     f"without overlap after {max_tries} tries."
                # )
                pass # Allow to proceed even if a goal is not perfectly placed after max_tries

    def respawn_agents_no_overlap(self, env_index=None, max_tries=50):
        """
        Respawn agents ensuring they are placed within formation_radius of each other
        without overlapping with obstacles or world boundaries.
        
        This implementation guarantees that all agents form a connected component
        where each agent is within formation_radius of at least one other agent.
        """
        batch_dim = self._world.batch_dim
        device = self._world.device
        
        # Define a tighter region for initial agent placement to ensure connectivity
        x_min = -0.5 * self.world_x_semidim + self.agent_radius
        x_max = 0.5 * self.world_x_semidim - self.agent_radius
        y_min = -0.5 * self.world_y_semidim + self.agent_radius
        y_max = 0.5 * self.world_y_semidim - self.agent_radius
        
        # Minimum required gap between agents and obstacles (collision avoidance)
        min_gap = 1.1 * (2 * self.agent_radius + self.collision_gap)
        
        # First agent is placed randomly within the tighter central region
        agent = self.world.agents[0]
        rand_x = torch.rand(batch_dim, device=device) * (x_max - x_min) + x_min
        rand_y = torch.rand(batch_dim, device=device) * (y_max - y_min) + y_min
        initial_pos = torch.stack([rand_x, rand_y], dim=-1)
        agent.set_pos(initial_pos, batch_index=None)
        agent.set_vel(torch.zeros_like(agent.state.vel), batch_index=None)
        
        # For each remaining agent, place it within formation_radius of at least one existing agent
        for i in range(1, len(self.world.agents)):
            agent = self.world.agents[i]
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0
            
            while not torch.all(placed) and tries < max_tries:
                tries += 1
                
                # For each environment, select a random existing agent to be connected to
                reference_indices = torch.randint(0, i, (batch_dim,), device=device)
                
                # Get positions of reference agents for each environment
                reference_positions = torch.zeros(batch_dim, 2, device=device)
                for bi in range(batch_dim):
                    reference_positions[bi] = self.world.agents[reference_indices[bi]].state.pos[bi]
                
                # Generate random positions within formation_radius of reference agent
                # Use 75% of formation_radius to ensure agents stay well within constraints
                max_distance = 0.75 * self.formation_radius
                random_distance = torch.rand(batch_dim, device=device) * max_distance
                random_angle = torch.rand(batch_dim, device=device) * 2 * torch.pi
                
                # Calculate offset in x, y coordinates
                dx = random_distance * torch.cos(random_angle)
                dy = random_distance * torch.sin(random_angle)
                offsets = torch.stack([dx, dy], dim=-1)
                
                # Proposed position = reference position + offset
                candidate_positions = reference_positions + offsets
                
                # Ensure positions are within world bounds
                candidate_positions[:, 0] = torch.clamp(
                    candidate_positions[:, 0],
                    min=-self.world_x_semidim + self.agent_radius,
                    max=self.world_x_semidim - self.agent_radius,
                )
                candidate_positions[:, 1] = torch.clamp(
                    candidate_positions[:, 1],
                    min=-self.world_y_semidim + self.agent_radius,
                    max=self.world_y_semidim - self.agent_radius,
                )
                
                # Start with current positions, then update only for environments that pass all checks
                final_positions = agent.state.pos.clone()
                
                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue
                    
                    candidate_pos = candidate_positions[bi]
                    collision_detected = False
                    
                    # 1. Check collisions with existing agents
                    for j in range(i):
                        other_agent = self.world.agents[j]
                        dist = torch.norm(candidate_pos - other_agent.state.pos[bi], p=2)
                        
                        # Check for collision
                        if dist < min_gap:
                            collision_detected = True
                            break
                    
                    # 2. Check collisions with obstacles
                    if not collision_detected:
                        for obs in self.obstacles:
                            dist = torch.norm(candidate_pos - obs.state.pos[bi], p=2)
                            if dist < (self.agent_radius + self.obstacle_radius + self.collision_gap):
                                collision_detected = True
                                break
                    
                    # 3. Verify connectivity - must be within formation_radius of at least one agent
                    if not collision_detected:
                        # Check if within formation_radius of ANY existing agent
                        is_connected = False
                        for j in range(i):
                            other_agent = self.world.agents[j]
                            dist = torch.norm(candidate_pos - other_agent.state.pos[bi], p=2)
                            if dist <= self.formation_radius:
                                is_connected = True
                                break
                        
                        if is_connected:
                            final_positions[bi] = candidate_pos
                            placed[bi] = True
                
                agent.set_pos(final_positions, batch_index=None)
                agent.set_vel(torch.zeros_like(agent.state.vel), batch_index=None)
            
            if tries >= max_tries:
                # print(f"Warning: Could not place agent {i} after {max_tries} tries. Using fallback placement.") # Commented out warning
                
                # Fallback: place directly next to agent 0, ensuring connectivity
                base_agent = self.world.agents[0]
                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if not placed[bi]:
                        # Place at a very small offset from agent 0 to guarantee connectivity
                        direction = torch.tensor([0.1, 0.1], device=device)
                        direction = direction / torch.norm(direction)
                        fallback_position = base_agent.state.pos[bi] + direction * (2*self.agent_radius + 0.01)
                        
                        # Ensure it's within bounds
                        fallback_position[0] = torch.clamp(
                            fallback_position[0],
                            min=-self.world_x_semidim + self.agent_radius,
                            max=self.world_x_semidim - self.agent_radius,
                        )
                        fallback_position[1] = torch.clamp(
                            fallback_position[1],
                            min=-self.world_y_semidim + self.agent_radius,
                            max=self.world_y_semidim - self.agent_radius,
                        )
                        
                        final_positions[bi] = fallback_position
                
                agent.set_pos(final_positions, batch_index=None)
                agent.set_vel(torch.zeros_like(agent.state.vel), batch_index=None)

    def respawn_obstacles_no_overlap(self, env_index=None, max_tries=50):
        batch_dim = self._world.batch_dim
        device = self._world.device
        x_min = -self._world.x_semidim + self.obstacle_radius
        x_max = self._world.x_semidim - self.obstacle_radius
        y_min = -self._world.y_semidim + self.obstacle_radius
        y_max = self._world.y_semidim - self.obstacle_radius
        
        # Increased safety margin for obstacle placement
        safety_margin = 2.0 * (self.agent_radius + self.obstacle_radius)

        for obs in self.obstacles:
            placed = torch.zeros(batch_dim, dtype=torch.bool, device=device)
            tries = 0
            while not torch.all(placed) and tries < max_tries:
                tries += 1
                rand_x = torch.rand(batch_dim, device=device) * (x_max - x_min) + x_min
                rand_y = torch.rand(batch_dim, device=device) * (y_max - y_min) + y_min
                candidate_positions = torch.stack([rand_x, rand_y], dim=-1)

                final_positions = obs.state.pos.clone()
                for bi in range(batch_dim):
                    if env_index is not None and bi != env_index:
                        continue
                    if placed[bi]:
                        continue

                    collision_detected = False
                    final_positions[bi] = candidate_positions[bi]

                    # vs other obstacles
                    for other_obs in self.obstacles:
                        if other_obs == obs:
                            continue
                        dist = torch.norm(final_positions[bi] - other_obs.state.pos[bi], p=2)
                        if dist < (2 * self.obstacle_radius + self.collision_gap):
                            collision_detected = True
                            break

                    # Increased minimum distance from agents
                    if not collision_detected:
                        for a in self.world.agents:
                            dist = torch.norm(final_positions[bi] - a.state.pos[bi], p=2)
                            if dist < safety_margin:
                                collision_detected = True
                                break

                    if not collision_detected:
                        placed[bi] = True

                obs.set_pos(final_positions, batch_index=None)

            if tries >= max_tries:
                # print(f"Warning: Could not place obstacle without overlap after {max_tries} tries. Skipping.") # Commented out warning
                # Place the obstacle far away to reduce constraint conflicts
                far_positions = torch.ones_like(obs.state.pos) * 100.0
                obs.set_pos(far_positions, batch_index=None)

    def spawn_walls(self, env_index=None):
        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim
        for i, wall in enumerate(self.walls):
            if i == 0:  # Top
                pos = torch.tensor([0.0, world_y_semidim], device=self._world.device)
                rot = 0.0
            elif i == 1:  # Bottom
                pos = torch.tensor([0.0, -world_y_semidim], device=self._world.device)
                rot = 0.0
            elif i == 2:  # Left
                pos = torch.tensor([-world_x_semidim, 0.0], device=self._world.device)
                rot = torch.pi / 2
            else:  # Right
                pos = torch.tensor([world_x_semidim, 0.0], device=self._world.device)
                rot = torch.pi / 2

            wall.set_pos(pos, batch_index=env_index)
            wall.set_rot(torch.tensor([rot], device=self._world.device), batch_index=env_index)

    def reward(self, agent: Agent):
        """
        Summation of individual reward (exponential proximity) + final_rew + collision penalties + formation reward (+ delta penalty if CBF).
        """
        is_first = agent == self.world.agents[0]
        if is_first:
            # Update each agent's shaped reward
            for a in self.world.agents:
                self.agent_reward(a)
                # Reset collision penalty
                a.agent_collision_rew[:] = 0

            # Collisions with obstacles
            for a in self.world.agents:
                for obs in self.obstacles:
                    if self._world.collides(a, obs):
                        distance = self._world.get_distance(a, obs)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        obs.collision_timer[collision_mask] = 5

            # Collisions with walls
            for a in self.world.agents:
                for wall in self.walls:
                    if self._world.collides(a, wall):
                        distance = self._world.get_distance(a, wall)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        wall.collision_timer[collision_mask] = 5

            # Collisions: agent-agent
            for i, a in enumerate(self.world.agents):
                for j, b in enumerate(self.world.agents):
                    if j <= i:
                        continue
                    if self._world.collides(a, b):
                        distance = self._world.get_distance(a, b)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        b.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        b.collision_timer[collision_mask] = 5

            # Decrement collision timers
            for a in self.world.agents:
                a.collision_timer = torch.maximum(a.collision_timer - 1, torch.zeros_like(a.collision_timer))
            for wall in self.walls:
                wall.collision_timer = torch.maximum(wall.collision_timer - 1, torch.zeros_like(wall.collision_timer))
            for obs in self.obstacles:
                obs.collision_timer = torch.maximum(obs.collision_timer - 1, torch.zeros_like(obs.collision_timer))

        # Combine everything: individual reward + final_rew + collisions + formation reward (+ delta if CBF)
        if self.agent_control_type == QP_CONTROL_CBF:
            total_rew = agent.individual_rew + agent.final_rew + agent.agent_collision_rew + agent.formation_rew + agent.delta_rew
        else:
            total_rew = agent.individual_rew + agent.final_rew + agent.agent_collision_rew + agent.formation_rew

        return total_rew

    def agent_reward(self, agent: Agent):
        """
        Exponential proximity reward based on distance to goal + connectivity constraint reward.
        If using CBF, add a penalty for delta as well.
        """
        device = self._world.device
        batch_dim = agent.state.pos.shape[0]

        # Zero them out before computing
        agent.individual_rew[:] = 0
        agent.final_rew[:] = 0  # Keep this for consistency with QP controllers
        agent.formation_rew[:] = 0  # Reset formation reward

        # Distance-based reward for goal seeking
        dist_to_goal = torch.norm(agent.state.pos - agent.goal_pos, dim=-1).to(torch.float32)
        
        # Decay factor for exponential reward (controls how quickly reward decays with distance)
        # Lower values make reward decay more slowly
        decay_factor = 2.0
        
        # Maximum possible reward when distance is 0
        max_reward = self.lambda_g
        
        # Compute exponential proximity reward: max_reward * exp(-decay_factor * distance)
        proximity_reward = max_reward * torch.exp(-decay_factor * dist_to_goal)
        
        # Also set final_rew for very close positions - needed for QP controllers
        on_goal_mask = dist_to_goal < 0.05
        agent.final_rew[on_goal_mask] = 0.0*self.lambda_g  # Use lambda_g as the final reward
        
        # Add progressive reward (improvement from previous distance)
        if not hasattr(agent, "prev_dist"):
            agent.prev_dist = dist_to_goal.clone()
            
        # Progress reward: how much closer did we get?
        progress_reward = agent.prev_dist - dist_to_goal
        agent.prev_dist = dist_to_goal.clone()
        
        # Combine proximity and progress rewards
        # Proximity reward encourages being close to goal
        # Progress reward encourages moving toward goal
        agent.individual_rew = proximity_reward + 0.5 * progress_reward
        
        # Connectivity penalties: if any other agent is more than formation_radius away, add negative reward
        if len(self.world.agents) > 1:
            other_positions = torch.stack([a.state.pos for a in self.world.agents if a != agent], dim=1).to(torch.float32)
            agent_pos_expanded = agent.state.pos.unsqueeze(1).to(torch.float32)  # (batch_dim, 1, 2)
            distances = torch.norm(other_positions - agent_pos_expanded, dim=-1)  # (batch_dim, n_agents-1)

            # Check if any distance exceeds formation_radius
            too_far = distances > self.formation_radius
            # --- REMOVED DEBUG: Log Formation Breaks ---
            # if too_far.any():
            #     env_indices = torch.where(too_far.any(dim=1))[0]
            #     for env_idx in env_indices:
            #         max_dist_env = distances[env_idx].max().item()
            #         print(f"[FORMATION BREAK DETECTED] Env: {env_idx}, Agent: {agent.name}, Max Dist: {max_dist_env:.4f} > {self.formation_radius:.4f}")
            # --- END REMOVED DEBUG ---
            agent.formation_rew = -self.lambda_u * too_far.any(dim=1).to(torch.float32)

        # If CBF, penalize delta
        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent.action, "delta"):
                agent.delta_rew = -self.lambda_u * agent.action.delta.to(torch.float32)
            else:
                # print(f"Warning: Agent {agent.name} has no delta value.") # Commented out warning
                pass # Added pass for empty else block

        return agent.individual_rew

    def observation(self, agent: Agent):
        device = self._world.device
        dist_to_goal = torch.norm(agent.state.pos - agent.goal_pos, dim=-1, keepdim=True)
        rel_goal_pos = agent.goal_pos - agent.state.pos

        obs = {
            "agent_pos": agent.state.pos,
            "agent_vel": agent.state.vel,
            "goal_pos": agent.goal_pos,
            "distance_to_goal": dist_to_goal,
            "rel_goal_pos": rel_goal_pos,
        }
        
        # Add obstacle relative positions and distances
        if self.obstacles:
            # obstacle_positions shape: (batch_dim, n_obstacles, 2)
            obstacle_positions = torch.stack([obs_ent.state.pos for obs_ent in self.obstacles], dim=1)
            agent_pos_expanded = agent.state.pos.unsqueeze(1)  # (batch_dim, 1, 2)
            obstacle_rel_positions = obstacle_positions - agent_pos_expanded  # (batch_dim, n_obstacles, 2)

            # Compute distances to each obstacle
            obstacle_distances = torch.norm(obstacle_rel_positions, dim=2)  # (batch_dim, n_obstacles)

            # Flatten relative positions
            batch_dim = obstacle_rel_positions.shape[0]
            n_obstacles = obstacle_rel_positions.shape[1]
            flattened_obstacle_rel_positions = obstacle_rel_positions.reshape(batch_dim, n_obstacles * 2)

            obs["obstacle_rel_positions"] = flattened_obstacle_rel_positions
            obs["obstacle_distances"] = obstacle_distances
        else:
            obs["obstacle_rel_positions"] = torch.zeros(agent.state.pos.size(0), 0, device=device)
            obs["obstacle_distances"] = torch.zeros(agent.state.pos.size(0), 0, device=device)

        # Compute the agent's diameter (max distance to any other agent)
        if len(self.world.agents) > 1:
            # Stack positions of all other agents
            other_positions = torch.stack([a.state.pos for a in self.world.agents if a != agent], dim=1)
            # Compute distances to other agents
            agent_pos_expanded = agent.state.pos.unsqueeze(1)  # (batch_dim, 1, 2)
            distances = torch.norm(other_positions - agent_pos_expanded, dim=-1)  # (batch_dim, n_agents-1)
            agent_diameter = torch.max(distances, dim=1)[0]  # (batch_dim,)
        else:
            # If there's only one agent, diameter can be considered zero
            agent_diameter = torch.zeros(agent.state.pos.size(0), device=device)

        # Add the agent's diameter to observations - here we provide the margin to formation_radius
        obs["agent_diameter"] = self.formation_radius - agent_diameter.unsqueeze(-1)  # shape: (batch_dim, 1)
        
        # Convert all tensor values to float32 to ensure consistency
        for key, value in obs.items():
            if torch.is_tensor(value):
                obs[key] = value.to(torch.float32)
        
        return obs

    def done(self):
        dones = torch.zeros(self._world.batch_dim, device=self._world.device, dtype=torch.bool)
        if not self.qp_solver_initialized:
            dones[:] = True
        return dones

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        info = {
            "collision_reward": agent.agent_collision_rew,
            "individual_reward": agent.individual_rew,
            "formation_reward": agent.formation_rew,
            "total_reward_without_delta_rew": agent.individual_rew + agent.agent_collision_rew + agent.formation_rew,
            "distance_to_goal": torch.norm(agent.state.pos - agent.goal_pos, dim=-1),
        }
        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent, "original_action_u"):
                info["original_action_u"] = agent.original_action_u.squeeze(-1)
            else:
                info["original_action_u"] = torch.zeros_like(agent.agent_collision_rew)

        if hasattr(agent.action, "delta_rew"):
            info["delta_reward"] = agent.delta_rew

        return info

    def init_qp_layer(self):
        """
        Set up the QP problem for controlling each agent:
          - Always includes an objective to move toward (x_goal, y_goal).
          - Delta is a slack variable for collision constraints / boundary constraints.
          - Formation constraints ensure connectivity between agents.
        """
        if self.agent_control_type == VELOCITY_CONTROL:
            return None

        # "Obstacle" = other agents + obstacles
        N_ = (self.n_agents - 1) + self.n_obstacles

        u = cp.Variable(2)
        delta = cp.Variable(1)

        ego_position = cp.Parameter(2)
        ego_velocity = cp.Parameter(2)

        obstacle_positions = cp.Parameter((N_, 2))
        h_values = cp.Parameter(N_)
        dx_vel_0 = cp.Parameter(N_)
        dy_vel_1 = cp.Parameter(N_)

        parameters = [
            ego_position,
            ego_velocity,
            obstacle_positions,
            h_values,
            dx_vel_0,
            dy_vel_1,
        ]
        
        # Store parameter names for debugging
        parameter_names = ["ego_position", "ego_velocity", "obstacle_positions", "h_values", "dx_vel_0", "dy_vel_1"]

        constraints = []
        # CBF-like constraints or simple constraints for obstacles
        for j in range(N_):
            h = h_values[j]
            lgh = 2 * (
                dx_vel_0[j] + dy_vel_1[j]
                + u[0] * (ego_position[0] - obstacle_positions[j, 0])
                + u[1] * (ego_position[1] - obstacle_positions[j, 1])
            )
            if self.agent_control_type in [QP_CONTROL_CBF, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
                constraints.append(lgh + h + delta >= 0)
                
                # For other agents (first n_agents-1), apply formation radius constraint with safety margin
                if j < (self.n_agents - 1):
                    constraints.append(cp.norm((ego_position + ego_velocity + u - obstacle_positions[j])) <= self.formation_radius * 0.95) # Added safety margin
            else:
                constraints.append(lgh + h >= 0)
                
                # Add formation radius constraint for QP controller with a safety margin
                if j < (self.n_agents - 1):
                    constraints.append(cp.norm((ego_position + ego_velocity + u - obstacle_positions[j])) <= self.formation_radius * 0.95) # Added 5% safety margin

        # Boundaries
        safety_margin = 1.05 * (self.agent_radius + self.collision_gap)
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            # This block demonstrates how boundary constraints might be handled with a slack variable.
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0])
                    <= self._world.x_semidim - safety_margin + delta,
                    ego_position[0] + (ego_velocity[0] + u[0])
                    >= -self._world.x_semidim + safety_margin - delta,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    <= self._world.y_semidim - safety_margin + delta,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    >= -self._world.y_semidim + safety_margin - delta,
                    cp.norm(u) <= self.max_velocity + delta, # slightly inconsistent to use same slack variable in multiple places, but a_range is capped anyway
                    delta >= 0,
                ]
            )
        else:
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0])
                    <= self._world.x_semidim - safety_margin,
                    ego_position[0] + (ego_velocity[0] + u[0])
                    >= -self._world.x_semidim + safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    <= self._world.y_semidim - safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1])
                    >= -self._world.y_semidim + safety_margin,
                    cp.norm(u) <= self.max_velocity,
                    delta >= 0,
                    delta <= 1000,
                ]
            )

        # Depending on the control type, define the objective
        if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([a, b, x_target, y_target])
            parameter_names.extend(["a", "b", "x_target", "y_target"])
            constraints.append(cp.norm(u - a) <= b + delta)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling * delta
            )

        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            # For objective, the target is the goal position 
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([x_target, y_target])
            parameter_names.extend(["x_target", "y_target"])
            constraints.append(delta >= 0)
            constraints.append(delta <= 10000)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling * delta
            )

        elif self.agent_control_type == QP_CONTROL_CONSTRAINT:
            a = cp.Parameter(2, name="a")
            b = cp.Parameter(1, name="b")
            x_goal = cp.Parameter(1, name="x_goal")
            y_goal = cp.Parameter(1, name="y_goal")
            # Always extend parameters for the layer definition
            parameters.extend([a, b, x_goal, y_goal])
            parameter_names.extend(["a", "b", "x_goal", "y_goal"])

            # Make conic constraint conditional
            if self.constraint_on:
                constraints.append(cp.norm(u - a) <= b + delta)
            else:
                # If constraint is off, use the 0* trick to disable it
                # This constraint still references 'a' and 'b', so they must be in the layer's parameters.
                constraints.append(0 * cp.norm(u - a) <= 0 * b + delta) 
                
            # Use a goal-directed objective that minimizes distance to goal
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_goal,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_goal,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling * delta
            )

        elif self.agent_control_type == QP_CONTROL_CBF:
            # Let agent param be e.g. "k_cbf" in action or so,
            # but also pass x_goal, y_goal for the objective.
            x_goal = cp.Parameter(1)
            y_goal = cp.Parameter(1)
            parameters.extend([x_goal, y_goal])
            parameter_names.extend(["x_goal", "y_goal"])

            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_goal,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_goal,
                        ]
                    ),
                    2
                )
                + self.delta_scaling * delta
            )
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")
        constraints.append(delta >= 0)
        constraints.append(delta <= 1000)
        prob = cp.Problem(objective, constraints)
        qp_layer = CvxpyLayer(prob, parameters=parameters, variables=[u, delta])
        
        # Store expected parameter names for debugging
        self.expected_parameter_names = parameter_names
        
        return qp_layer

    def solve_qp_layer(self, parameters_batch, solver_args):
        """
        Solve batched QP. If fail, mark done.
        """
        # Ensure all parameters are float32
        parameters_batch = [p.to(torch.float32) for p in parameters_batch]
        
        try:
            with torch.no_grad():
                u_qp_batch, delta_batch = self.qp_layer(*parameters_batch, solver_args=solver_args)
            return u_qp_batch, delta_batch
        except Exception as e:
            # Restore original error message
            print(f"SolverError encountered: {e}. Marking the environment as done.")
            self.qp_solver_initialized = False
            return None, None

    def process_action(self, agent: Agent):
        if self.agent_control_type == VELOCITY_CONTROL:
            # Clamp the action to its allowed range
            agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, self.u_range)
            agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, self.max_velocity)
            # Zero out actions that are too small
            action_norm = torch.linalg.vector_norm(agent.action.u, dim=1)
            agent.action.u[action_norm < self.min_input_norm] = 0

            vel_is_zero = torch.linalg.vector_norm(agent.action.u, dim=1) < 1e-3
            agent.controller.reset(vel_is_zero)

            # Apply the force (which updates pos and vel)
            agent.controller.process_force()
        else:
            # For QP-based control, prepare batched parameters for all agents before solving.
            is_first = agent == self.world.agents[0]

            if is_first:
                pos_list = []
                vel_list = []

                for idx, a in enumerate(self.world.agents):
                    pos_list.append(a.state.pos)
                    vel_list.append(a.state.vel)
                    a.original_action_u = a.action.u.clone()

                N_ = (self.n_agents - 1) + self.n_obstacles

                obstacle_positions_list = []
                current_h_values_list = []
                dx_vel_0_list = []
                dy_vel_1_list = []

                for agent_idx in range(self.n_agents):
                    for env_idx in range(self._world.batch_dim):
                        ego_pos = pos_list[agent_idx][env_idx]
                        ego_vel = vel_list[agent_idx][env_idx]

                        other_indices = [i for i in range(self.n_agents) if i != agent_idx]
                        other_pos = torch.stack(
                            [pos_list[i][env_idx] for i in other_indices],
                            dim=0,
                        )

                        obs_positions = torch.stack([obs.state.pos[env_idx] for obs in self.obstacles], dim=0)
                        full_obstacles_pos = torch.cat([other_pos, obs_positions], dim=0)

                        obstacle_positions_list.append(full_obstacles_pos)

                        dist_agents = torch.norm(other_pos - ego_pos, dim=1)**2
                        dist_obstacles = torch.norm(obs_positions - ego_pos.unsqueeze(0), dim=1)**2
                        all_distances = torch.cat([dist_agents, dist_obstacles], dim=0)

                        # --- Add a safety buffer to min_distance for QP h_values calculation ---
                        qp_safety_distance_buffer = 0.02 # Small buffer
                        min_dist_sq_buffered = (self.min_distance_between_entities + qp_safety_distance_buffer) ** 2
                        # --- End safety buffer addition ---

                        if self.agent_control_type == QP_CONTROL_CBF:
                            k_cbf = (self.world.agents[agent_idx].action.u[env_idx][0] + self.k_cbf_range) / 2
                            p_cbf = 1.0
                            h_values = k_cbf * (all_distances - min_dist_sq_buffered)**p_cbf # Use buffered distance
                        elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                            k_cbf = 0.9
                            p_cbf = 1.0
                            h_values = k_cbf * (all_distances - min_dist_sq_buffered)**p_cbf # Use buffered distance
                        else:
                            h_values = all_distances - min_dist_sq_buffered # Use buffered distance

                        current_h_values_list.append(h_values)

                        dx_agents = ego_pos[0] - other_pos[:, 0]
                        dy_agents = ego_pos[1] - other_pos[:, 1]
                        dx_vel_agents = dx_agents * ego_vel[0]
                        dy_vel_agents = dy_agents * ego_vel[1]

                        dx_obs = ego_pos[0] - obs_positions[:, 0]
                        dy_obs = ego_pos[1] - obs_positions[:, 1]
                        dx_vel_obs = dx_obs * ego_vel[0]
                        dy_vel_obs = dy_obs * ego_vel[1]

                        dx_vel_all = torch.cat([dx_vel_agents, dx_vel_obs], dim=0)
                        dy_vel_all = torch.cat([dy_vel_agents, dy_vel_obs], dim=0)

                        dx_vel_0_list.append(dx_vel_all)
                        dy_vel_1_list.append(dy_vel_all)

                ego_positions_batch = torch.cat(pos_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                ego_velocities_batch = torch.cat(vel_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                obstacle_positions_batch = torch.stack(obstacle_positions_list, dim=0)
                current_h_values_batch = torch.stack(current_h_values_list, dim=0)
                dx_vel_0_batch = torch.stack(dx_vel_0_list, dim=0)
                dy_vel_1_batch = torch.stack(dy_vel_1_list, dim=0)

                parameters_batch = [
                    ego_positions_batch,
                    ego_velocities_batch,
                    obstacle_positions_batch,
                    current_h_values_batch,
                    dx_vel_0_batch,
                    dy_vel_1_batch,
                ]

                if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                    a_list = []
                    b_list = []
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        x_target_list.append(a.action.u[:, 3:4])
                        y_target_list.append(a.action.u[:, 4:5])
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, x_target_batch, y_target_batch]
                    )

                elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        x_target_list.append(a.action.u[:, 0:1])
                        y_target_list.append(a.action.u[:, 1:2])
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend([x_target_batch, y_target_batch])

                elif self.agent_control_type in [QP_CONTROL_CONSTRAINT, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
                    a_list = []
                    b_list = []
                    target_direction_list = []
                    x_goal_list = []
                    y_goal_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                        x_goal_list.append(a.goal_pos[:, 0:1])
                        y_goal_list.append(a.goal_pos[:, 1:2])
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    x_goal_batch = torch.cat(x_goal_list, dim=0)
                    y_goal_batch = torch.cat(y_goal_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, x_goal_batch, y_goal_batch]
                    )
                
                elif self.agent_control_type == QP_CONTROL_CBF:
                    target_direction_list = []
                    x_goal_list = []
                    y_goal_list = []
                    for a in self.world.agents:
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                        x_goal_list.append(a.goal_pos[:, 0:1])
                        y_goal_list.append(a.goal_pos[:, 1:2])
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    x_goal_batch = torch.cat(x_goal_list, dim=0)
                    y_goal_batch = torch.cat(y_goal_list, dim=0)
                    parameters_batch.extend(
                        [x_goal_batch, y_goal_batch]
                    )

                u_qp_batch, delta_batch = self.solve_qp_layer(
                    parameters_batch,
                    solver_args={"eps": 1e-8, "max_iters": 10000},
                )

                if u_qp_batch is None or delta_batch is None:
                    print("QP solver failed, applying fallback strategy.")
                    for a in self.world.agents:
                        a.action.u = torch.zeros(
                            (a.action.u.size(0), 2), device=self._world.device
                        )
                        vel_is_zero = torch.linalg.vector_norm(
                            a.action.u[:, :2], dim=1
                        ) < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()
                else:
                    split_sizes = [a.action.u.size(0) for a in self.world.agents]
                    u_qp_list = torch.split(u_qp_batch, split_sizes, dim=0)
                    delta_list = torch.split(delta_batch, split_sizes, dim=0)

                    for i, a in enumerate(self.world.agents):
                        a.action.u = u_qp_list[i]
                        a.action.delta = delta_list[i].squeeze()
                        a.action.u = TorchUtils.clamp_with_norm(a.action.u, self.u_range)
                        action_norm = torch.linalg.vector_norm(a.action.u, dim=1)
                        a.action.u[action_norm < self.min_input_norm] = 0
                        a.computed_velocity = a.action.u.clone()

                        vel_is_zero = action_norm < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()

    def extra_render(self, env_index: int = 0) -> "List[rendering.Geom]":
        geoms: List[rendering.Geom] = []

        # Draw walls
        for wall in self.walls:
            # Make wall coloring conditional
            wall_color = (0.0, 0.0, 0.0) # Default black
            if not self.disable_collision_viz and hasattr(wall, 'collision_timer') and wall.collision_timer[env_index] > 0:
                wall_color = (1.0, 0.0, 0.0)
            wall._color = wall_color

        # Draw a simplified reward heat map with fewer polygons
        # We'll create a grid with lower resolution
        grid_resolution = 15  # Reduced from 30
        x_range = torch.linspace(-self._world.x_semidim, self._world.x_semidim, grid_resolution)
        y_range = torch.linspace(-self._world.y_semidim, self._world.y_semidim, grid_resolution)
        
        # For each agent, render a reward heat map using polygons instead of many small circles
        for idx, agent in enumerate(self._world.agents):
            goal_pos = agent.goal_pos[env_index].cpu().numpy()
            
            # Create a grid of rectangles for the heat map with agent's color
            for i in range(len(x_range) - 1):
                for j in range(len(y_range) - 1):
                    x1, x2 = x_range[i].item(), x_range[i + 1].item()
                    y1, y2 = y_range[j].item(), y_range[j + 1].item()
                    
                    # Calculate the center point of this grid cell
                    center_x = (x1 + x2) / 2
                    center_y = (y1 + y2) / 2
                    
                    # Calculate distance from center to goal
                    dist = np.sqrt((center_x - goal_pos[0])**2 + (center_y - goal_pos[1])**2)
                    
                    # Calculate reward based on same formula as in agent_reward
                    decay_factor = 2.0
                    max_reward = 1.0
                    reward = max_reward * np.exp(-decay_factor * dist)
                    
                    # Only render cells with reward above threshold
                    if reward > 0.05:
                        # Alpha based on reward
                        alpha = min(reward * 0.3, 0.15)  # Lower max alpha for less visual clutter
                        
                        # Create a rectangle for this grid cell
                        rect = rendering.make_polygon([(x1, y1), (x2, y1), (x2, y2), (x1, y2)], filled=True)
                        rect.set_color(
                            agent.original_color[0],
                            agent.original_color[1],
                            agent.original_color[2],
                            alpha=alpha
                        )
                        geoms.append(rect)

        # Draw obstacles
        for obs in self.obstacles:
            # Make obstacle coloring conditional
            obstacle_color = (0.3, 0.3, 0.3) # Default gray
            if not self.disable_collision_viz and hasattr(obs, 'collision_timer') and obs.collision_timer[env_index] > 0:
                obstacle_color = (1.0, 0.0, 0.0)
            obstacle_geom = rendering.make_circle(self.obstacle_radius, res=20, filled=True)
            obstacle_transform = rendering.Transform(translation=obs.state.pos[env_index].cpu().numpy())
            obstacle_geom.add_attr(obstacle_transform)
            obstacle_geom.set_color(*obstacle_color)
            geoms.append(obstacle_geom)

        # Check for agent-agent collisions in the current frame
        # We'll do this explicitly to ensure visual feedback
        agent_colliding = [False] * len(self.world.agents)
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                # Check for collision
                distance = torch.norm(agent1.state.pos[env_index] - agent2.state.pos[env_index])
                if distance <= self.min_distance_between_entities:
                    agent_colliding[i] = True
                    agent_colliding[j] = True

        # Draw formation constraint connections between agents (simple black lines)
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                
                # Draw connection line between agents if within formation radius
                if agent_dist[env_index] <= self.formation_radius:
                    # Simple black lines for all connections
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(0.0, 0.0, 0.0, 0.8)  # Black with slight transparency
                    geoms.append(line)

        # Draw points of interest (goals) and agents (sensors)
        for idx, agent in enumerate(self._world.agents):
            # Draw the goal (point of interest)
            goal_pos = agent.goal_pos[env_index].cpu().numpy()
            
            # Draw a distinctive marker for the goal (star-like)
            goal_radius = 0.15
            star_points = 5
            outer_radius = goal_radius
            inner_radius = goal_radius * 0.5
            angle = 0
            angle_increment = 2 * np.pi / (2 * star_points)
            
            star_points_coords = []
            for i in range(2 * star_points):
                radius = outer_radius if i % 2 == 0 else inner_radius
                x = goal_pos[0] + radius * np.cos(angle)
                y = goal_pos[1] + radius * np.sin(angle)
                star_points_coords.append((x, y))
                angle += angle_increment
            
            goal_geom = rendering.make_polygon(star_points_coords, filled=True)
            goal_geom.set_color(
                agent.original_color[0] * 0.8, 
                agent.original_color[1] * 0.8, 
                agent.original_color[2] * 0.8
            )
            geoms.append(goal_geom)
            
            # Add a circular outline around the goal
            goal_outline = rendering.make_circle(goal_radius * 1.2, res=20, filled=False)
            goal_outline_transform = rendering.Transform(translation=goal_pos)
            goal_outline.add_attr(goal_outline_transform)
            goal_outline.set_color(
                agent.original_color[0], 
                agent.original_color[1], 
                agent.original_color[2], 
                alpha=0.8
            )
            goal_outline.set_linewidth(2)
            geoms.append(goal_outline)
            
            # Draw a line from agent to its goal to show assignment
            line_to_goal = rendering.Line(
                (agent.state.pos[env_index]),
                (agent.goal_pos[env_index]),
                width=1
            )
            line_to_goal.set_color(
                agent.original_color[0], 
                agent.original_color[1], 
                agent.original_color[2], 
                alpha=0.3
            )
            geoms.append(line_to_goal)
            
            # Draw the agent (sensor)
            # Check both the timer and current collision state
            agent_color = agent.original_color # Default original color
            if not self.disable_collision_viz:
                if hasattr(agent, 'collision_timer') and agent.collision_timer[env_index] > 0:
                    agent_color = (1.0, 0.0, 0.0)  # Red if timer active
                elif agent_colliding[idx]:
                    agent_color = (1.0, 0.0, 0.0)  # Red if currently colliding
            # agent_color = agent.original_color # This line and associated comments are removed
                
            # Update the agent's actual color property - this is the key fix!
            agent.color = agent_color
                
            # Agent's main body
            sensor_geom = rendering.make_circle(self.agent_radius, res=20, filled=True)
            sensor_transform = rendering.Transform(translation=agent.state.pos[env_index].cpu().numpy())
            sensor_geom.add_attr(sensor_transform)
            sensor_geom.set_color(*agent_color)
            geoms.append(sensor_geom)
            
            # Add a small direction indicator showing agent's velocity
            if torch.norm(agent.state.vel[env_index]) > 0.05:
                vel_dir = agent.state.vel[env_index] / torch.norm(agent.state.vel[env_index])
                vel_length = min(0.3, torch.norm(agent.state.vel[env_index]) * 0.5)
                end_x = agent.state.pos[env_index][0] + vel_length * vel_dir[0]
                end_y = agent.state.pos[env_index][1] + vel_length * vel_dir[1]
                
                direction = rendering.Line(
                    (agent.state.pos[env_index]),
                    (torch.tensor([end_x, end_y])),
                    width=2
                )
                direction.set_color(1.0, 1.0, 1.0)  # White direction indicator
                geoms.append(direction)

        return geoms


# Example usage:
if __name__ == "__main__":
    import sys
    import os
    # Add the parent directory (supplementary_material) to sys.path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils import use_vmas_env

    # Restore original loop for example usage
    for i in range(5):
        use_vmas_env(
            scenario=SensorCoverageScenario(),
            render=True, # Restore rendering
            save_render=False,
            num_envs=1, # Restore single environment
            n_steps=200, # Restore original steps
            device="cpu",
            continuous_actions=True,
            random_action=True,
            # deterministic_action_value=0.0, # Keep random actions
            n_agents=4, # Default value
            n_obstacles=2, # Default value
            agent_control_type=QP_CONTROL_CONSTRAINT, # Default back to QP_CONSTRAINT
            # Add other parameters from your specific config if they differ from defaults
            # e.g., formation_radius=X, lambda_u=Y, etc.
        )
