import typing
from typing import Dict, List

import torch
from torch import Tensor
from vmas.simulator.core import Agent, Sphere, World, Landmark, Line
from vmas.simulator.dynamics.holonomic import Holonomic
from vmas.simulator.dynamics.common import Dynamics
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import Color, ScenarioUtils, TorchUtils
from vmas.simulator.controllers.velocity_controller import VelocityController
from vmas.simulator import rendering

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

# Import RVO implementation
import numpy as np
# import RVO # Moved import inside compute_rvo_velocities

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom

import numpy as np

# Custom dynamics class for agents using QP
class HolonomicQP(Dynamics):
    @property
    def needed_action_size(self) -> int:
        return 2

    def process_action(self):
        self.agent.state.force = self.agent.action.u[:, :2]


# Agent control types
VELOCITY_CONTROL = "VELOCITY_CONTROL"
QP_CONTROL_OBJECTIVE_AND_CONSTRAINT = "QP_CONTROL_OBJECTIVE_AND_CONSTRAINT"
QP_CONTROL_OBJECTIVE = "QP_CONTROL_OBJECTIVE"
QP_CONTROL_CONSTRAINT = "QP_CONTROL_CONSTRAINT"
QP_CONTROL_CBF = "QP_CONTROL_CBF"
QP_CONTROL_CONSTRAINT_ROBOMASTER = "QP_CONTROL_CONSTRAINT_ROBOMASTER"
# New agent control type
QP_CONTROL_RVO = "QP_CONTROL_RVO"

