import torch
import numpy as np
from matplotlib.path import Path
import math

from sam_hq2.builder.utils.box_utils import boxes_to_corners_3d
from sam_hq2.builder.utils.common_utils import rotate_points_along_z


def get_crop_coordinates(box_centers, crop_sizes, img_h, img_w):
    crop_x1 = torch.clamp(box_centers[:, 0] - crop_sizes // 2, min=0)
    crop_x2 = torch.clamp(box_centers[:, 0] + crop_sizes // 2, max=img_w)
    crop_y1 = torch.clamp(box_centers[:, 1] - crop_sizes // 2, min=0)
    crop_y2 = torch.clamp(box_centers[:, 1] + crop_sizes // 2, max=img_h)
    return crop_x1, crop_x2, crop_y1, crop_y2


def project_points(points_3d, proj_matrices):
    """
    Args:
        points_3d: (B, N, 3)
        proj_matrices: (B, 4, 4)

    Returns:
        projected_2d: (B, N, 2)
        depth: (B, N)
    """
    B, N, _ = points_3d.shape
    device = points_3d.device

    # (B, N, 4)
    ones = torch.ones((B, N, 1), device=device)
    points_h = torch.cat([points_3d, ones], dim=-1)

    # (B, N, 4, 1)
    points_h = points_h.unsqueeze(-1)

    # (B, 4, N)
    proj_matrices = proj_matrices.unsqueeze(1).expand(-1, N, -1, -1)  # (B, N, 4, 4)
    points_h = points_h  # (B, N, 4, 1)

    # (B, N, 4, 1) = (B, N, 4, 4) @ (B, N, 4, 1)
    proj = torch.matmul(proj_matrices, points_h).squeeze(-1)  # (B, N, 4)

    x, y, z = proj[..., 0], proj[..., 1], proj[..., 2]

    eps = 1e-6
    u = x / (z + eps)
    v = y / (z + eps)

    projected_2d = torch.stack([u, v], dim=-1)  # (B, N, 2)
    depth = z  # (B, N)

    return projected_2d, depth


def compute_2d_bounding_boxes(points_xyz_padded, proj_matrices):
    """
    Args:
        points_xyz_padded: (B, N, 3)
        proj_matrices: (B, 4, 4)

    Returns:
        bboxes_2d: (B, 4) [xmin, ymin, xmax, ymax] per batch
    """
    projected_2d, depth = project_points(
        points_xyz_padded, proj_matrices
    )  # (B, N, 2), (B, N)

    valid_mask = (
        torch.isfinite(projected_2d).all(dim=-1) & torch.isfinite(depth) & (depth > 0)
    )

    bboxes = []
    for i in range(projected_2d.shape[0]):
        valid_points = projected_2d[i][valid_mask[i]]  # (n_i, 2)
        if valid_points.shape[0] == 0:
            bboxes.append(
                torch.tensor([float("nan")] * 4, device=points_xyz_padded.device)
            )
        else:
            xy_min = valid_points.min(dim=0).values  # (2,)
            xy_max = valid_points.max(dim=0).values  # (2,)
            bboxes.append(torch.cat([xy_min, xy_max], dim=0))  # (4,)

    return torch.stack(bboxes)  # (B, 4)


def projection(points, lidar2image):
    device = points.device

    V = lidar2image.shape[0]
    N, P = points.shape[:2]

    ones = torch.ones((N, P, 1), dtype=points.dtype, device=device)
    points_hom = torch.cat([points, ones], dim=2)  # (N, P, 4)
    points_hom = points_hom.unsqueeze(0).expand(V, -1, -1, -1)  # (V, N, P, 4)
    points_hom = points_hom.reshape(V, -1, 4)  # (V, N*P, 4)

    points_2d = torch.einsum(
        "vij,vjk->vik", lidar2image[:, :3, :4], points_hom.transpose(1, 2)
    )  # (V, 3, N*P)
    points_2d = points_2d.view(V, 3, N, P)  # (V, 3, N, P)

    eps = 1e-6
    z = points_2d[:, 2, :, :] + eps  # (V, N, P)
    u = points_2d[:, 0, :, :] / z  # (V, N, P)
    v = points_2d[:, 1, :, :] / z  # (V, N, P)
    points_2d = torch.stack([u, v], dim=-1)  # (V, N, P, 2)

    mask_z = z > 0
    mask_exp = mask_z.unsqueeze(-1).float()  # (V, N, P, 1)
    points_2d = (
        points_2d * mask_exp + (-1) * (~mask_z).unsqueeze(-1).float()
    )  # (V, N, P, 2)

    return points_2d, mask_z, z


def projection_2d_box(
    lidar2image, points, H, W, return_center=False, convert_int=False
):
    points_2d, mask_z, z = projection(points, lidar2image)

    BIG_POS, BIG_NEG = 1.0e10, -1.0e10
    valid_points_min = torch.where(
        mask_z.unsqueeze(-1),
        points_2d,
        torch.full_like(points_2d, BIG_POS, dtype=points.dtype),
    )
    valid_points_max = torch.where(
        mask_z.unsqueeze(-1),
        points_2d,
        torch.full_like(points_2d, BIG_NEG, dtype=points.dtype),
    )

    min_u = torch.min(valid_points_min[..., 0], dim=-1)[0]  # (V, N)
    max_u = torch.max(valid_points_max[..., 0], dim=-1)[0]  # (V, N)
    min_v = torch.min(valid_points_min[..., 1], dim=-1)[0]  # (V, N)
    max_v = torch.max(valid_points_max[..., 1], dim=-1)[0]  # (V, N)

    center_u = 0.5 * (min_u + max_u)
    center_v = 0.5 * (min_v + max_v)

    if convert_int:
        min_u = torch.clamp(min_u, 0, W - 1).to(torch.int64)
        max_u = torch.clamp(max_u, 0, W - 1).to(torch.int64)
        min_v = torch.clamp(min_v, 0, H - 1).to(torch.int64)
        max_v = torch.clamp(max_v, 0, H - 1).to(torch.int64)

        center_u = torch.clamp(center_u, 0, W - 1).to(torch.int64)
        center_v = torch.clamp(center_v, 0, H - 1).to(torch.int64)
    else:
        min_u = torch.clamp(min_u, 0, W - 1)  # .to(torch.int32)
        max_u = torch.clamp(max_u, 0, W - 1)  # .to(torch.int32)
        min_v = torch.clamp(min_v, 0, H - 1)  # .to(torch.int32)
        max_v = torch.clamp(max_v, 0, H - 1)  # .to(torch.int32)

        center_u = torch.clamp(center_u, 0, W - 1)  # .to(torch.int32)
        center_v = torch.clamp(center_v, 0, H - 1)  # .to(torch.int32)
        center = torch.stack([center_u, center_v], dim=-1)

    coordss = torch.stack([min_v, min_u, max_v, max_u], dim=-1)  # (V, N, 4)

    valid_mask = (max_u > min_u) & (max_v > min_v)
    coordss[~valid_mask] = 0

    if return_center:
        return coordss, center
    return coordss


def get_frustum_corners_world(
    x_min, x_max, y_min, y_max, K, R, t, z_near=5, z_far=50.0
):
    corners_2d = torch.tensor(
        [
            # [x_min, y_min],  # top-left
            # [x_max, y_min],  # top-right
            [x_max, y_max],  # bottom-right
            [x_min, y_max],  # bottom-left
        ],
        dtype=torch.float32,
    )

    K_inv = torch.inverse(K)

    frustum_corners_world = []

    for z in [z_near, z_far]:
        for u, v in corners_2d:
            uv1 = torch.tensor([u, v, 1.0], dtype=torch.float32, device=K.device)
            xyz_c_norm = torch.matmul(K_inv, uv1)  # (3,) shape
            xyz_cam = xyz_c_norm * z  # (3,)
            xyz_world = torch.matmul(R, (xyz_cam - t))

            frustum_corners_world.append(xyz_world)

    frustum_corners_world = torch.stack(frustum_corners_world)  # (8, 3)
    return frustum_corners_world


def sort_vertices_counterclockwise(vertices):
    center = vertices.mean(dim=0)
    angles = torch.atan2(vertices[:, 1] - center[1], vertices[:, 0] - center[0])
    sorted_indices = torch.argsort(angles)
    return vertices[sorted_indices]


def expand_polygon(vertices, expansion_factor=2.0):
    from shapely.geometry import Polygon

    polygon = Polygon(vertices.cpu().numpy())
    expanded_polygon = polygon.buffer(expansion_factor, join_style=2)
    return torch.tensor(
        expanded_polygon.exterior.coords[:-1],
        dtype=vertices.dtype,
        device=vertices.device,
    )


def sample_lidar_points_outside_holes(
    main_vertices,
    lidar,
    lidar2image,
    ground_masks,
    H,
    W,
    v,
    iou_mask,
    cv_boxes,
    angle_rad,
    expansion_factor=5,
    num_samples=10,
    max_attempts=10,
):
    sorted_main_vertices = sort_vertices_counterclockwise(main_vertices)
    sorted_main_vertices = expand_polygon(sorted_main_vertices, -expansion_factor)
    main_trapezoid_path = Path(sorted_main_vertices.cpu().numpy())
    inside_main = main_trapezoid_path.contains_points(lidar[:, :2].cpu().numpy())
    lidar = lidar[inside_main]
    lidar = lidar[lidar[:, 2] < 3]
    lidar = lidar[lidar[:, 2] > -3]
    lidar_points = lidar[:, :3]
    permutation = torch.randperm(lidar_points.shape[0])
    lidar_points = lidar_points[permutation]

    radius = torch.sqrt(lidar_points[:, 0] ** 2 + lidar_points[:, 1] ** 2)
    radius_noisy = radius + torch.randn(radius.size(), device=radius.device) * 20
    idx = torch.argsort(radius_noisy)
    lidar_points = lidar_points[idx]
    if not len(lidar_points) == 0:
        attempts = 0
        worlds = []

        while len(worlds) < num_samples and attempts < max_attempts:
            rand_x = torch.empty(lidar_points.shape[0]).uniform_(-0.2, 0.2)
            rand_y = torch.empty(lidar_points.shape[0]).uniform_(-0.2, 0.2)
            points = lidar_points[:, :2] + torch.stack([rand_x, rand_y], dim=1).to(
                lidar_points.device
            )

            inside_main = main_trapezoid_path.contains_points(points.cpu().numpy())
            valid_points = points[inside_main]
            valid_zs = lidar_points[:, 2][inside_main] + cv_boxes[0, 5] / 2

            if valid_points.shape[0] == 0:
                attempts += 1
                continue

            # Generate full random box data
            cv_voxes = torch.cat(
                (
                    cv_boxes[0:1, 3:6],
                    torch.from_numpy(angle_rad).to(lidar_points.device),
                ),
                dim=1,
            ).float()
            cv_voxes = cv_voxes.repeat(valid_zs.shape[0], 1)
            random_box_data = torch.cat(
                (valid_points, valid_zs.view(-1, 1), cv_voxes), dim=1
            ).float()

            # 3D -> 2D projection
            random_box_corner = boxes_to_corners_3d(random_box_data)
            project_2d = projection_2d_box(
                lidar2image, random_box_corner, H, W, convert_int=True
            )

            row_min = torch.clip(project_2d[0, :, 0], 0, H - 1)
            row_max = torch.clip(project_2d[0, :, 2], 0, H - 1)
            col_min = torch.clip(project_2d[0, :, 1], 0, W - 1)
            col_max = torch.clip(project_2d[0, :, 3], 0, W - 1)

            N = row_min.shape[0]
            scores = torch.empty(N, device=iou_mask.device)

            for i in range(N):
                rmin, rmax = row_min[i].item(), row_max[i].item()
                cmin, cmax = col_min[i].item(), col_max[i].item()
                patch_iou = iou_mask[v, rmin:rmax, cmin:cmax]
                scores[i] = patch_iou.sum() / ((rmax - rmin) * (cmax - cmin))

            bottom_corners = random_box_corner[:, [0, 1, 2, 3]]  # (N, 4, 3)
            points_3d = bottom_corners.reshape(1, -1, 3)
            projected_2d, _ = project_points(points_3d, lidar2image)  # (1, N*4, 2)
            projected_2d = projected_2d[0].view(-1, 4, 2)

            rows = torch.clamp(projected_2d[:, :, 1].long(), 0, H - 1)
            cols = torch.clamp(projected_2d[:, :, 0].long(), 0, W - 1)
            ground_vals = ground_masks[v, rows, cols]  # (N, 4)
            fully_on_ground = ground_vals.bool().all(dim=1)
            valid_idxs = torch.nonzero(
                (scores < 0.1) & fully_on_ground, as_tuple=False
            ).squeeze(1)

            if len(valid_idxs) > 0:
                selected_boxes = random_box_corner[valid_idxs]
                worlds.extend(selected_boxes[: (num_samples - len(worlds))])

            attempts += 1
            lidar_points = lidar_points[torch.randperm(lidar_points.size(0))]

        num_missing = num_samples - len(worlds)
        if num_missing > 0:
            existing = (
                torch.stack(worlds) if len(worlds) > 0 else torch.empty((0, 8, 3))
            )
            padding = torch.zeros(
                (num_missing, 8, 3), dtype=existing.dtype, device=existing.device
            )
            return torch.cat([existing, padding], dim=0)
        return torch.stack(worlds[:num_samples])
    else:
        return torch.zeros((num_samples, 8, 3))


def sample_random_box(
    main_vertices,
    ground_masks,
    lidar2image,
    H,
    W,
    v,
    iou_mask,
    cv_boxes,
    angle_rad,
    expansion_factor=5,
    num_samples=10,
    max_attempts=10,
):
    sorted_main_vertices = sort_vertices_counterclockwise(main_vertices)
    sorted_main_vertices = expand_polygon(sorted_main_vertices, -expansion_factor)
    main_trapezoid_path = Path(sorted_main_vertices.cpu().numpy())

    def generate_random_xyz(N):
        x = np.random.uniform(-47, 47, N)
        y = np.random.uniform(-47, 47, N)
        z = np.random.uniform(-3, 1, N)
        return np.stack([x, y, z], axis=1)

    attempts = 0
    worlds = []

    while len(worlds) < num_samples and attempts < max_attempts:
        lidar_points = generate_random_xyz(H * W // 2)
        inside_main = main_trapezoid_path.contains_points(lidar_points[:, :2])
        lidar_points = torch.from_numpy(lidar_points[inside_main]).to(
            ground_masks.device
        )

        if lidar_points.shape[0] == 0:
            attempts += 1
            continue

        permutation = torch.randperm(lidar_points.shape[0])
        lidar_points = lidar_points[permutation]

        valid_zs = lidar_points[:, 2] + cv_boxes[0, 5] / 2
        lidar_points_xy = lidar_points[:, :2]

        target_boxes = torch.cat(
            (cv_boxes[0:1, 3:6], torch.from_numpy(angle_rad).to(lidar_points.device)),
            dim=1,
        ).float()
        target_boxes = target_boxes.repeat(valid_zs.shape[0], 1)
        random_box_data = torch.cat(
            (lidar_points_xy, valid_zs.view(-1, 1), target_boxes), dim=1
        ).float()

        # 3D -> 2D projection
        random_box_corner = boxes_to_corners_3d(random_box_data)
        project_2d = projection_2d_box(
            lidar2image, random_box_corner, H, W, convert_int=True
        )

        row_min = torch.clip(project_2d[0, :, 0], 0, H - 1)
        row_max = torch.clip(project_2d[0, :, 2], 0, H - 1)
        col_min = torch.clip(project_2d[0, :, 1], 0, W - 1)
        col_max = torch.clip(project_2d[0, :, 3], 0, W - 1)

        N = row_min.shape[0]
        scores = torch.empty(N, device=iou_mask.device)

        for i in range(N):
            rmin, rmax = row_min[i].item(), row_max[i].item()
            cmin, cmax = col_min[i].item(), col_max[i].item()
            patch_iou = iou_mask[v, rmin:rmax, cmin:cmax]
            scores[i] = patch_iou.sum() / ((rmax - rmin) * (cmax - cmin))

        bottom_corners = random_box_corner[:, [0, 1, 2, 3]]  # (N, 4, 3)
        points_3d = bottom_corners.reshape(1, -1, 3)
        projected_2d, _ = project_points(points_3d, lidar2image)  # (1, N*4, 2)
        projected_2d = projected_2d[0].view(-1, 4, 2)

        rows = torch.clamp(projected_2d[:, :, 1].long(), 0, H - 1)
        cols = torch.clamp(projected_2d[:, :, 0].long(), 0, W - 1)
        ground_vals = ground_masks[v, rows, cols]  # (N, 4)
        fully_on_ground = ground_vals.bool().all(dim=1)
        valid_idxs = torch.nonzero(
            (scores < 0.3) & fully_on_ground, as_tuple=False
        ).squeeze(1)

        if len(valid_idxs) > 0:
            selected_boxes = random_box_corner[valid_idxs]
            worlds.extend(selected_boxes[: (num_samples - len(worlds))])

        attempts += 1
        lidar_points_xy = lidar_points_xy[torch.randperm(lidar_points_xy.size(0))]

    num_missing = num_samples - len(worlds)
    if num_missing > 0:
        existing = torch.stack(worlds) if len(worlds) > 0 else torch.empty((0, 8, 3))
        padding = torch.zeros(
            (num_missing, 8, 3), dtype=existing.dtype, device=existing.device
        )
        return torch.cat([existing, padding], dim=0)
    return torch.stack(worlds[:num_samples])


def depth_to_pointmap(depth, mask, intrinsic, extrinsic):
    """
    Args:
        depth (torch.Tensor): [v, h, w]
        mask (torch.Tensor): [v, h, w]
        intrinsic (torch.Tensor): [v, 3, 3]
        extrinsic (torch.Tensor): [v, 4, 4]

    Returns:
        point_cloud_map (torch.Tensor): [v, h, w, 4] (xyz + intensity)
    """
    v, h, w = depth.shape
    device = depth.device

    u = (
        torch.arange(w, device=device, dtype=torch.float32)
        .view(1, 1, w)
        .expand(v, h, w)
    )  # [v, h, w]
    v_grid = (
        torch.arange(h, device=device, dtype=torch.float32)
        .view(1, h, 1)
        .expand(v, h, w)
    )  # [v, h, w]
    ones = torch.ones_like(u)  # [v, h, w]

    pixel_coords = torch.stack([u, v_grid, ones], dim=1).reshape(v, 3, -1)  # [v, 3, N]

    K_inv = torch.inverse(intrinsic)  # [v, 3, 3]

    camera_coords = torch.matmul(K_inv, pixel_coords)  # [v, 3, N]
    camera_coords = camera_coords * depth.reshape(v, 1, -1)  # [v, 3, N]

    ones_hom = torch.ones(v, 1, camera_coords.shape[2], device=device)  # [v, 1, N]
    camera_coords_hom = torch.cat([camera_coords, ones_hom], dim=1)  # [v, 4, N]

    world_coords_hom = torch.matmul(extrinsic, camera_coords_hom)  # [v, 4, N]
    world_coords = world_coords_hom[:, :3, :]  # [v, 3, N]

    world_coords = (
        world_coords.view(v, 3, h, w).permute(0, 2, 3, 1).contiguous()
    )  # [v, h, w, 3]
    valid_mask = mask.unsqueeze(-1)  # [v, h, w, 1]
    point_cloud_map = torch.cat([world_coords], dim=-1) * valid_mask  # [v, h, w, 4]

    return point_cloud_map


def points_to_normals(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
    """
    Calculate normal map from point map using batch operations.
    Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.

    Args:
        point (torch.Tensor): shape (B, H, W, 3), point map
        mask (torch.Tensor, optional): shape (B, H, W), validity mask
    Returns:
        normal (torch.Tensor): shape (B, H, W, 3), normal map.
        normal_mask (torch.Tensor, optional): shape (B, H, W), mask for valid normal values.
    """
    B, H, W, _ = point.shape
    device = point.device
    has_mask = mask is not None

    if mask is None:
        mask = torch.ones((B, H, W), dtype=torch.bool, device=device)
    mask_pad = torch.zeros((B, H + 2, W + 2), dtype=torch.bool, device=device)
    mask_pad[:, 1:-1, 1:-1] = mask
    mask = mask_pad

    pts = torch.zeros((B, H + 2, W + 2, 3), dtype=point.dtype, device=device)
    pts[:, 1:-1, 1:-1, :] = point

    up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
    left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :]
    down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
    right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :]

    normal = torch.stack(
        [
            torch.cross(up, left, dim=-1),
            torch.cross(left, down, dim=-1),
            torch.cross(down, right, dim=-1),
            torch.cross(right, up, dim=-1),
        ],
        dim=1,
    )  # Stack along dim=1

    normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-12)

    valid = (
        torch.stack(
            [
                mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2],
                mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1],
                mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:],
                mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1],
            ],
            dim=1,
        )
        & mask[:, None, 1:-1, 1:-1]
    )  # Ensure valid has correct dimensions

    normal = (normal * valid[..., None]).sum(dim=1)  # Changed dim=0 -> dim=1
    normal = normal / (torch.norm(normal, dim=-1, keepdim=True) + 1e-12)

    if has_mask:
        normal_mask = valid.any(dim=1)
        normal = torch.where(
            normal_mask[..., None],
            normal,
            torch.tensor(0.0, device=device, dtype=normal.dtype),
        )
        return normal, normal_mask
    else:
        return normal


