"""
Based on:
https://github.com/atb033/multi_agent_path_planning/
"""

import numpy as np

# hyperparameters
DT = 0.1
ROBOT_RADIUS = 0.3
V_MAX = 1
V_DESIRED = 1
WORKSPACE_X_LIM: tuple[float, float] = (-1.0, 1.0)
WORKSPACE_Y_LIM: tuple[float, float] = (-3.0, 3.0)
TARGET_REGION_Y_LOWER = WORKSPACE_Y_LIM[0] + 0.5
TARGET_REGION_Y_UPPER = WORKSPACE_Y_LIM[1] - 0.5
VISIBILITY_DIST = 1.5
TARGET_REGION = 2.0


# main
def compute_velocities(
    robots: np.ndarray, target_directions: np.ndarray
) -> np.ndarray | None:
    """
    robots: [px, py, vx, vy, direciton] * N
    v_desired: [vx, vy] * N
    """
    vels = []
    for i in range(len(robots)):
        robot = robots[i]

        # compute target velocity
        target_direction = target_directions[i]
        v_desired = np.array([0, V_DESIRED]) * target_direction
        py = robot[1]
        if (target_direction > 0 and py > TARGET_REGION_Y_UPPER) or (
            target_direction < 0 and py < TARGET_REGION_Y_LOWER
        ):
            v_desired = np.zeros(2)

        # create obstacles
        obstacles = []
        for j in range(len(robots)):
            if j == i:
                continue
            if np.linalg.norm(robot[[0, 1]] - robots[j, [0, 1]]) <= VISIBILITY_DIST:
                obstacles.append(robots[j])
        obstacles = np.array(obstacles)

        # compute target vel
        v = compute_velocity(robot, obstacles, v_desired)

        # failure case
        if v is None:
            return None

        vels.append(v)

    return np.array(vels)


def compute_velocities_with_goals(
    positions: np.ndarray, velocities: np.ndarray, goals: np.ndarray, dt: float = DT, 
    max_speed: float = V_MAX, visibility_radius: float = VISIBILITY_DIST, 
    target_region: float = TARGET_REGION, max_iters: int = 3, simple_mode: bool = False
) -> np.ndarray | None:
    """
    Compute velocities using RVO for robots moving toward specific goals.
    
    Parameters:
    - positions: numpy array of shape (n_robots, 2) containing (px, py) for each robot
    - velocities: numpy array of shape (n_robots, 2) containing (vx, vy) for each robot
    - goals: numpy array of shape (n_robots, 2) containing (gx, gy) goal positions for each robot
    - dt: time step for simulation
    - max_speed: maximum speed allowed for each robot
    - visibility_radius: radius for obstacle detection
    - target_region: distance to goal to be considered as reached
    - max_iters: maximum iterations to try computing RVO
    - simple_mode: if True, use a simplified RVO computation that's faster but less accurate
    
    Returns:
    - numpy array of shape (n_robots, 2) with computed velocities (vx, vy) for each robot
    - None if computation failed
    """
    n_robots = positions.shape[0]
    
    # Check if input arrays are empty
    if n_robots == 0:
        return np.zeros((0, 2))
    
    # Handle simple mode if requested (faster but less accurate)
    if simple_mode:
        return compute_velocities_simple(positions, velocities, goals, max_speed, 
                                        visibility_radius, target_region)
    
    # Main computation - similar to compute_velocities structure
    vels = []
    for i in range(n_robots):
        # Extract robot data
        robot_pos = positions[i]
        robot_vel = velocities[i]
        goal = goals[i]
        
        # Create robot array in format [px, py, vx, vy]
        robot = np.concatenate([robot_pos, robot_vel])
        
        # Compute desired velocity toward goal
        to_goal = goal - robot_pos
        dist_to_goal = np.linalg.norm(to_goal)
        
        # If close to goal, reduce speed
        if dist_to_goal < target_region:
            # Scale down velocity when near the goal
            scale = dist_to_goal / target_region
            v_desired = (to_goal / dist_to_goal) * max_speed * scale if dist_to_goal > 1e-6 else np.zeros(2)
        else:
            # Full speed toward goal
            v_desired = (to_goal / dist_to_goal) * max_speed if dist_to_goal > 1e-6 else np.zeros(2)
            
        # Create obstacles (other robots within visibility distance)
        obstacles = []
        for j in range(n_robots):
            if j == i:
                continue
            
            # Check if robot is within visibility radius
            if np.linalg.norm(robot_pos - positions[j]) <= visibility_radius:
                # Add obstacle in format [px, py, vx, vy]
                obstacles.append(np.concatenate([positions[j], velocities[j]]))
                
        # Convert obstacles to numpy array if any were found
        if obstacles:
            obstacles = np.array(obstacles)
            
            # Compute target velocity using RVO
            v = compute_velocity(robot, obstacles, v_desired)
            
            # If computation failed, use a simpler velocity
            if v is None:
                v = v_desired * 0.5  # Fallback to slower desired velocity
        else:
            # No obstacles, use desired velocity
            v = v_desired
            
        vels.append(v)
        
    # Check if all computations succeeded
    if len(vels) < n_robots:
        print("RVO computation failed for some robots")
        return None
        
    return np.array(vels)