# Scenario class definition
class NarrowCorridorScenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):

        # Set agent control type
        self.agent_control_type = kwargs.get(
            "agent_control_type", QP_CONTROL_RVO
        )

        self.n_agents = kwargs.get("n_agents", 2)
        self.v_range = kwargs.get("v_range", 1.0)
        self.a_range = kwargs.get("a_range", 1.0)
        self.linear_friction = kwargs.get("linear_friction", 0.0)
        self.collisions = kwargs.get("collisions", True)
        self.constraint_on = kwargs.get("constraint_on", True)

        self.max_velocity = kwargs.get(
            "max_velocity", 1.0
        )  # Maximum velocity for QP controller
        self.max_steps = kwargs.get("max_steps", 100)
        self.agent_radius = kwargs.get("agent_radius", 0.25)

        self.lambda_u = kwargs.get(
            "LAMBDA_U", 10.0
        )  # Penalty for agent-agent collision and out of bounds
        self.lambda_g = kwargs.get("LAMBDA_G", 1.0)  # Reward for reaching target region
        self.delta_scaling = kwargs.get(
            "delta_scaling", 99.0
        )  # Scaling factor for the slack variable

        # Read visualization options
        self.disable_collision_viz = kwargs.get("disable_collision_viz", False)

        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            self.min_input_norm = kwargs.get("min_input_norm", 0.12)
            self.robomaster_deployed = kwargs.get("robomaster_deployed", True)
            self.delta_scaling = 999.0
            self.agent_radius = kwargs.get("robomaster_agent_radius", 0.19)#0.18
            self.collision_gap =  kwargs.get("robomaster_collision_gap",0.0015) #0.001
            self.world_x_semidim = kwargs.get("robomaster_x_semidim_scaling_factor",1.28)*(self.agent_radius + self.collision_gap) #1.23*(self.agent_radius + self.collision_gap) 

            self.robomaster_model_on = kwargs.get("robomaster_model_on", True) # whether to use the QP controller

        else:
            self.min_input_norm = kwargs.get("min_input_norm", 0.08)
            self.collision_gap = 0.001
            self.world_x_semidim = 1.1*(self.agent_radius + self.collision_gap) 

        self.min_distance_between_entities = self.agent_radius * 2 + self.collision_gap
        self.world_y_semidim = 3.0

        self.bottom_region_height = 1.5
        self.top_region_height = 1.5


        self.f_range = self.a_range + self.linear_friction
        self.u_range = (
            self.v_range
            if self.agent_control_type == VELOCITY_CONTROL
            else self.f_range
        )

        controller_params = [2, 6, 0.002]

        self.resolution_factor = kwargs.pop("resolution_factor", 1200)
        self.viewer_zoom = 2
        self.viewer_size = kwargs.pop(
            "viewer_size",
            (
                int(self.resolution_factor),
                int(self.resolution_factor),
            ),
        )


        # Make world
        world = World(
            batch_dim,
            device,
            dt=0.1,
            substeps=1,
            linear_friction=self.linear_friction,
            drag=0.0,
            x_semidim=self.world_x_semidim + self.agent_radius,
            y_semidim=self.world_y_semidim + self.agent_radius,
        )

        known_colors = [
            (0.22, 0.40, 0.72),
            (0.22, 0.72, 0.40),
        ]

        # Determine action size and u_range based on control type
        if self.agent_control_type == VELOCITY_CONTROL:
            action_size = 2  # Velocity commands
            u_range = [self.u_range, self.u_range]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            action_size = 2 + 1 + 1 + 1  # 'a' (2D), 'b' (1D), x_target (1D), y_target (1D)
            u_range = [
                self.u_range,  # a[0]
                self.u_range,  # a[1]
                self.u_range,  # b
                world.x_semidim,  # x_target
                world.y_semidim,  # y_target
            ]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            action_size = 2  # x_target, y_target
            u_range = [world.x_semidim, world.y_semidim]
        elif self.agent_control_type == QP_CONTROL_CONSTRAINT or self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            action_size = 2 + 1  # 'a' (2D), 'b' (1D)
            u_range = [self.u_range, self.u_range, self.u_range]
        elif self.agent_control_type == QP_CONTROL_CBF:
            self.k_cbf_range = kwargs.get("k_cbf_range", 2)
            action_size = 1 # just k_cbf -- p_cbf is problematic due to negative exponents
            u_range = self.k_cbf_range
        elif self.agent_control_type == QP_CONTROL_RVO:
            action_size = 2  # Dummy action size, won't actually be used
            u_range = [1.0, 1.0]  # Dummy range, values won't be used
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        # Add agents to the world
        for i in range(self.n_agents):

            if self.agent_control_type == VELOCITY_CONTROL:
                dynamics = Holonomic()
            else:
                dynamics = HolonomicQP()

            agent = Agent(
                name=f"agent_{i}",
                collide=self.collisions,
                alpha=1.0,
                shape=Sphere(radius=self.agent_radius),
                render_action=True,
                v_range=self.v_range,
                f_range=self.f_range,
                u_range=u_range,
                dynamics=dynamics,
                action_size=action_size,
            )

            # Assign target directions and colors
            if i % 2 == 0:
                agent.target_direction = torch.ones(batch_dim, device=device)
                agent.original_color = known_colors[1]
            else:
                agent.target_direction = -torch.ones(batch_dim, device=device)
                agent.original_color = known_colors[0]

            agent.color = agent.original_color  # Initialize current color

            agent.controller = VelocityController(
                agent, world, controller_params, "standard"
            )

            agent.agent_collision_rew = torch.zeros(batch_dim, device=device)
            agent.individual_rew = agent.agent_collision_rew.clone()

            agent.computed_velocity = torch.zeros(batch_dim, 2, device=device)  # Initialize computed velocity

            # Initialize delta to zero
            if self.agent_control_type != VELOCITY_CONTROL:
                agent.action.delta = torch.zeros(batch_dim, device=device)
                agent.delta_rew = torch.zeros(batch_dim, device=device)  # Initialize delta_rew

            # Initialize collision_timer
            agent.collision_timer = torch.zeros(batch_dim, device=device)

            world.add_agent(agent)

        # Initialize walls
        self.walls = []
        for i in range(4):
            if i < 2:
                wall_length = 2 * (world.x_semidim)
            else:
                wall_length = 2 * (world.y_semidim)
            wall = Landmark(
                name=f"wall_{i}",
                collide=True,
                shape=Line(length=wall_length),
                color=Color.BLACK,
            )
            wall.original_color = (0.0, 0.0, 0.0)  # Store original color
            # Initialize collision_timer
            wall.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_landmark(wall)
            self.walls.append(wall)

        self._world = world

        # Spawn walls to position them correctly
        self.spawn_walls()

        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        return world

    def reset_world_at(self, env_index: int = None):
        # print("Resetting the environment.") # Debug print removed.

        # Reinitialize QP solver upon reset
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        # Pull in any new keyword argument that indicates harder initial conditions
        hard_conditions = False # Hard conditions currently disabled
        safety_margin = 1.05*(self.agent_radius + self.collision_gap)

        if not hard_conditions:
            # Original random spawn for all agents
            ScenarioUtils.spawn_entities_randomly(
                self._world.agents,
                self._world,
                env_index,
                self.min_distance_between_entities,
                (-self._world.x_semidim + safety_margin, self._world.x_semidim - safety_margin),
                (-self._world.y_semidim + safety_margin, self._world.y_semidim - safety_margin),
            )
        else:
            # Hard conditions: separate the agents by their target direction (+1 or -1)
            agents_up = []
            agents_down = []
            for ag in self._world.agents:
                # ag.target_direction is a 1D tensor of shape (batch_dim, ), 
                # but typically we only store +1 or -1 in each environment.
                # We'll assume they're consistent for each environment or we pick env_index=0.
                # If in a multi-env setting you have different directions, 
                # you might need to separate by (ag.target_direction[env_index] > 0) etc.
                if (ag.target_direction < 0).all():
                    agents_up.append(ag)
                else:
                    agents_down.append(ag)

            # Top region is from (top_region_boundary) up to (world_y_semidim)
            # In this scenario, top_region_boundary = world_y_semidim - self.top_region_height
            # But you can pick a narrower or broader band. For example:
            top_y_min = (self._world.y_semidim - self.top_region_height)
            top_y_max = self._world.y_semidim - safety_margin

            # Bottom region is from (−world_y_semidim) up to (bottom_region_boundary)
            # In this scenario, bottom_region_boundary = -world_y_semidim + self.bottom_region_height
            bottom_y_min = -self._world.y_semidim + safety_margin
            bottom_y_max = (-self._world.y_semidim + self.bottom_region_height)

            # Spawn the "up" agents in the top region (if any)
            if agents_up:
                ScenarioUtils.spawn_entities_randomly(
                    agents_up,
                    self._world,
                    env_index,
                    self.min_distance_between_entities,
                    (
                        -self._world.x_semidim + 0.5*safety_margin,
                        self._world.x_semidim - 0.5*safety_margin,
                    ),
                    (top_y_min, top_y_max),
                )

            # Spawn the "down" agents in the bottom region (if any)
            if agents_down:
                ScenarioUtils.spawn_entities_randomly(
                    agents_down,
                    self._world,
                    env_index,
                    self.min_distance_between_entities,
                    (
                        -self._world.x_semidim + 0.5*safety_margin,
                        self._world.x_semidim - 0.5*safety_margin,
                    ),
                    (bottom_y_min, bottom_y_max),
                )
            # print("Success") # Removed debug print

        # Reset any additional agent/controller state
        for i, agent in enumerate(self.world.agents):
            agent.controller.reset(env_index)

        # Re-spawn walls (in case of dynamic environments)
        self.spawn_walls(env_index)
        # print(f"Agent control type: {self.agent_control_type}") # Removed debug print


    def spawn_walls(self, env_index=None):
        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim
        for i, wall in enumerate(self.walls):
            if i == 0:  # Top wall
                pos = torch.tensor(
                    [0.0, world_y_semidim],
                    device=self._world.device,
                )
                rot = 0.0
            elif i == 1:  # Bottom wall
                pos = torch.tensor(
                    [0.0, -world_y_semidim],
                    device=self._world.device,
                )
                rot = 0.0
            elif i == 2:  # Left wall
                pos = torch.tensor(
                    [-world_x_semidim, 0.0],
                    device=self._world.device,
                )
                rot = torch.pi / 2
            elif i == 3:  # Right wall
                pos = torch.tensor(
                    [world_x_semidim, 0.0],
                    device=self._world.device,
                )
                rot = torch.pi / 2

            wall.set_pos(pos, batch_index=env_index)
            wall.set_rot(
                torch.tensor([rot], device=self._world.device), batch_index=env_index
            )

    def reward(self, agent: Agent):
        is_first = agent == self.world.agents[0]

        if is_first:
            for a in self.world.agents:
                self.agent_reward(a)
                a.agent_collision_rew[:] = 0

            for a in self.world.agents:
                # Agent-wall collisions
                for wall in self.walls:
                    if self._world.collides(a, wall):
                        distance = self._world.get_distance(a, wall)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u

                        # Set collision timers
                        a.collision_timer[collision_mask] = 5  # Set timer to 5 steps
                        wall.collision_timer[collision_mask] = 5

                # Agent-agent collisions
                for b in self.world.agents:
                    if a == b:
                        continue
                    if self._world.collides(a, b):
                        distance = self._world.get_distance(a, b)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u

                        # Set collision timers
                        a.collision_timer[collision_mask] = 5  # Set timer to 5 steps
                        b.collision_timer[collision_mask] = 5

            # Decrement collision timers
            for a in self.world.agents:
                if hasattr(a, 'collision_timer'):
                    a.collision_timer = torch.maximum(a.collision_timer - 1, torch.zeros_like(a.collision_timer))

            for wall in self.walls:
                if hasattr(wall, 'collision_timer'):
                    wall.collision_timer = torch.maximum(wall.collision_timer - 1, torch.zeros_like(wall.collision_timer))

        # Compute total reward by aggregating individual components
        if self.agent_control_type == QP_CONTROL_CBF:
            total_rew = agent.individual_rew + agent.agent_collision_rew + agent.delta_rew
        else:
            total_rew = agent.individual_rew + agent.agent_collision_rew

        return total_rew


    def agent_reward(self, agent: Agent):
        # Reset individual_rew for this timestep
        agent.individual_rew = torch.zeros(
            agent.state.pos.size(0), device=self._world.device
        )
        
        # Reset delta_rew for this timestep
        if self.agent_control_type == QP_CONTROL_CBF:
            agent.delta_rew = torch.zeros(agent.state.pos.size(0), device=self._world.device)

        # Penalty for being out of bounds
        out_of_bounds_x = torch.abs(agent.state.pos[:, 0]) > self._world.x_semidim
        out_of_bounds_y = torch.abs(agent.state.pos[:, 1]) > self._world.y_semidim
        out_of_bounds = out_of_bounds_x | out_of_bounds_y

        agent.individual_rew += torch.where(out_of_bounds, -self.lambda_u, 0)

        # Reward for reaching the target region
        top_region_boundary = self._world.y_semidim - self.top_region_height
        bottom_region_boundary = -self._world.y_semidim + self.bottom_region_height

        in_target_region = (
            agent.target_direction * agent.state.pos[:, 1] > top_region_boundary
        )


        agent.individual_rew += torch.where(
            in_target_region,
            self.lambda_g,
            0.01 * self.lambda_g * agent.state.vel[:, 1] * agent.target_direction,
        )

        # Penalize for accidentally reaching the wrong region
        penalize_wrong_region = True
        in_wrong_target_region = (
            -agent.target_direction * agent.state.pos[:, 1] > top_region_boundary
        )

        if penalize_wrong_region:
            agent.individual_rew += torch.where(
                in_wrong_target_region,
                -0.1*self.lambda_g,
                0.0,
        )

        # **Compute and assign delta_rew in QP_CONTROL_CBF mode**
        if self.agent_control_type == QP_CONTROL_CBF:
            # Ensure delta attribute exists
            if hasattr(agent.action, 'delta'):
                # Penalize delta; higher delta implies more constraint relaxation
                agent.delta_rew = -self.lambda_u * agent.action.delta
            else:
                # If delta is not set, optionally handle it (e.g., log a warning)
                # print(f"Warning: Agent {agent.name} has no delta value.") # Commented out warning
                agent.delta_rew = torch.zeros(agent.state.pos.size(0), device=self._world.device)
                pass # ensure block is not empty

        return agent.individual_rew

    def observation(self, agent: Agent):
        top_region_boundary = self._world.y_semidim - self.top_region_height
        bottom_region_boundary = -self._world.y_semidim + self.bottom_region_height
        distance_to_boundary = torch.where(
            agent.target_direction > 0,
            top_region_boundary - agent.state.pos[:, 1],
            agent.state.pos[:, 1] - bottom_region_boundary,
        )

        # Removed commented out prints for semidim and distances_to_walls block

        obs = {
            "agent_pos": agent.state.pos.clone(),
            "agent_vel": agent.state.vel.clone(),
            "target_direction": agent.target_direction.clone(),  # shape (batch_dim,)
            "distance_to_boundary": distance_to_boundary.clone(),  # shape (batch_dim,)
    #        "distances_to_walls": distances_to_walls.clone(),      # shape (batch_dim, 4)
        }
        
        return obs


    def done(self):
        dones = torch.zeros(
            self._world.batch_dim, device=self._world.device, dtype=torch.bool
        )

        # If the QP solver failed, set the done flag
        if not self.qp_solver_initialized:
            dones[:] = True  # Mark all environments as done

        return dones
    
    def info(self, agent: Agent) -> Dict[str, Tensor]:
        info = {
            "collision_reward": agent.agent_collision_rew,
        }

        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent, 'original_action_u'):
                info['original_action_u'] = agent.original_action_u.squeeze(-1)
            else:
                info['original_action_u'] = torch.zeros_like(agent.agent_collision_rew)

        if hasattr(agent.action, 'delta_rew'):
            info["delta_reward"] = agent.delta_rew

        info["individual_reward"] = agent.individual_rew
        info["total_reward_without_delta_rew"] = agent.individual_rew + agent.agent_collision_rew
        
        return info

    def init_qp_layer(self):
        if self.agent_control_type == VELOCITY_CONTROL:
            return None

        # Define control input and slack variable
        u = cp.Variable(2)  # Control inputs for a single agent
        delta = cp.Variable(1)  # Slack variable for a single agent

        # Parameters for the QP
        ego_position = cp.Parameter(2)  # Position of the ego-agent
        ego_velocity = cp.Parameter(2)  # Velocity of the ego-agent

        # Collision avoidance parameters
        obstacle_positions = cp.Parameter(
            (self.n_agents - 1, 2)
        )  # Positions of other agents
        h_values = cp.Parameter(
            self.n_agents - 1
        )  # Current distances to other agents
        dx_vel_0 = cp.Parameter(self.n_agents - 1)  # dx * ego_velocity[0]
        dy_vel_1 = cp.Parameter(self.n_agents - 1)  # dy * ego_velocity[1]

        # Define constraints and objective based on agent control type
        constraints = []
        parameters = [
            ego_position,
            ego_velocity,
            obstacle_positions,
            h_values,
            dx_vel_0,
            dy_vel_1,
        ]

        # delta >= 0
        constraints.append(delta >= 0)

        # Collision avoidance constraints
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            # Add slack variable s for Robomaster constraints
            s = cp.Variable(1)  # Slack variable for a single agent
            for j in range(self.n_agents - 1):
                h = h_values[j]
                lgh = 2 * (
                    dx_vel_0[j] + dy_vel_1[j] + 
                    u[0] * (ego_position[0] - obstacle_positions[j, 0]) + 
                    u[1] * (ego_position[1] - obstacle_positions[j, 1])
                )
                constraints.append(lgh + h + s >= 0)

        else:
            for j in range(self.n_agents - 1):
                h = h_values[j]
                lgh = 2 * (
                    dx_vel_0[j] + dy_vel_1[j] + 
                    u[0] * (ego_position[0] - obstacle_positions[j, 0]) + 
                    u[1] * (ego_position[1] - obstacle_positions[j, 1])
                )
                constraints.append(lgh + h  >= 0)


        # Stay within bounds
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            safety_margin = 1.18*(self.agent_radius + self.collision_gap)
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0]) <= self._world.x_semidim - safety_margin + s,
                    ego_position[0] + (ego_velocity[0] + u[0]) >= -self._world.x_semidim + safety_margin - s,
                    ego_position[1] + (ego_velocity[1] + u[1]) <= self._world.y_semidim - safety_margin + s,
                    ego_position[1] + (ego_velocity[1] + u[1]) >= -self._world.y_semidim + safety_margin - s,
                    cp.norm(u) <= self.max_velocity + s,
                    s >= 0
                ]
            )

        else:
            safety_margin = 1.05*(self.agent_radius + self.collision_gap)
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0]) <= self._world.x_semidim - safety_margin,
                    ego_position[0] + (ego_velocity[0] + u[0]) >= -self._world.x_semidim + safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1]) <= self._world.y_semidim - safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1]) >= -self._world.y_semidim + safety_margin,
                    cp.norm(u) <= self.max_velocity,
                    delta >= 0,
                ]
            )

        if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            # Agents control both 'a', 'b' and x_target, y_target
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([a, b, x_target, y_target])

            # Circle constraint
            constraints.append(cp.norm(u - a) <= b + delta)

            # Objective: minimize distance to (x_target, y_target)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling*delta
            )

        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            # Define target positions as symbolic parameters
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([x_target, y_target])

            # Objective: minimize distance to (x_target, y_target)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling*delta
            )
        
        elif self.agent_control_type == QP_CONTROL_RVO:
            # TODO: This parameter calculation seems wrong inside init_qp_layer.
            #       It should likely just define x_target, y_target as parameters.
            #       Commenting out the calculation for now.
            # # Get target positions from the RVO velocities computed earlier
            # x_target_list = []
            # y_target_list = []
            # 
            # for i, a in enumerate(self.world.agents):
            #     # Convert RVO velocity to a target position (current position + RVO velocity)
            #     rvo_vel = a.rvo_velocity
            #     current_pos = a.state.pos
            #     # Calculate target position as current position + RVO velocity
            #     target_pos = current_pos + rvo_vel                        
            #     x_target_list.append(target_pos[:, 0:1])
            #     y_target_list.append(target_pos[:, 1:2])
            # 
            # x_target_batch = torch.cat(x_target_list, dim=0)
            # y_target_batch = torch.cat(y_target_list, dim=0)
            # parameters_batch.extend([x_target_batch, y_target_batch])
            
            # Define target positions as parameters for RVO control objective
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([x_target, y_target])
            # Define objective for RVO (e.g., minimize distance to RVO-derived target)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling*delta # Keep delta term for consistency/potential use
            )

        elif self.agent_control_type == QP_CONTROL_CONSTRAINT:
            # Agent controls the constraints ('a', 'b')
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            target_direction = cp.Parameter(1)
            parameters.extend([a, b, target_direction])

            # Conic constraint
            if self.constraint_on:
                constraints.append(cp.norm(u - a) <= b + delta)
            else:
                constraints.append(0*a.T @ u <= 0*b + delta)
            # constraints.append(a.T @ u <= b + delta) # Removed Linear alternative constraint

            # Objective: maximize movement in target direction
            objective = cp.Minimize(-target_direction * (u[1]) + self.delta_scaling*delta)
        elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            # Agent controls the constraints ('a', 'b')
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            target_direction = cp.Parameter(1)
            parameters.extend([a, b, target_direction])

            # Conic constraint
            if self.robomaster_model_on:
                 constraints.append(cp.norm(u - a) <= b + delta + s)
            else:
                constraints.append(0*a.T @ u <= 0*b + delta + s)

            # Objective: maximize movement in target direction
            objective = cp.Minimize(-target_direction * (u[1]) + self.delta_scaling*delta + 500*self.delta_scaling*s)

        elif self.agent_control_type == QP_CONTROL_CBF:
            # Objective: maximize movement in target direction
            # No new parameters - h function controlled in process_action()
            target_direction = cp.Parameter(1)
            parameters.extend([target_direction])
            objective = cp.Minimize(-target_direction * (u[1]) + self.delta_scaling*delta)

        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        # Define the QP problem
        prob = cp.Problem(objective, constraints)

        # Create the CvxpyLayer
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            qp_layer = CvxpyLayer(
                prob,
                parameters=parameters,
                variables=[u, delta, s],
            )
        else:    
            qp_layer = CvxpyLayer(
                prob,
                parameters=parameters,
                variables=[u, delta],
            )

        return qp_layer

    def solve_qp_layer(
        self,
        parameters_batch,
        solver_args,
    ):
        try:
            # Solve the QP for all agents in the batch
            with torch.no_grad():
                if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                    u_qp_batch, delta_batch, s_batch = self.qp_layer(
                        *parameters_batch,
                        solver_args=solver_args,
                    )
                    return u_qp_batch, delta_batch, s_batch
                else:
                    u_qp_batch, delta_batch = self.qp_layer(
                        *parameters_batch,
                        solver_args=solver_args,
                    )
            return u_qp_batch, delta_batch

        except Exception as e:
            print(f"SolverError encountered: {e}. Marking the environment as done.")
            # Set done flag for all environments
            self.qp_solver_initialized = False
            return None, None

    def process_action(self, agent: Agent):
        if self.agent_control_type == VELOCITY_CONTROL:
            # Zero small input
            agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, self.u_range)
            action_norm = torch.linalg.vector_norm(agent.action.u, dim=1)
            agent.action.u[action_norm < self.min_input_norm] = 0

            # Reset controller
            vel_is_zero = torch.linalg.vector_norm(agent.action.u, dim=1) < 1e-3
            agent.controller.reset(vel_is_zero)
            agent.controller.process_force()

        else:
            is_first = agent == self.world.agents[0]

            if is_first:
                # Collect parameters from all agents
                pos_list = []
                vel_list = []
                obstacle_positions_list = []
                current_h_values_list = []
                dx_vel_0_list = []
                dy_vel_1_list = []

                # Prepare parameters for each agent
                for idx, a in enumerate(self.world.agents):
                    pos_list.append(a.state.pos)
                    vel_list.append(a.state.vel)

                    # **Store the original action before overwriting**
                    a.original_action_u = a.action.u.clone()

                # Handle RVO computation if needed
                if self.agent_control_type == QP_CONTROL_RVO:
                    # Compute RVO velocities for all agents and environments
                    rvo_velocities = self.compute_rvo_velocities(pos_list, vel_list)
                    # Store computed RVO velocities for visualization
                    for i, a in enumerate(self.world.agents):
                        a.rvo_velocity = rvo_velocities[i]

                # Prepare obstacle positions and collision parameters
                for agent_idx in range(self.n_agents):
                    for env_idx in range(self._world.batch_dim):
                        ego_pos = pos_list[agent_idx][env_idx]
                        ego_vel = vel_list[agent_idx][env_idx]

                        other_pos = torch.stack(
                            [
                                pos_list[i][env_idx]
                                for i in range(self.n_agents)
                                if i != agent_idx
                            ],
                            dim=0,
                        )

                        obstacle_positions_list.append(other_pos)

                        distances = torch.norm(other_pos - ego_pos, dim=1) ** 2
                        if self.agent_control_type == QP_CONTROL_CBF:
                            # Extract from the agent whose intex is agent_idx, copy number env_index
                            k_cbf = (self.world.agents[agent_idx].action.u[env_idx][0] + self.k_cbf_range) / 2
                            p_cbf = 1.0

                            # Removed temporary k_cbf override
                            current_h_values_list.append(k_cbf * (distances - self.min_distance_between_entities ** 2)**p_cbf)
                        elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                            k_cbf = 2.0 # was 1.05
                            p_cbf = 1.0
                            current_h_values_list.append(k_cbf * (distances - (self.min_distance_between_entities + 0.01) ** 2)**p_cbf)
                        else:
                            current_h_values_list.append(distances - self.min_distance_between_entities ** 2)

                        dx = ego_pos[0] - other_pos[:, 0]
                        dy = ego_pos[1] - other_pos[:, 1]
                        dx_vel_0_list.append(dx * ego_vel[0])
                        dy_vel_1_list.append(dy * ego_vel[1])

                # Stack parameters into batches
                ego_positions_batch = torch.cat(pos_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                ego_velocities_batch = torch.cat(vel_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                obstacle_positions_batch = torch.stack(obstacle_positions_list, dim=0)
                current_h_values_batch = torch.stack(
                    current_h_values_list, dim=0
                )
                dx_vel_0_batch = torch.stack(dx_vel_0_list, dim=0)
                dy_vel_1_batch = torch.stack(dy_vel_1_list, dim=0)

                # Prepare parameters based on control type
                parameters_batch = [
                    ego_positions_batch,
                    ego_velocities_batch,
                    obstacle_positions_batch,
                    current_h_values_batch,
                    dx_vel_0_batch,
                    dy_vel_1_batch,
                ]

                if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                    # Get 'a', 'b', x_target, y_target from agent actions
                    a_list = []
                    b_list = []
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        x_target_list.append(a.action.u[:, 3:4])
                        y_target_list.append(a.action.u[:, 4:5])
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, x_target_batch, y_target_batch]
                    )

                elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
                    # Get target positions from agent actions
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        # Assuming x_target is the first action component
                        x_target_list.append(a.action.u[:, 0:1]) 
                        # Assuming y_target is the second action component
                        y_target_list.append(a.action.u[:, 1:2]) 
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend([x_target_batch, y_target_batch])

                elif self.agent_control_type == QP_CONTROL_RVO:
                    # Get target positions from the RVO velocities computed earlier
                    x_target_list = []
                    y_target_list = []
                    
                    for i, a in enumerate(self.world.agents):
                        # Convert RVO velocity to a target position (current position + RVO velocity)
                        rvo_vel = a.rvo_velocity
                        current_pos = a.state.pos
                        # Calculate target position as current position + RVO velocity
                        target_pos = current_pos + rvo_vel                        
                        x_target_list.append(target_pos[:, 0:1])
                        y_target_list.append(target_pos[:, 1:2])
                    
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend([x_target_batch, y_target_batch])

                elif self.agent_control_type == QP_CONTROL_CONSTRAINT or self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                    # Get 'a' and 'b' from agent actions
                    a_list = []
                    b_list = []
                    target_direction_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, target_direction_batch]
                    )
                
                elif self.agent_control_type == QP_CONTROL_CBF:
                    target_direction_list = []
                    for a in self.world.agents:
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    parameters_batch.extend(
                        [target_direction_batch]
                    )

                # Solve the QP
                if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                    u_qp_batch, delta_batch, s_batch = self.solve_qp_layer(
                        parameters_batch,
                        solver_args={"eps": 1e-8},
                    )
                else:
                    u_qp_batch, delta_batch, = self.solve_qp_layer(
                        parameters_batch,
                        solver_args={"eps": 1e-8},
                    )

                # Check if the solver returned None (indicating a failure)
                if u_qp_batch is None or delta_batch is None:
                    print("QP solver failed, applying fallback strategy.")
                    # Apply a fallback strategy: set the agent's action to zero
                    for a in self.world.agents:
                        a.action.u = torch.zeros(
                            (a.action.u.size(0), 2), device=self._world.device
                        )  # Zero out all action components as a fallback
                        vel_is_zero = torch.linalg.vector_norm(
                            a.action.u[:, :2], dim=1
                        ) < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()
                else:
                    # Split and assign the results back to agents
                    split_sizes = [a.action.u.size(0) for a in self.world.agents]
                    u_qp_list = torch.split(u_qp_batch, split_sizes, dim=0)
                    delta_list = torch.split(delta_batch, split_sizes, dim=0)

                    for i, a in enumerate(self.world.agents):
                        a.action.u = u_qp_list[i]
                        a.action.delta = delta_list[i].squeeze()

                        # Clamp control inputs
                        a.action.u = TorchUtils.clamp_with_norm(a.action.u, self.u_range)
                        action_norm = torch.linalg.vector_norm(a.action.u, dim=1)
                        a.action.u[action_norm < self.min_input_norm] = 0
                        a.computed_velocity = a.action.u.clone()

                        vel_is_zero = action_norm < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()
                        
    def compute_rvo_velocities(self, pos_list, vel_list):
        """Compute RVO velocities for all agents and all environments.
        
        Args:
            pos_list: List of tensors with agent positions [batch_dim, 2]
            vel_list: List of tensors with agent velocities [batch_dim, 2]
            
        Returns:
            List of tensors with RVO velocities [batch_dim, 2]
        """
        import RVO # Import RVO here
        rvo_velocities = []
        
        # Process each environment separately
        for env_idx in range(self._world.batch_dim):
            # Create numpy array for RVO input: [px, py, vx, vy, direction] * N
            rvo_input = np.zeros((self.n_agents, 4))
            
            # Fill in position and velocity data
            for i, (pos, vel) in enumerate(zip(pos_list, vel_list)):
                rvo_input[i, 0] = pos[env_idx, 0].item()  # px
                rvo_input[i, 1] = pos[env_idx, 1].item()  # py
                rvo_input[i, 2] = vel[env_idx, 0].item()  # vx
                rvo_input[i, 3] = vel[env_idx, 1].item()  # vy
            
            # Get target directions
            target_directions = np.array([agent.target_direction[env_idx].item() for agent in self.world.agents])
            
            # Call RVO to compute velocities
            rvo_output = RVO.compute_velocities(rvo_input, target_directions)
            
            # If RVO failed, use a simple fallback
            if rvo_output is None:
                rvo_output = np.zeros((self.n_agents, 2))
                # Set a small velocity in the target direction as fallback
                for i, agent in enumerate(self.world.agents):
                    rvo_output[i, 1] = 0.1 * agent.target_direction[env_idx].item()
            
            # Convert RVO output to torch tensors and store
            for i in range(self.n_agents):
                if i >= len(rvo_velocities):
                    # Initialize tensor for this agent
                    rvo_velocities.append(torch.zeros((self._world.batch_dim, 2), device=self._world.device))
                # Set computed velocity for this environment
                rvo_velocities[i][env_idx, 0] = rvo_output[i, 0]
                rvo_velocities[i][env_idx, 1] = rvo_output[i, 1]
        
        return rvo_velocities

    def extra_render(self, env_index: int = 0) -> "List[Geom]":
        geoms: List[rendering.Geom] = []

        # Draw the walls
        for wall in self.walls:
            # Determine wall color based on collision_timer
            if hasattr(wall, 'collision_timer') and wall.collision_timer[env_index] > 0 and not self.disable_collision_viz:
                wall_color = (1.0, 0.0, 0.0)  # Red color for collision
            else:
                wall_color = (0.0, 0.0, 0.0)  # Original color
            wall._color = wall_color

        # Draw the world boundaries
        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim

        # Draw bottom region
        bottom_region = rendering.make_polygon(
            [
                (-world_x_semidim, -world_y_semidim),
                (world_x_semidim, -world_y_semidim),
                (
                    world_x_semidim,
                    -world_y_semidim + self.bottom_region_height,
                ),
                (
                    -world_x_semidim,
                    -world_y_semidim + self.bottom_region_height,
                ),
            ],
            filled=True,
            draw_border=True,
        )
        bottom_region.set_color(0.3, 0.3, 1, alpha=0.5)
        geoms.append(bottom_region)

        # Draw top region
        top_region = rendering.make_polygon(
            [
                (-world_x_semidim, world_y_semidim),
                (world_x_semidim, world_y_semidim),
                (
                    world_x_semidim,
                    world_y_semidim - self.top_region_height,
                ),
                (
                    -world_x_semidim,
                    world_y_semidim - self.top_region_height,
                ),
            ],
            filled=True,
            draw_border=True,
        )
        top_region.set_color(0.2, 0.75, 0.2, alpha=0.5)
        geoms.append(top_region)

        # Draw agents with shaded colors based on k_cbf
        for idx, agent in enumerate(self._world.agents):
             # Determine agent color based on collision_timer
             # Only show red if timer > 0 AND collision viz is enabled
            if hasattr(agent, 'collision_timer') and agent.collision_timer[env_index] > 0 and not self.disable_collision_viz:
                agent_color = (1.0, 0.0, 0.0)  # Red color for collision
            else:
                agent_color = agent.original_color  # Original color
                
            agent.color = agent_color

            # Optionally, draw target points based on QP objective
            if self.agent_control_type in [
                QP_CONTROL_OBJECTIVE,
                QP_CONTROL_OBJECTIVE_AND_CONSTRAINT,
                QP_CONTROL_RVO,
            ]:
                # Retrieve x_target and y_target from the agent's original action
                if hasattr(agent, "original_action_u"):
                    if self.agent_control_type == QP_CONTROL_OBJECTIVE:
                        x_t = agent.original_action_u[env_index, 0].cpu().numpy()
                        y_t = agent.original_action_u[env_index, 1].cpu().numpy()
                    elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                        x_t = agent.original_action_u[env_index, 3].cpu().numpy()
                        y_t = agent.original_action_u[env_index, 4].cpu().numpy()
                    elif self.agent_control_type == QP_CONTROL_RVO:
                        # For RVO, draw the target point from current position + RVO velocity
                        curr_pos = agent.state.pos[env_index].cpu().numpy()
                        if hasattr(agent, 'rvo_velocity'):
                            rvo_vel = agent.rvo_velocity[env_index].cpu().numpy()
                            # Draw the RVO velocity as a line
                            rvo_line = rendering.Line(
                                (curr_pos[0], curr_pos[1]),
                                (curr_pos[0] + rvo_vel[0], curr_pos[1] + rvo_vel[1]),
                                width=3,
                            )
                            rvo_line.set_color(0.5, 0.0, 0.5, alpha=0.8)  # Purple color
                            geoms.append(rvo_line)
                            
                            # Target point is current position + RVO velocity
                            x_t = curr_pos[0] + rvo_vel[0]
                            y_t = curr_pos[1] + rvo_vel[1]
                        else:
                            continue  # Skip if no RVO velocity available

                    target_pos = (x_t, y_t)

                    # Create a small circle to represent the target
                    target_circle = rendering.make_circle(
                        radius=0.05, res=10, filled=True
                    )
                    target_transform = rendering.Transform(translation=target_pos)
                    target_circle.add_attr(target_transform)
                    
                    if self.agent_control_type == QP_CONTROL_RVO:
                        target_circle.set_color(0.5, 0.0, 0.5, alpha=0.8)  # Purple color
                    else:
                        target_circle.set_color(0.0, 1.0, 0.0, alpha=0.8)  # Green color

                    geoms.append(target_circle)

            # Initialize shade factor
            shade_factor = 1.0  # Default shade (no shading)

            if self.agent_control_type == QP_CONTROL_CBF:
                # Extract k_cbf for the current environment
                k_cbf = agent.original_action_u[env_index][0].item()
                k_cbf = (k_cbf + self.k_cbf_range) / 2
                normalized_k_cbf = k_cbf / self.k_cbf_range
                # Define shading range (e.g., from 0.0 to 1.0)
                min_shade = 0.5
                shade_factor = min_shade + (1.0 - min_shade) * normalized_k_cbf

            # Compute shaded color by multiplying original RGB with shade_factor
            original_color = agent.color  # Tuple of (R, G, B)
            shaded_color = (
                original_color[0] * shade_factor,
                original_color[1] * shade_factor,
                original_color[2] * shade_factor,
            )

            # Update agent color
            agent.color = shaded_color

            # Create the agent's graphical representation
            agent_geom = rendering.make_circle(
                radius=self.agent_radius,
                res=20,
                filled=True,
            )
            agent_transform = rendering.Transform(translation=agent.state.pos[env_index].cpu().numpy())
            agent_geom.add_attr(agent_transform)
            agent_geom.set_color(*shaded_color)
            geoms.append(agent_geom)

        # Communication lines
        edge_radius = 1.5
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= edge_radius:
                    color = Color.BLACK.value
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)
        return geoms

if __name__ == "__main__":
    import sys
    import os
    # Add the parent directory (supplementary_material) to sys.path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils import use_vmas_env

    for i in range(10):
        use_vmas_env(
            scenario=NarrowCorridorScenario(),
            render=True,
            save_render=False,
            num_envs=1,
            n_steps=500,
            device="cpu",
            continuous_actions=True,
            random_action=True,
            deterministic_action_value=0.0,
            n_agents=8,
        )