def lidar_xy_to_top_coords(
    x, y, res=0.1, side_range=(-52.0, 52 - 0.05), fwd_range=(-52.0, 52 - 0.05)
):
    x_img = (-y / res).to(torch.int32)
    y_img = (-x / res).to(torch.int32)

    x_img -= int(torch.floor(torch.tensor(side_range[0] / res)))
    y_img += int(torch.floor(torch.tensor(fwd_range[1] / res)))

    return x_img, y_img


def point_cloud_2_top(
    points,
    res=0.1,
    zres=1.0,
    side_range=(-52.0, 52 - 0.05),
    fwd_range=(-52.0, 52 - 0.05),
    height_range=(-2.0, 1.0),
):
    device = points.device

    x_points = points[:, 0]
    y_points = points[:, 1]
    z_points = points[:, 2]
    reflectance = points[:, 3]

    x_max = int((side_range[1] - side_range[0]) / res)
    y_max = int((fwd_range[1] - fwd_range[0]) / res)
    z_max = int((height_range[1] - height_range[0]) / zres)

    top = torch.zeros(
        (y_max + 1, x_max + 1, z_max + 1), dtype=torch.float32, device=device
    )

    f_filt = torch.logical_and(x_points > fwd_range[0], x_points < fwd_range[1])
    s_filt = torch.logical_and(y_points > -side_range[1], y_points < -side_range[0])
    filt = torch.logical_and(f_filt, s_filt)

    x_points, y_points, z_points, reflectance = (
        x_points[filt],
        y_points[filt],
        z_points[filt],
        reflectance[filt],
    )

    height_bins = torch.arange(height_range[0], height_range[1], zres, device=device)

    for i, height in enumerate(height_bins):
        z_filt = torch.logical_and(z_points >= height, z_points < height + zres)
        indices = torch.nonzero(z_filt, as_tuple=True)[0]

        xi_points = x_points[indices]
        yi_points = y_points[indices]
        zi_points = z_points[indices]
        ref_i = reflectance[indices]

        x_img = (-yi_points / res).to(torch.int32)
        y_img = (-xi_points / res).to(torch.int32)

        x_img -= int(torch.floor(torch.tensor(side_range[0] / res)))
        y_img += int(torch.floor(torch.tensor(fwd_range[1] / res)))

        valid_mask = (x_img >= 0) & (x_img < x_max) & (y_img >= 0) & (y_img < y_max)
        x_img, y_img, zi_points, ref_i = (
            x_img[valid_mask],
            y_img[valid_mask],
            zi_points[valid_mask],
            ref_i[valid_mask],
        )

        top[y_img, x_img, i] = zi_points - height_range[0]
        top[y_img, x_img, z_max] = ref_i

    top = (top / top.max() * 255).to(torch.uint8)
    return top


