import typing
from typing import Dict, List

import torch
from torch import Tensor
from vmas.simulator.core import Agent, Sphere, World, Landmark, Line
from vmas.simulator.dynamics.holonomic import Holonomic
from vmas.simulator.dynamics.common import Dynamics
from vmas.simulator.scenario import BaseScenario
from vmas.simulator.utils import Color, TorchUtils, ScenarioUtils
from vmas.simulator.controllers.velocity_controller import VelocityController
from vmas.simulator import rendering

import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer

if typing.TYPE_CHECKING:
    from vmas.simulator.rendering import Geom

import numpy as np


# Agent control types
VELOCITY_CONTROL = "VELOCITY_CONTROL"
QP_CONTROL_OBJECTIVE_AND_CONSTRAINT = "QP_CONTROL_OBJECTIVE_AND_CONSTRAINT"
QP_CONTROL_OBJECTIVE = "QP_CONTROL_OBJECTIVE"
QP_CONTROL_CONSTRAINT = "QP_CONTROL_CONSTRAINT"
QP_CONTROL_CBF = "QP_CONTROL_CBF"
QP_CONTROL_CONSTRAINT_ROBOMASTER = "QP_CONTROL_CONSTRAINT_ROBOMASTER"

class HolonomicQP(Dynamics):
    @property
    def needed_action_size(self) -> int:
        return 2

    def process_action(self):
        self.agent.state.force = self.agent.action.u[:, :2]