def compute_velocities_simple(
    positions: np.ndarray, velocities: np.ndarray, goals: np.ndarray, 
    max_speed: float = V_MAX, visibility_radius: float = VISIBILITY_DIST,
    target_region: float = TARGET_REGION
) -> np.ndarray:
    """
    Simplified version of RVO that's faster but less accurate.
    Used when simple_mode=True in compute_velocities_with_goals.
    """
    n_robots = positions.shape[0]
    output_velocities = np.zeros_like(positions)
    
    # First compute desired velocities toward goals
    for i in range(n_robots):
        # Vector to goal
        to_goal = goals[i] - positions[i]
        dist_to_goal = np.linalg.norm(to_goal)
        
        # Compute desired velocity
        if dist_to_goal > 1e-6:
            desired_vel = to_goal / dist_to_goal * max_speed
            # Scale down if near goal
            if dist_to_goal < target_region:
                desired_vel *= dist_to_goal / target_region
        else:
            desired_vel = np.zeros(2)
            
        output_velocities[i] = desired_vel
    
    # Then adjust for collision avoidance
    for i in range(n_robots):
        nearby_robots = []
        
        # Find nearby robots
        for j in range(n_robots):
            if j == i:
                continue
                
            # Check if within visibility radius
            dist = np.linalg.norm(positions[i] - positions[j])
            if dist <= visibility_radius:
                nearby_robots.append(j)
                
        # Apply simple collision avoidance if there are nearby robots
        if nearby_robots:
            # 1. Reduce speed
            output_velocities[i] *= 0.9
            
            # 2. Add perpendicular component to avoid head-on collisions
            desired_vel = output_velocities[i].copy()
            perp = np.array([-desired_vel[1], desired_vel[0]])
            perp_norm = np.linalg.norm(perp)
            
            if perp_norm > 1e-6:
                perp = perp / perp_norm * max_speed * 0.1
                output_velocities[i] += perp
                
            # 3. Ensure speed limit
            speed = np.linalg.norm(output_velocities[i])
            if speed > max_speed:
                output_velocities[i] = output_velocities[i] / speed * max_speed
                
    return output_velocities


def compute_velocity(
    robot: np.ndarray, obstacles: np.ndarray, v_desired: np.ndarray
) -> np.ndarray | None:
    """
    robot: [px, py, vx, vy]
    obstacle: [px, py, vx, vy] * (N - 1)
    """

    pA = robot[[0, 1]]
    vA = robot[[2, 3]]

    # Create search-space first, to handle case of no obstacles
    th = np.linspace(0, 2 * np.pi, 100)
    vel = np.linspace(0, V_MAX, 5)
    vv, thth = np.meshgrid(vel, th)
    vx_sample = (vv * np.cos(thth)).flatten()
    vy_sample = (vv * np.sin(thth)).flatten()
    v_sample = np.stack((vx_sample, vy_sample))

    # If no obstacles, just return the desired velocity (clamped to V_MAX)
    if obstacles.size == 0:
        desired_norm = np.linalg.norm(v_desired)
        if desired_norm > V_MAX:
            return v_desired / desired_norm * V_MAX
        return v_desired

    # Compute the constraints for each velocity obstacles
    number_of_obstacles = np.shape(obstacles)[0]
    Amat = np.empty((number_of_obstacles * 2, 2))
    bvec = np.empty((number_of_obstacles * 2))
    
    try:
        for i in range(number_of_obstacles):
            obstacle = obstacles[i, :]
            pB = obstacle[[0, 1]]
            vB = obstacle[[2, 3]]
            dispBA = pA - pB
            distBA = np.linalg.norm(dispBA)
            thetaBA = np.arctan2(dispBA[1], dispBA[0])
            if 2.2 * ROBOT_RADIUS > distBA:
                distBA = 2.2 * ROBOT_RADIUS
            phi_obst = np.arcsin(2.2 * ROBOT_RADIUS / distBA)
            phi_left = thetaBA + phi_obst
            phi_right = thetaBA - phi_obst

            # VO
            translation = vB
            Atemp, btemp = create_constraints(translation, phi_left, "left")
            Amat[i * 2, :] = Atemp
            bvec[i * 2] = btemp
            Atemp, btemp = create_constraints(translation, phi_right, "right")
            Amat[i * 2 + 1, :] = Atemp
            bvec[i * 2 + 1] = btemp
    except Exception as e:
        print(f"Error computing constraints: {e}")
        return None

    # bound constraints
    px, py = pA
    vx_lb = (WORKSPACE_X_LIM[0] - px) / DT
    vx_ub = (WORKSPACE_X_LIM[1] - px) / DT
    vy_lb = (WORKSPACE_Y_LIM[0] - py) / DT
    vy_ub = (WORKSPACE_Y_LIM[1] - py) / DT
    
    try:
        mask = (
            (vx_lb < v_sample[0, :])
            & (v_sample[0, :] < vx_ub)
            & (vy_lb < v_sample[1, :])
            & (v_sample[1, :] < vy_ub)
        )
        v_sample = v_sample[:, np.where(mask)[0]]
    except Exception as e:
        print(f"Error applying boundary constraints: {e}")
        return None

    # Check if search space is empty after boundary constraints
    if v_sample.size == 0:
       # print("Search space empty after boundary constraints")
        return None

    # VO constraints
    try:
        v_sample = check_constraints(v_sample, Amat, bvec)
    except Exception as e:
        print(f"Error checking VO constraints: {e}")
        return None

    # Check if search space is empty after VO constraints
    if v_sample.size == 0:
      #  print("Search space empty after VO constraints")
        return None

    # Objective function
    try:
        size = np.shape(v_sample)[1]
        diffs = v_sample - ((v_desired).reshape(2, 1) @ np.ones(size).reshape(1, size))
        norm = np.linalg.norm(diffs, axis=0)
        min_index = np.where(norm == np.amin(norm))[0][0]
        cmd_vel = v_sample[:, min_index]
        return cmd_vel
    except Exception as e:
        print(f"Error computing objective function: {e}")
        return None