def scale_depth_map(depth_map, new_min, new_max):
    old_min = depth_map.min()
    old_max = depth_map.max()

    scaled_depth = (depth_map - old_min) / (old_max - old_min)
    scaled_depth = scaled_depth * (new_max - new_min) + new_min

    return scaled_depth


def get_filtered_values(values):
    q1, q3 = torch.quantile(values, torch.tensor([0.25, 0.75], device=values.device))
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    return values[(values >= lower_bound) & (values <= upper_bound)]


def pad_points_list_to_tensor(points_xyz_list):
    """
    Args:
        points_xyz_list: list of (N_i, 3) Tensors

    Returns:
        padded_tensor: (B, max_N, 3) Tensor
        mask: (B, max_N) BoolTensor indicating valid points
    """
    B = len(points_xyz_list)
    device = points_xyz_list[0].device
    dtype = points_xyz_list[0].dtype
    lengths = [p.shape[0] for p in points_xyz_list]
    max_len = max(lengths)
    feature_dim = points_xyz_list[0].shape[1]

    padded = torch.zeros((B, max_len, feature_dim), dtype=dtype, device=device)
    mask = torch.zeros((B, max_len), dtype=torch.bool, device=device)

    for i, p in enumerate(points_xyz_list):
        N = p.shape[0]
        padded[i, :N] = p
        mask[i, :N] = True
        if N < max_len:
            padded[i, N:] = p[-1]  # pad with last value

    return padded, mask