class ConnectivityScenario(BaseScenario):
    def make_world(self, batch_dim: int, device: torch.device, **kwargs):

        self.agent_control_type = kwargs.get(
            "agent_control_type", QP_CONTROL_CONSTRAINT
        )

        self.n_agents = kwargs.get("n_agents", 4)
        self.n_obstacles = kwargs.get("n_obstacles", 2)  # Number of obstacles
        self.v_range = kwargs.get("v_range", 1)#was 1
        self.a_range = kwargs.get("a_range", 1.0)
        self.linear_friction = kwargs.get("linear_friction", 0.0)
        self.collisions = kwargs.get("collisions", True)

        self.max_velocity = kwargs.get("max_velocity", 0.5) #was 5
        self.max_steps = kwargs.get("max_steps", 100)
        self.agent_radius = kwargs.get("agent_radius", 0.25)

        self.lambda_u = kwargs.get("LAMBDA_U", 10.0)
        self.lambda_g = kwargs.get("LAMBDA_G", 1.0)
        self.delta_scaling = kwargs.get("delta_scaling", 99.0)
        self.formation_radius = kwargs.get("formation_radius", 1.75)

        # Read QP constraint flag
        self.constraint_on = kwargs.get("constraint_on", True)

        # Read experimental obstacle flag
        self.experimental_obstacle_placement = kwargs.get("experimental_obstacle_placement", False)
        self.disable_collision_viz = kwargs.get("disable_collision_viz", False)
        self.disable_action_render = kwargs.get("disable_action_render", False)
        self.secondary_experimental_placement = kwargs.get("secondary_experimental_placement", False)

        self.collision_gap = 0.001
        self.min_distance_between_entities = self.agent_radius * 2 + self.collision_gap
        self.world_x_semidim = 2.2*(self.agent_radius + self.collision_gap) 
        self.world_y_semidim = 3.0

        self.bottom_region_height = 1.5
        self.top_region_height = 1.5

        if self.agent_control_type == QP_CONTROL_CONSTRAINT:
            self.min_input_norm = kwargs.get("min_input_norm", 0.14)
            self.robomaster_deployed = kwargs.get("robomaster_deployed", True)
            self.delta_scaling = 99999.0
        else:
            self.min_input_norm = kwargs.get("min_input_norm", 0.08)

        self.f_range = self.a_range + self.linear_friction
        self.u_range = (
            self.v_range
            if self.agent_control_type == VELOCITY_CONTROL
            else self.f_range
        )

        controller_params = [2, 6, 0.002]

        self.resolution_factor = kwargs.pop("resolution_factor", 1200)
        self.viewer_zoom = 2
        self.viewer_size = kwargs.pop(
            "viewer_size",
            (
                int(self.resolution_factor),
                int(self.resolution_factor),
            ),
        )


        world = World(
            batch_dim,
            device,
            dt=0.1,
            substeps=1,
            linear_friction=self.linear_friction,
            drag=0.0,
            x_semidim=self.world_x_semidim + self.agent_radius,
            y_semidim=self.world_y_semidim + self.agent_radius,
        )

        known_colors = [
            (0.22, 0.40, 0.72),
            (0.22, 0.72, 0.40),
        ]

        if self.agent_control_type == VELOCITY_CONTROL:
            action_size = 2
            u_range = [self.u_range, self.u_range]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            action_size = 2 + 1 + 1 + 1
            u_range = [
                self.u_range,
                self.u_range,
                self.u_range,
                self.world_x_semidim,
                self.world_y_semidim,
            ]
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            action_size = 2
            u_range = [self.world_x_semidim, self.world_y_semidim]
        elif self.agent_control_type == QP_CONTROL_CONSTRAINT or self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            action_size = 2 + 1
            u_range = [self.u_range, self.u_range, self.u_range]
        elif self.agent_control_type == QP_CONTROL_CBF:
            self.k_cbf_range = kwargs.get("k_cbf_range", 4.0)
            action_size = 1
            u_range = self.k_cbf_range
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        # Add agents
        for i in range(self.n_agents):
            if self.agent_control_type == VELOCITY_CONTROL:
                dynamics = Holonomic()
            else:
                dynamics = HolonomicQP()

            agent = Agent(
                name=f"agent_{i}",
                collide=self.collisions,
                alpha=1.0,
                shape=Sphere(radius=self.agent_radius),
                render_action=not self.disable_action_render,
                v_range=self.v_range,
                f_range=self.f_range,
                u_range=u_range,
                dynamics=dynamics,
                action_size=action_size,
            )

            agent.target_direction = torch.ones(batch_dim, device=device)
            agent.original_color = known_colors[1]

            agent.color = agent.original_color

            agent.controller = VelocityController(
                agent, world, controller_params, "standard"
            )

            agent.agent_collision_rew = torch.zeros(batch_dim, device=device)
            agent.individual_rew = agent.agent_collision_rew.clone()
            agent.computed_velocity = torch.zeros(batch_dim, 2, device=device)

            if self.agent_control_type != VELOCITY_CONTROL:
                agent.action.delta = torch.zeros(batch_dim, device=device)
                agent.delta_rew = torch.zeros(batch_dim, device=device)

            agent.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_agent(agent)

        # Initialize walls
        self.walls = []
        for i in range(4):
            if i < 2:
                wall_length = 2 * (world.x_semidim)
            else:
                wall_length = 2 * (world.y_semidim)
            wall = Landmark(
                name=f"wall_{i}",
                collide=True,
                shape=Line(length=wall_length),
                color=Color.BLACK,
            )
            wall.original_color = (0.0, 0.0, 0.0)
            wall.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_landmark(wall)
            self.walls.append(wall)

        # Add multiple static circular obstacles
        self.obstacle_radius = self.agent_radius
        self.obstacles = []
        for i in range(self.n_obstacles):
            obs = Landmark(
                name=f"obstacle_{i}",
                collide=True,
                shape=Sphere(radius=self.obstacle_radius),
                color=Color.BLACK,
            )
            obs.collision_timer = torch.zeros(batch_dim, device=device)
            world.add_landmark(obs)
            self.obstacles.append(obs)

        self._world = world
        self.spawn_walls()
        self.spawn_obstacles()

        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        return world

    def reset_world_at(self, env_index: int = None):
        self.qp_layer = self.init_qp_layer()
        self.qp_solver_initialized = True

        safety_margin = 1.5*(self.agent_radius + self.collision_gap)

        x_positions = [-0.5, 0.5] 
        y_start = (-self._world.y_semidim) + safety_margin
        y_spacing = 0.6  # vertical spacing between rows

        # Position agents near the bottom
        for i, agent in enumerate(self.world.agents):
            row = i // 2
            col = i % 2
            x = x_positions[col]
            y = y_start + row * y_spacing

            pos = torch.tensor([x, y], device=self._world.device).unsqueeze(0)
            agent.set_pos(pos, batch_index=env_index)
            # Set agent's initial velocity to zero
            agent.set_vel(torch.zeros_like(agent.state.vel), batch_index=env_index)

        for i, agent in enumerate(self.world.agents):
            agent.controller.reset(env_index)

        self.spawn_walls(env_index)
        self.spawn_obstacles(env_index)

        for obs in self.obstacles:
            obs.collision_timer = torch.zeros_like(obs.collision_timer)

    def _spawn_obstacles_randomly(self, env_index=None, start_index=0):
        """Helper function for original random obstacle spawning."""
        obstacles_to_spawn = self.obstacles[start_index:]
        if not obstacles_to_spawn:
            return

        ScenarioUtils.spawn_entities_randomly(
            obstacles_to_spawn,
            self._world,
            env_index,
            self.min_distance_between_entities,
            # Spawn obstacles within the arena bounds, avoiding the agent starting area
            (-self._world.x_semidim + self.obstacle_radius, self._world.x_semidim - self.obstacle_radius),
            (0, self._world.y_semidim - self.obstacle_radius) # Spawn above y=0
        )

    def spawn_obstacles(self, env_index=None):
        batch_dim = self._world.batch_dim
        device = self._world.device

        if self.secondary_experimental_placement:
            # print("Using SECONDARY experimental obstacle placement (horizontal gap).") # Removed debug print
            # Parameters for the gap
            gap_width = self.agent_radius * 2.5  # Slightly wider than two agents
            obstacle_y_position = 0.0
            obstacle_x_offset = (gap_width / 2) + self.obstacle_radius - 0.15

            if len(self.obstacles) < 2:
                 # print("Warning: Need at least 2 obstacles for secondary experimental placement. Falling back to random.") # Commented out warning
                 self._spawn_obstacles_randomly(env_index)
                 return

            # Place the first obstacle to the left of the gap
            obs1_pos_x = obstacle_x_offset + 0.1
            obs1_pos_y = obstacle_y_position - 0.4
            # Expand pos for batch dimension only if env_index is None
            obs1_pos_tensor = torch.tensor([[obs1_pos_x, obs1_pos_y]], device=device)
            if env_index is None:
                obs1_pos_batch = obs1_pos_tensor.repeat(batch_dim, 1)
                self.obstacles[0].set_pos(obs1_pos_batch, batch_index=None)
            else:
                self.obstacles[0].set_pos(obs1_pos_tensor, batch_index=env_index)

            # Place the second obstacle to the right of the gap and slightly higher
            obs2_pos_x = obstacle_x_offset
            obs2_pos_y = obstacle_y_position + self.obstacle_radius * 1.5 # Increase offset
            obs2_pos_tensor = torch.tensor([[obs2_pos_x, obs2_pos_y]], device=device)
            if env_index is None:
                obs2_pos_batch = obs2_pos_tensor.repeat(batch_dim, 1)
                self.obstacles[1].set_pos(obs2_pos_batch, batch_index=None)
            else:
                self.obstacles[1].set_pos(obs2_pos_tensor, batch_index=env_index)

            # Place any additional obstacles randomly (if n_obstacles > 2)
            if len(self.obstacles) > 2:
                # print("Placing additional obstacles randomly for secondary exp setup.") # Removed debug print
                self._spawn_obstacles_randomly(env_index, start_index=2)

        elif self.experimental_obstacle_placement:
            # print("Using PRIMARY experimental obstacle placement (fixed positions).") # Removed debug print
            obstacle_positions = [
                torch.tensor([-0.0, 1.5], device=device),
                torch.tensor([0.0, 0.0], device=device),
            ]
            num_to_place = min(self.n_obstacles, len(obstacle_positions))
            for i in range(num_to_place):
                pos_tensor = obstacle_positions[i].unsqueeze(0)
                if env_index is None:
                    pos_batch = pos_tensor.repeat(batch_dim, 1)
                    self.obstacles[i].set_pos(pos_batch, batch_index=None)
                else:
                    self.obstacles[i].set_pos(pos_tensor, batch_index=env_index)

            if self.n_obstacles > num_to_place:
                 # print(f"Warning: Primary exp. placement requested {self.n_obstacles} obstacles, but only {num_to_place} fixed positions defined. Placing remaining randomly.") # Commented out warning
                 remaining_obstacles = self.obstacles[num_to_place:]
                 if remaining_obstacles:
                    self._spawn_obstacles_randomly(env_index, start_index=num_to_place)

        else:
            # Original random spawning logic using helper
            self._spawn_obstacles_randomly(env_index)

    def spawn_walls(self, env_index=None):
        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim
        for i, wall in enumerate(self.walls):
            if i == 0:  # Top wall
                pos = torch.tensor([0.0, world_y_semidim], device=self._world.device,)
                rot = 0.0
            elif i == 1:  # Bottom wall
                pos = torch.tensor([0.0, -world_y_semidim], device=self._world.device,)
                rot = 0.0
            elif i == 2:  # Left wall
                pos = torch.tensor([-world_x_semidim, 0.0], device=self._world.device,)
                rot = torch.pi / 2
            elif i == 3:  # Right wall
                pos = torch.tensor([world_x_semidim, 0.0], device=self._world.device,)
                rot = torch.pi / 2

            wall.set_pos(pos, batch_index=env_index)
            wall.set_rot(torch.tensor([rot], device=self._world.device), batch_index=env_index)

    def reward(self, agent: Agent):
        is_first = agent == self.world.agents[0]

        if is_first:
            for a in self.world.agents:
                self.agent_reward(a)
                a.agent_collision_rew[:] = 0

            # Agent-obstacle collisions
            for a in self.world.agents:
                for obs in self.obstacles:
                    if self._world.collides(a, obs):
                        distance = self._world.get_distance(a, obs)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        obs.collision_timer[collision_mask] = 5

            # Agent-wall collisions
            for a in self.world.agents:
                for wall in self.walls:
                    if self._world.collides(a, wall):
                        distance = self._world.get_distance(a, wall)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        wall.collision_timer[collision_mask] = 5

            # Agent-agent collisions
            for a in self.world.agents:
                for b in self.world.agents:
                    if a == b:
                        continue
                    if self._world.collides(a, b):
                        distance = self._world.get_distance(a, b)
                        collision_mask = distance <= self.collision_gap
                        a.agent_collision_rew[collision_mask] -= self.lambda_u
                        a.collision_timer[collision_mask] = 5
                        b.collision_timer[collision_mask] = 5

            # Decrement collision timers
            for a in self.world.agents:
                if hasattr(a, 'collision_timer'):
                    a.collision_timer = torch.maximum(a.collision_timer - 1, torch.zeros_like(a.collision_timer))

            for wall in self.walls:
                if hasattr(wall, 'collision_timer'):
                    wall.collision_timer = torch.maximum(wall.collision_timer - 1, torch.zeros_like(wall.collision_timer))

            for obs in self.obstacles:
                if hasattr(obs, 'collision_timer'):
                    obs.collision_timer = torch.maximum(obs.collision_timer - 1, torch.zeros_like(obs.collision_timer))

        if self.agent_control_type == QP_CONTROL_CBF:
            total_rew = agent.individual_rew + agent.agent_collision_rew + agent.delta_rew
        else:
            total_rew = agent.individual_rew + agent.agent_collision_rew

        return total_rew

    def agent_reward(self, agent: Agent):
        agent.individual_rew = torch.zeros(
            agent.state.pos.size(0), device=self._world.device
        )
        
        # Initialize formation_rew
        agent.formation_rew = torch.zeros(agent.state.pos.size(0), device=self._world.device)

        if self.agent_control_type == QP_CONTROL_CBF:
            agent.delta_rew = torch.zeros(agent.state.pos.size(0), device=self._world.device)

        # Out-of-bounds penalty
        out_of_bounds_x = torch.abs(agent.state.pos[:, 0]) > self._world.x_semidim
        out_of_bounds_y = torch.abs(agent.state.pos[:, 1]) > self._world.y_semidim
        out_of_bounds = out_of_bounds_x | out_of_bounds_y

        agent.individual_rew += torch.where(out_of_bounds, -self.lambda_u, 0)

        top_region_boundary = self._world.y_semidim - self.top_region_height
        in_target_region = (
            agent.target_direction * agent.state.pos[:, 1] > top_region_boundary
        )

        agent.individual_rew += torch.where(
            in_target_region,
            self.lambda_g,
            0.01 * self.lambda_g * agent.state.vel[:, 1] * agent.target_direction,
        )

        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent.action, 'delta'):
                agent.delta_rew = -self.lambda_u * agent.action.delta
            else:
                # print(f"Warning: Agent {agent.name} has no delta value.") # Commented out warning
                agent.delta_rew = torch.zeros(agent.state.pos.size(0), device=self._world.device)
                pass # ensure block is not empty

        # Connectivity penalty: if any other agent is more than formation_radius away, add negative reward
        if len(self.world.agents) > 1:
            other_positions = torch.stack([a.state.pos for a in self.world.agents if a != agent], dim=1)
            agent_pos_expanded = agent.state.pos.unsqueeze(1)  # (batch_dim, 1, 2)
            distances = torch.norm(other_positions - agent_pos_expanded, dim=-1)  # (batch_dim, n_agents-1)

            # Check if any distance exceeds formation_radius
            too_far = distances > self.formation_radius
            
            # Apply negative formation penalty if too far from ANY other agent
            agent.formation_rew -= self.lambda_u * too_far.any(dim=1)

        return agent.individual_rew

    
    def observation(self, agent: Agent):
        top_region_boundary = self._world.y_semidim - self.top_region_height
        bottom_region_boundary = -self._world.y_semidim + self.bottom_region_height
        distance_to_boundary = torch.where(
            agent.target_direction > 0,
            top_region_boundary - agent.state.pos[:, 1],
            agent.state.pos[:, 1] - bottom_region_boundary,
        )

        obs = {
            "agent_pos": agent.state.pos,
            "agent_vel": agent.state.vel,
            "target_direction": agent.target_direction.unsqueeze(-1),
            "distance_to_boundary": distance_to_boundary.unsqueeze(-1),
        }

        # Add obstacle relative positions and distances
        if self.obstacles:
            # obstacle_positions shape: (batch_dim, n_obstacles, 2)
            obstacle_positions = torch.stack([obs_ent.state.pos for obs_ent in self.obstacles], dim=1)
            agent_pos_expanded = agent.state.pos.unsqueeze(1)  # (batch_dim, 1, 2)
            obstacle_rel_positions = obstacle_positions - agent_pos_expanded  # (batch_dim, n_obstacles, 2)

            # Compute distances to each obstacle
            obstacle_distances = torch.norm(obstacle_rel_positions, dim=2)  # (batch_dim, n_obstacles)

            # Flatten relative positions
            batch_dim = obstacle_rel_positions.shape[0]
            n_obstacles = obstacle_rel_positions.shape[1]
            flattened_obstacle_rel_positions = obstacle_rel_positions.reshape(batch_dim, n_obstacles * 2)

            obs["obstacle_rel_positions"] = flattened_obstacle_rel_positions
            obs["obstacle_distances"] = obstacle_distances
        else:
            obs["obstacle_rel_positions"] = torch.zeros(agent.state.pos.size(0), 0, device=self._world.device)
            obs["obstacle_distances"] = torch.zeros(agent.state.pos.size(0), 0, device=self._world.device)

        # Compute the agent's diameter (max distance to any other agent)
        if len(self.world.agents) > 1:
            # Stack positions of all other agents
            other_positions = torch.stack([a.state.pos for a in self.world.agents if a != agent], dim=1)
            # Compute distances to other agents
            agent_pos_expanded = agent.state.pos.unsqueeze(1)  # (batch_dim, 1, 2)
            distances = torch.norm(other_positions - agent_pos_expanded, dim=-1)  # (batch_dim, n_agents-1)
            agent_diameter = torch.max(distances, dim=1)[0]  # (batch_dim,)
        else:
            # If there's only one agent, diameter can be considered zero
            agent_diameter = torch.zeros(agent.state.pos.size(0), device=self._world.device)

        # Add the agent's diameter to observations
        obs["agent_diameter"] = self.formation_radius - agent_diameter.unsqueeze(-1)  # shape: (batch_dim, 1)

        # Add noise if QP_CONTROL_CONSTRAINT_ROBOMASTER and not deployed
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER and self.robomaster_deployed == False:
            obs["agent_pos"] += torch.randn_like(obs["agent_pos"]) * 0.01
            obs["agent_vel"] += torch.randn_like(obs["agent_vel"]) * 0.01
            obs["distance_to_boundary"] += torch.randn_like(obs["distance_to_boundary"]) * 0.01
            obs["obstacle_rel_positions"] += torch.randn_like(obs["obstacle_rel_positions"]) * 0.01
            obs["obstacle_distances"] += torch.randn_like(obs["obstacle_distances"]) * 0.01
            obs["agent_diameter"] += torch.randn_like(obs["agent_diameter"]) * 0.01

        return obs



    def done(self):
        dones = torch.zeros(
            self._world.batch_dim, device=self._world.device, dtype=torch.bool
        )
        if not self.qp_solver_initialized:
            dones[:] = True
        return dones

    def info(self, agent: Agent) -> Dict[str, Tensor]:
        info = {
            "collision_reward": agent.agent_collision_rew,
        }

        if self.agent_control_type == QP_CONTROL_CBF:
            if hasattr(agent, 'original_action_u'):
                info['original_action_u'] = agent.original_action_u.squeeze(-1)
            else:
                info['original_action_u'] = torch.zeros_like(agent.agent_collision_rew)

        if hasattr(agent, 'delta_rew'):
            info["delta_reward"] = agent.delta_rew
        else:
            info["delta_reward"] = torch.zeros_like(agent.agent_collision_rew)

        # Add the formation_rew to the info
        if hasattr(agent, 'formation_rew'):
            info["formation_reward"] = agent.formation_rew
        else:
            info["formation_reward"] = torch.zeros_like(agent.agent_collision_rew)

        info["individual_reward"] = agent.individual_rew
        info["total_reward_without_delta_rew"] = agent.individual_rew + agent.agent_collision_rew + info["formation_reward"]

        return info

    def init_qp_layer(self):
        if self.agent_control_type == VELOCITY_CONTROL:
            return None

        # Number of "obstacle" entities = other agents + all obstacles
        N_ = (self.n_agents - 1) + self.n_obstacles

        u = cp.Variable(2)
        delta = cp.Variable(1)

        ego_position = cp.Parameter(2)
        ego_velocity = cp.Parameter(2)

        obstacle_positions = cp.Parameter((N_, 2))
        h_values = cp.Parameter(N_)
        dx_vel_0 = cp.Parameter(N_)
        dy_vel_1 = cp.Parameter(N_)

        parameters = [
            ego_position,
            ego_velocity,
            obstacle_positions,
            h_values,
            dx_vel_0,
            dy_vel_1,
        ]

        constraints = []

        for j in range(N_):
            h = h_values[j]
            lgh = 2 * (
                dx_vel_0[j] + dy_vel_1[j] +
                u[0]*(ego_position[0]-obstacle_positions[j,0]) +
                u[1]*(ego_position[1]-obstacle_positions[j,1])
            )

            if self.agent_control_type in [QP_CONTROL_CBF, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
                constraints.append(lgh + h + delta >= 0)
                # For other agents (first n_agents-1), apply formation radius constraint
                if j < (self.n_agents - 1):
                    constraints.append(cp.norm((ego_position + ego_velocity + u - obstacle_positions[j])) <= self.formation_radius + delta)
            else:
                constraints.append(lgh + h >= 0)
                if j < (self.n_agents - 1):
                    constraints.append(cp.norm((ego_position + ego_velocity + u - obstacle_positions[j])) <= self.formation_radius)

        safety_margin = 1.05*(self.agent_radius + self.collision_gap)
        if self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0]) <= self._world.x_semidim - safety_margin + delta,
                    ego_position[0] + (ego_velocity[0] + u[0]) >= -self._world.x_semidim + safety_margin - delta,
                    ego_position[1] + (ego_velocity[1] + u[1]) <= self._world.y_semidim - safety_margin + delta,
                    ego_position[1] + (ego_velocity[1] + u[1]) >= -self._world.y_semidim + safety_margin - delta,
                    cp.norm(u) <= self.max_velocity + delta,
                    delta >= 0,
                ]
            )
        else:
            constraints.extend(
                [
                    ego_position[0] + (ego_velocity[0] + u[0]) <= self._world.x_semidim - safety_margin,
                    ego_position[0] + (ego_velocity[0] + u[0]) >= -self._world.x_semidim + safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1]) <= self._world.y_semidim - safety_margin,
                    ego_position[1] + (ego_velocity[1] + u[1]) >= -self._world.y_semidim + safety_margin,
                    cp.norm(u) <= self.max_velocity,
                    delta >= 0,
                ]
            )

        if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
            a = cp.Parameter(2)
            b = cp.Parameter(1)
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([a, b, x_target, y_target])
            constraints.append(cp.norm(u - a) <= b + delta)
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling*delta
            )
        elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
            x_target = cp.Parameter(1)
            y_target = cp.Parameter(1)
            parameters.extend([x_target, y_target])
            objective = cp.Minimize(
                cp.norm(
                    cp.hstack(
                        [
                            ego_position[0] + (ego_velocity[0] + u[0]) - x_target,
                            ego_position[1] + (ego_velocity[1] + u[1]) - y_target,
                        ]
                    ),
                    2,
                )
                + self.delta_scaling*delta
            )
        elif self.agent_control_type in [QP_CONTROL_CONSTRAINT, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
            # Parameters specific to this type
            a = cp.Parameter(2, name="a")
            b = cp.Parameter(1, name="b")
            target_direction = cp.Parameter(1, name="target_direction")

            parameters.extend([a, b, target_direction]) # Always add a, b, target_direction

            # Make conic constraint conditional using 0*
            if self.constraint_on:
                constraints.append(cp.norm(u - a) <= b + delta)
            else:
                constraints.append(0*cp.norm(u - a) <= 0*b + delta)
            objective = cp.Minimize(-target_direction * (u[1]) + self.delta_scaling*delta)
        elif self.agent_control_type == QP_CONTROL_CBF:
            target_direction = cp.Parameter(1)
            parameters.extend([target_direction])
            objective = cp.Minimize(-target_direction * (u[1]) + self.delta_scaling*delta)
        else:
            raise ValueError(f"Unknown agent control type: {self.agent_control_type}")

        prob = cp.Problem(objective, constraints)
        qp_layer = CvxpyLayer(
            prob,
            parameters=parameters,
            variables=[u, delta],
        )

        return qp_layer

    def solve_qp_layer(
        self,
        parameters_batch,
        solver_args,
    ):
        try:
            with torch.no_grad():
                u_qp_batch, delta_batch = self.qp_layer(
                    *parameters_batch,
                    solver_args=solver_args,
                )
            return u_qp_batch, delta_batch
        except Exception as e:
            print(f"SolverError encountered: {e}. Marking the environment as done.")
            self.qp_solver_initialized = False
            return None, None
        
    def pre_step(self):
        """Store the pre-step states of the agents."""
        if self.agent_control_type == VELOCITY_CONTROL:
            # Store pre-step positions and velocities
            batch_dim = self.world.batch_dim
            n_agents = len(self.world.agents)
            device = self.world.device

            self.pre_step_positions = torch.zeros((batch_dim, n_agents, 2), device=device)
            self.pre_step_velocities = torch.zeros((batch_dim, n_agents, 2), device=device)

            for i, agent in enumerate(self.world.agents):
                self.pre_step_positions[:, i, :] = agent.state.pos
                self.pre_step_velocities[:, i, :] = agent.state.vel

    def post_step(self):
        """Check formation constraint and revert positions if violated."""
        if self.agent_control_type == VELOCITY_CONTROL and len(self.world.agents) > 1:
            # Compute pairwise distances after step
            batch_dim = self.world.batch_dim
            n_agents = len(self.world.agents)
            current_positions = torch.stack([a.state.pos for a in self.world.agents], dim=1)  # (batch_dim, n_agents, 2)

            # For each agent, we check if distance to ANY other agent > formation_radius
            # We'll do this by creating a distance matrix:
            # dist[i, j] = distance between agent i and agent j
            # dist matrix shape: (batch_dim, n_agents, n_agents)
            # Start by expanding dimensions for broadcasting
            pos_expanded_1 = current_positions.unsqueeze(2)  # (batch_dim, n_agents, 1, 2)
            pos_expanded_2 = current_positions.unsqueeze(1)  # (batch_dim, 1, n_agents, 2)

            dist_matrix = torch.norm(pos_expanded_1 - pos_expanded_2, dim=-1)  # (batch_dim, n_agents, n_agents)

            # We don't consider distance to self
            # Check if any distance > formation_radius for each agent
            too_far = dist_matrix > self.formation_radius
            # Ignore self-distance by setting diagonal to False
            eye_mask = torch.eye(n_agents, dtype=torch.bool, device=self.world.device).unsqueeze(0)
            idx = torch.arange(n_agents, device=self.world.device)
            too_far[:, idx, idx] = False

            # For each agent, if any too_far is True, we revert that agent
            revert_mask = too_far.any(dim=2)  # (batch_dim, n_agents)
            # revert_mask[b, i] = True if agent i in batch b violates formation

            # Apply reversion where needed
            for i, agent in enumerate(self.world.agents):
                # For all envs where revert_mask[:, i] is True, revert pos/vel
                mask = revert_mask[:, i]
                if mask.any():
                    agent.state.pos[mask] = self.pre_step_positions[mask, i, :]
                #   agent.state.vel[mask] = self.pre_step_velocities[mask, i, :]

    def process_action(self, agent: Agent):
        if self.agent_control_type == VELOCITY_CONTROL:
            # Clamp the action to its allowed range
            agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, self.u_range)
            agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, self.max_velocity)
            # Zero out actions that are too small
            action_norm = torch.linalg.vector_norm(agent.action.u, dim=1)
            agent.action.u[action_norm < self.min_input_norm] = 0

            vel_is_zero = torch.linalg.vector_norm(agent.action.u, dim=1) < 1e-3
            agent.controller.reset(vel_is_zero)

            # Apply the force (which updates pos and vel)
            agent.controller.process_force()
        else:
            is_first = agent == self.world.agents[0]

            if is_first:
                pos_list = []
                vel_list = []

                for idx, a in enumerate(self.world.agents):
                    pos_list.append(a.state.pos)
                    vel_list.append(a.state.vel)
                    a.original_action_u = a.action.u.clone()

                N_ = (self.n_agents - 1) + self.n_obstacles

                obstacle_positions_list = []
                current_h_values_list = []
                dx_vel_0_list = []
                dy_vel_1_list = []

                for agent_idx in range(self.n_agents):
                    for env_idx in range(self._world.batch_dim):
                        ego_pos = pos_list[agent_idx][env_idx]
                        ego_vel = vel_list[agent_idx][env_idx]

                        other_indices = [i for i in range(self.n_agents) if i != agent_idx]
                        other_pos = torch.stack(
                            [pos_list[i][env_idx] for i in other_indices],
                            dim=0,
                        )

                        obs_positions = torch.stack([obs.state.pos[env_idx] for obs in self.obstacles], dim=0)
                        full_obstacles_pos = torch.cat([other_pos, obs_positions], dim=0)

                        obstacle_positions_list.append(full_obstacles_pos)

                        dist_agents = torch.norm(other_pos - ego_pos, dim=1)**2
                        dist_obstacles = torch.norm(obs_positions - ego_pos.unsqueeze(0), dim=1)**2
                        all_distances = torch.cat([dist_agents, dist_obstacles], dim=0)

                        if self.agent_control_type == "QP_CONTROL_CBF":
                            k_cbf = (self.world.agents[agent_idx].action.u[env_idx][0] + self.k_cbf_range) / 2
                            p_cbf = 1.0
                            h_values = k_cbf * (all_distances - self.min_distance_between_entities ** 2)**p_cbf
                        elif self.agent_control_type == QP_CONTROL_CONSTRAINT_ROBOMASTER:
                            k_cbf = 0.9
                            p_cbf = 1.0
                            h_values = k_cbf * (all_distances - self.min_distance_between_entities ** 2)**p_cbf
                        else:
                            h_values = all_distances - self.min_distance_between_entities ** 2

                        current_h_values_list.append(h_values)

                        dx_agents = ego_pos[0] - other_pos[:, 0]
                        dy_agents = ego_pos[1] - other_pos[:, 1]
                        dx_vel_agents = dx_agents * ego_vel[0]
                        dy_vel_agents = dy_agents * ego_vel[1]

                        dx_obs = ego_pos[0] - obs_positions[:, 0]
                        dy_obs = ego_pos[1] - obs_positions[:, 1]
                        dx_vel_obs = dx_obs * ego_vel[0]
                        dy_vel_obs = dy_obs * ego_vel[1]

                        dx_vel_all = torch.cat([dx_vel_agents, dx_vel_obs], dim=0)
                        dy_vel_all = torch.cat([dy_vel_agents, dy_vel_obs], dim=0)

                        dx_vel_0_list.append(dx_vel_all)
                        dy_vel_1_list.append(dy_vel_all)

                ego_positions_batch = torch.cat(pos_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                ego_velocities_batch = torch.cat(vel_list, dim=0).view(
                    self._world.batch_dim * self.n_agents, 2
                )
                obstacle_positions_batch = torch.stack(obstacle_positions_list, dim=0)
                current_h_values_batch = torch.stack(current_h_values_list, dim=0)
                dx_vel_0_batch = torch.stack(dx_vel_0_list, dim=0)
                dy_vel_1_batch = torch.stack(dy_vel_1_list, dim=0)

                parameters_batch = [
                    ego_positions_batch,
                    ego_velocities_batch,
                    obstacle_positions_batch,
                    current_h_values_batch,
                    dx_vel_0_batch,
                    dy_vel_1_batch,
                ]

                if self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                    a_list = []
                    b_list = []
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        x_target_list.append(a.action.u[:, 3:4])
                        y_target_list.append(a.action.u[:, 4:5])
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, x_target_batch, y_target_batch]
                    )

                elif self.agent_control_type == QP_CONTROL_OBJECTIVE:
                    x_target_list = []
                    y_target_list = []
                    for a in self.world.agents:
                        x_target_list.append(a.action.u[:, 0:1])
                        y_target_list.append(a.action.u[:, 1:2])
                    x_target_batch = torch.cat(x_target_list, dim=0)
                    y_target_batch = torch.cat(y_target_list, dim=0)
                    parameters_batch.extend([x_target_batch, y_target_batch])

                elif self.agent_control_type in [QP_CONTROL_CONSTRAINT, QP_CONTROL_CONSTRAINT_ROBOMASTER]:
                    a_list = []
                    b_list = []
                    target_direction_list = []
                    for a in self.world.agents:
                        a_list.append(a.action.u[:, :2])
                        b_list.append(a.action.u[:, 2:3])
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                    a_batch = torch.cat(a_list, dim=0)
                    b_batch = torch.cat(b_list, dim=0)
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    parameters_batch.extend(
                        [a_batch, b_batch, target_direction_batch]
                    )
                
                elif self.agent_control_type == QP_CONTROL_CBF:
                    target_direction_list = []
                    for a in self.world.agents:
                        target_direction_list.append(a.target_direction.unsqueeze(1))
                    target_direction_batch = torch.cat(target_direction_list, dim=0)
                    parameters_batch.extend(
                        [target_direction_batch]
                    )

                u_qp_batch, delta_batch = self.solve_qp_layer(
                    parameters_batch,
                    solver_args={"eps": 1e-8},
                )

                if u_qp_batch is None or delta_batch is None:
                    print("QP solver failed, applying fallback strategy.")
                    for a in self.world.agents:
                        a.action.u = torch.zeros(
                            (a.action.u.size(0), 2), device=self._world.device
                        )
                        vel_is_zero = torch.linalg.vector_norm(
                            a.action.u[:, :2], dim=1
                        ) < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()
                else:
                    split_sizes = [a.action.u.size(0) for a in self.world.agents]
                    u_qp_list = torch.split(u_qp_batch, split_sizes, dim=0)
                    delta_list = torch.split(delta_batch, split_sizes, dim=0)

                    for i, a in enumerate(self.world.agents):
                        a.action.u = u_qp_list[i]
                        a.action.delta = delta_list[i].squeeze()
                        a.action.u = TorchUtils.clamp_with_norm(a.action.u, self.u_range)
                        action_norm = torch.linalg.vector_norm(a.action.u, dim=1)
                        a.action.u[action_norm < self.min_input_norm] = 0
                        a.computed_velocity = a.action.u.clone()

                        vel_is_zero = action_norm < 1e-3
                        a.controller.reset(vel_is_zero)
                        a.controller.process_force()

    def extra_render(self, env_index: int = 0) -> "List[rendering.Geom]":
        geoms: List[rendering.Geom] = []

        for wall in self.walls:
            if hasattr(wall, 'collision_timer') and wall.collision_timer[env_index] > 0:
                wall_color = (1.0, 0.0, 0.0)
            else:
                wall_color = (0.0, 0.0, 0.0)
            wall._color = wall_color

        world_x_semidim = self._world.x_semidim
        world_y_semidim = self._world.y_semidim

        bottom_region = rendering.make_polygon(
            [
                (-world_x_semidim, -world_y_semidim),
                (world_x_semidim, -world_y_semidim),
                (
                    world_x_semidim,
                    -world_y_semidim + self.bottom_region_height,
                ),
                (
                    -world_x_semidim,
                    -world_y_semidim + self.bottom_region_height,
                ),
            ],
            filled=True,
            draw_border=True,
        )
        bottom_region.set_color(0.3, 0.3, 1, alpha=0.5)
        geoms.append(bottom_region)

        top_region = rendering.make_polygon(
            [
                (-world_x_semidim, world_y_semidim),
                (world_x_semidim, world_y_semidim),
                (
                    world_x_semidim,
                    world_y_semidim - self.top_region_height,
                ),
                (
                    -world_x_semidim,
                    world_y_semidim - self.top_region_height,
                ),
            ],
            filled=True,
            draw_border=True,
        )
        top_region.set_color(0.2, 0.75, 0.2, alpha=0.5)
        geoms.append(top_region)

        # Draw the obstacles
        for obs in self.obstacles:
            if hasattr(obs, 'collision_timer') and obs.collision_timer[env_index] > 0 and not self.disable_collision_viz:
                obstacle_color = (1.0, 0.0, 0.0)
            else:
                obstacle_color = (0.0, 0.0, 0.0)  # black
            obstacle_geom = rendering.make_circle(
                radius=self.obstacle_radius,
                res=20,
                filled=True,
            )
            obstacle_transform = rendering.Transform(translation=obs.state.pos[env_index].cpu().numpy())
            obstacle_geom.add_attr(obstacle_transform)
            obstacle_geom.set_color(*obstacle_color)
            geoms.append(obstacle_geom)

        for idx, agent in enumerate(self._world.agents):
            if hasattr(agent, 'collision_timer') and agent.collision_timer[env_index] > 0 and not self.disable_collision_viz:
                agent_color = (1.0, 0.0, 0.0)
            else:
                agent_color = agent.original_color

            agent.color = agent_color

            if self.agent_control_type in [
                QP_CONTROL_OBJECTIVE,
                QP_CONTROL_OBJECTIVE_AND_CONSTRAINT,
            ]:
                if hasattr(agent, "original_action_u"):
                    if self.agent_control_type == QP_CONTROL_OBJECTIVE:
                        x_t = agent.original_action_u[env_index, 0].cpu().numpy()
                        y_t = agent.original_action_u[env_index, 1].cpu().numpy()
                    elif self.agent_control_type == QP_CONTROL_OBJECTIVE_AND_CONSTRAINT:
                        x_t = agent.original_action_u[env_index, 3].cpu().numpy()
                        y_t = agent.original_action_u[env_index, 4].cpu().numpy()

                    target_pos = (x_t, y_t)
                    target_circle = rendering.make_circle(
                        radius=0.05, res=10, filled=True
                    )
                    target_transform = rendering.Transform(translation=target_pos)
                    target_circle.add_attr(target_transform)
                    target_circle.set_color(0.0, 1.0, 0.0, alpha=0.8)
                    geoms.append(target_circle)

            shade_factor = 1.0
            if self.agent_control_type == QP_CONTROL_CBF:
                k_cbf = agent.original_action_u[env_index][0].item()
                k_cbf = (k_cbf + self.k_cbf_range) / 2
                normalized_k_cbf = k_cbf / self.k_cbf_range
                min_shade = 0.5
                shade_factor = min_shade + (1.0 - min_shade) * normalized_k_cbf

            original_color = agent.color
            shaded_color = (
                original_color[0] * shade_factor,
                original_color[1] * shade_factor,
                original_color[2] * shade_factor,
            )

            agent.color = shaded_color

            agent_geom = rendering.make_circle(
                radius=self.agent_radius,
                res=20,
                filled=True,
            )
            agent_transform = rendering.Transform(translation=agent.state.pos[env_index].cpu().numpy())
            agent_geom.add_attr(agent_transform)
            agent_geom.set_color(*shaded_color)
            geoms.append(agent_geom)

        edge_radius = self.formation_radius
        for i, agent1 in enumerate(self.world.agents):
            for j, agent2 in enumerate(self.world.agents):
                if j <= i:
                    continue
                agent_dist = torch.linalg.vector_norm(
                    agent1.state.pos - agent2.state.pos, dim=-1
                )
                if agent_dist[env_index] <= edge_radius:
                    color = (0.0, 0.0, 0.0)
                    line = rendering.Line(
                        (agent1.state.pos[env_index]),
                        (agent2.state.pos[env_index]),
                        width=1,
                    )
                    xform = rendering.Transform()
                    line.add_attr(xform)
                    line.set_color(*color)
                    geoms.append(line)

        return geoms


if __name__ == "__main__":
    import sys
    import os
    # Add the parent directory (supplementary_material) to sys.path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from utils import use_vmas_env

    for i in range(10):
        use_vmas_env(
            scenario=ConnectivityScenario(),
            render=True,
            save_render=False,
            num_envs=1,
            n_steps=800,
            device="cpu",
            continuous_actions=True,
            random_action=True,
            deterministic_action_value=0.0,
            n_agents=4,
            n_obstacles=2,
        )