def check_constraints(v_sample, Amat, bvec):
    # Check if input array is empty
    if v_sample is None or v_sample.size == 0:
        return np.zeros((2, 0))  # Return empty array with proper shape
        
    length = np.shape(bvec)[0]

    for i in range(int(length / 2)):
        v_sample = check_inside(
            v_sample, Amat[2 * i : 2 * i + 2, :], bvec[2 * i : 2 * i + 2]
        )
        # If v_sample became empty after applying constraints, break early
        if v_sample.size == 0:
            break

    return v_sample


def check_inside(v, Amat, bvec):
    # Check if v is empty or has incorrect shape
    if v is None or np.size(v) == 0:
        return np.zeros((2, 0))  # Return empty array with proper shape (2, 0)
    
    # Ensure v has proper shape
    if len(np.shape(v)) == 1:
        # If v is just a 1D array, reshape it to a column vector
        v = v.reshape((-1, 1))
    
    v_out = []
    for i in range(np.shape(v)[1]):
        if not ((Amat @ v[:, i] < bvec).all()):
            v_out.append(v[:, i])
    
    # Check if v_out is empty
    if not v_out:
        return np.zeros((2, 0))  # Return empty array with proper shape
    
    return np.array(v_out).T


def create_constraints(translation, angle, side):
    # create line
    origin = np.array([0, 0, 1])
    point = np.array([np.cos(angle), np.sin(angle)])
    line = np.cross(origin, point)
    line = translate_line(line, translation)

    if side == "left":
        line *= -1

    A = line[:2]
    b = -line[2]

    return A, b


def translate_line(line, translation):
    matrix = np.eye(3)
    matrix[2, :2] = -translation[:2]
    return matrix @ line


if __name__ == "__main__":
    # define initial states
    robots = np.array(
        [
            [0, -1.5, 0, 0],
            [0, 1, 0, 0],
            #
            [0.5, -1, 0, 0],
            [-0.5, 1.3, 0, 0],
            #
            # [-0.5, -1.8, 0, 0],
            # [0.5, 2.0, 0, 0],
        ]
    )
    # 1: go up, -1: down
    target_directions = np.array([1, -1, 1, -1, 1, -1])

    # simulation
    from tqdm import tqdm

    np.random.seed(42)

    num_robots = len(robots)
    states = np.copy(robots)
    hist = [states.copy()]
    for t in tqdm(range(100)):
        vels = compute_velocities(states, target_directions)
        if vels is None:
            print(f"t={t}, failed to find feasible velocities, use zero vel")
            vels = np.zeros((num_robots, 2))
        states[:, [0, 1]] += vels * DT
        states[:, [2, 3]] = vels

        # add noise
        states += np.random.rand(*states.shape) * 0.01 - 0.005

        hist.append(states.copy())
    hist = np.array(hist)

    # visualise
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    from matplotlib.patches import Circle

    cmap = matplotlib.colormaps["jet"]

    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_xlim(WORKSPACE_X_LIM[0] - ROBOT_RADIUS, WORKSPACE_X_LIM[1] + ROBOT_RADIUS)
    ax.set_ylim(WORKSPACE_Y_LIM[0] - ROBOT_RADIUS, WORKSPACE_Y_LIM[1] + ROBOT_RADIUS)
    ax.grid(True)

    circles = []
    for i in range(num_robots):
        traj = hist[:, i, [0, 1]]
        color = cmap(i / num_robots)
        ax.plot(*traj.T, color=color)
        c = Circle((robots[i, [0, 1]]), ROBOT_RADIUS, color=color)
        ax.add_patch(c)
        circles.append(c)

    def animate(t):
        for i in range(num_robots):
            circles[i].center = hist[t, i, [0, 1]]

    ani = animation.FuncAnimation(fig, animate, range(len(hist)), interval=30)
    ani.save("demo.gif", fps=30)
    plt.show()