def compute_box(points_xyz_batch, heading=None):
    """
    Args:
        points_xyz_batch (Tensor): shape=(B, N, 3)

    Returns:
        boxes: (B, 7) torch.Tensor
            Each row: [cx, cy, cz, length_x, length_y, height, heading]
    """
    B, N, _ = points_xyz_batch.shape
    device = points_xyz_batch.device
    points_xyz_batch = points_xyz_batch.float()

    z_vals = points_xyz_batch[:, :, 2]
    z_min = z_vals.min(dim=1).values
    z_max = z_vals.max(dim=1).values
    height = z_max - z_min  # (B,)

    xy = points_xyz_batch[:, :, :2]  # (B, N, 2)
    mean_xy = xy.mean(dim=1, keepdim=True)  # (B, 1, 2)
    xy_centered = xy - mean_xy  # (B, N, 2)

    # Compute covariance matrix per batch
    if heading == None:
        xy_t = xy_centered.transpose(1, 2)  # (B, 2, N)
        cov = torch.bmm(xy_t, xy_centered) / (N - 1)  # (B, 2, 2)

        eigvals, eigvecs = torch.linalg.eigh(cov)  # eigvecs: (B, 2, 2)
        idx_desc = eigvals.argsort(dim=-1, descending=True)
        batch_indices = torch.arange(B, device=device).unsqueeze(-1)

        eigvecs = eigvecs[batch_indices, :, idx_desc]  # (B, 2, 2)
        main_dir = eigvecs[:, :, 0]  # (B, 2)

        heading = torch.atan2(main_dir[:, 1], main_dir[:, 0])  # (B,)
    else:
        if isinstance(heading, (float, int)):
            heading = torch.tensor([heading] * B, device=device)
        elif isinstance(heading, np.ndarray):
            heading = torch.tensor(heading, device=device)
        elif isinstance(heading, list):
            heading = torch.tensor(heading, device=device)
        elif isinstance(heading, torch.Tensor):
            heading = heading.to(device)
        else:
            raise TypeError(
                "heading must be float, int, np.ndarray, list, or torch.Tensor"
            )
        heading = heading.view(-1).float()

    cos_t = torch.cos(-heading)
    sin_t = torch.sin(-heading)
    R_inv = torch.stack(
        [torch.stack([cos_t, -sin_t], dim=-1), torch.stack([sin_t, cos_t], dim=-1)],
        dim=1,
    )  # (B, 2, 2)

    xy_rot = torch.bmm(xy_centered, R_inv)  # (B, N, 2)

    min_xy_rot = xy_rot.min(dim=1).values  # (B, 2)
    max_xy_rot = xy_rot.max(dim=1).values  # (B, 2)
    box_size_2d = max_xy_rot - min_xy_rot  # (B, 2)

    center_xy_rot = 0.5 * (min_xy_rot + max_xy_rot)  # (B, 2)

    R_fwd = R_inv.transpose(1, 2)  # (B, 2, 2)
    center_xy_centered = torch.bmm(center_xy_rot.unsqueeze(1), R_fwd).squeeze(
        1
    )  # (B, 2)
    center_xy_abs = center_xy_centered + mean_xy.squeeze(1)  # (B, 2)

    center_z = 0.5 * (z_min + z_max)  # (B,)
    center_3d = torch.cat([center_xy_abs, center_z.unsqueeze(1)], dim=1)  # (B, 3)

    size_3d = torch.cat([box_size_2d, height.unsqueeze(1)], dim=1)  # (B, 3)

    boxes = torch.cat([center_3d, size_3d, heading.unsqueeze(1)], dim=1)  # (B, 7)
    return boxes


def update_bbox_dimension(bboxes, delta_d, dim="l"):
    """
    bboxes: (B, 7) tensor [x, y, z, w, l, h, heading]
    delta_d: scalar or (B,) tensor
    dim: 'w' or 'l'
    Returns:
        updated_bboxes: (B, 7)
    """
    device = bboxes.device
    dtype = bboxes.dtype
    B = bboxes.shape[0]

    x, y, z, w, l, h, heading = [bboxes[:, i] for i in range(7)]

    if isinstance(delta_d, (float, int)):
        delta_d = torch.full((B,), float(delta_d), device=device, dtype=dtype)

    forward_x = torch.cos(heading)
    forward_y = torch.sin(heading)
    lateral_x = torch.sin(heading)
    lateral_y = -torch.cos(heading)

    if dim == "w":
        direction_x = forward_x
        direction_y = forward_y
        new_w = w + delta_d
        new_l = l
    elif dim == "l":
        direction_x = lateral_x
        direction_y = lateral_y
        new_l = l + delta_d
        new_w = w
    else:
        raise ValueError("dim must be 'w' or 'l'")

    delta_x = 0.5 * delta_d * direction_x
    delta_y = 0.5 * delta_d * direction_y

    dist_old = torch.sqrt(x**2 + y**2)
    dist_new = torch.sqrt((x + delta_x) ** 2 + (y + delta_y) ** 2)
    flip_mask = dist_new < dist_old

    delta_x[flip_mask] *= -1
    delta_y[flip_mask] *= -1

    new_x = x + delta_x
    new_y = y + delta_y

    updated_bboxes = torch.stack([new_x, new_y, z, new_w, new_l, h, heading], dim=1)
    return updated_bboxes


def simulate_car_motion(
    boxes, sweep=10, step_forward=0.2, step_heading=math.radians(0.0)
):
    """
    Args:
        boxes: (B, 7) tensor [x, y, z, w, l, h, heading]
        sweep: number of steps
        step_forward: forward distance per step
        step_heading: heading increment per step (in radians)

    Returns:
        all_boxes: (B, sweep, 7)
        all_corners: (B, sweep, 8, 3)
    """
    B = boxes.shape[0]
    device = boxes.device
    dtype = boxes.dtype

    steps = torch.arange(sweep, device=device, dtype=dtype).view(1, sweep)

    init_heading = boxes[:, 6].unsqueeze(1)  # (B, 1)
    headings = init_heading + step_heading * steps  # (B, T)

    zero_heading = init_heading  # (B, 1)
    headings = torch.cat([zero_heading, headings[:, :-1]], dim=1)

    dx = step_forward * torch.cos(headings)  # (B, T)
    dy = step_forward * torch.sin(headings)  # (B, T)

    zero_step = torch.zeros((B, 1), device=device, dtype=dtype)
    dx = torch.cat([zero_step, dx[:, :-1]], dim=1)
    dy = torch.cat([zero_step, dy[:, :-1]], dim=1)

    cum_dx = dx.cumsum(dim=1)
    cum_dy = dy.cumsum(dim=1)

    x0 = boxes[:, 0].unsqueeze(1)
    y0 = boxes[:, 1].unsqueeze(1)
    z0 = boxes[:, 2].unsqueeze(1)

    x_all = x0 + cum_dx
    y_all = y0 + cum_dy
    z_all = z0.expand(-1, sweep)

    w = boxes[:, 3].unsqueeze(1).expand(-1, sweep)
    l = boxes[:, 4].unsqueeze(1).expand(-1, sweep)
    h = boxes[:, 5].unsqueeze(1).expand(-1, sweep)

    all_boxes = torch.stack(
        [x_all, y_all, z_all, w, l, h, headings], dim=-1
    )  # (B, T, 7)

    # 5. 3D box corners (N, 7) → (N, 8, 3)
    all_corners = boxes_to_corners_3d(all_boxes.view(-1, 7)).view(B, sweep, 8, 3)

    return all_boxes, all_corners


def move_points_along_with_box(points, old_box, new_box):

    cx_old, cy_old, cz_old, dx_old, dy_old, dz_old, heading_old = old_box
    cx_new, cy_new, cz_new, dx_new, dy_new, dz_new, heading_new = new_box

    points[:, 0] -= cx_old
    points[:, 1] -= cy_old
    points[:, 2] -= cz_old

    points = rotate_points_along_z(
        points.unsqueeze(0), -heading_old.unsqueeze(0)
    ).squeeze(0)
    points = rotate_points_along_z(
        points.unsqueeze(0), heading_new.unsqueeze(0)
    ).squeeze(0)

    points[:, 0] += cx_new
    points[:, 1] += cy_new
    points[:, 2] += cz_new

    return points


def select_uniform_points(points, x_bin=3, y_bin=3, z_bin=3):

    N = points.shape[0]
    device = points.device

    min_xyz, _ = points.min(dim=0)
    max_xyz, _ = points.max(dim=0)
    eps = 1e-6
    max_xyz += eps

    x_idx = (
        ((points[:, 0] - min_xyz[0]) / (max_xyz[0] - min_xyz[0]) * x_bin)
        .long()
        .clamp(max=x_bin - 1)
    )
    y_idx = (
        ((points[:, 1] - min_xyz[1]) / (max_xyz[1] - min_xyz[1]) * y_bin)
        .long()
        .clamp(max=y_bin - 1)
    )
    z_idx = (
        ((points[:, 2] - min_xyz[2]) / (max_xyz[2] - min_xyz[2]) * z_bin)
        .long()
        .clamp(max=z_bin - 1)
    )

    voxel_idx = x_idx * (y_bin * z_bin) + y_idx * z_bin + z_idx  # (N,)

    sorted_idx = torch.argsort(voxel_idx)
    voxel_sorted = voxel_idx[sorted_idx]
    keep = torch.ones_like(voxel_sorted, dtype=torch.bool)
    keep[1:] = voxel_sorted[1:] != voxel_sorted[:-1]  # unique first occurrence

    selected_indices = sorted_idx[keep]

    num_required = x_bin * y_bin * z_bin
    num_missing = num_required - selected_indices.shape[0]
    if num_missing > 0:
        additional_indices = torch.randint(0, N, (num_missing,), device=device)
        selected_indices = torch.cat([selected_indices, additional_indices], dim=0)

    return selected_indices
