import torch
import sam_hq2.grounding_dino.groundingdino.datasets.transforms as T
from sam_hq2.builder.my_utils_torch import *


def is_invalid_frame(depth, mask):
    return torch.isnan(depth).any() or torch.isinf(depth).any() or mask.sum() == 0


def load_tr_image(image_source, device):
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image_transformed, _ = transform(image_source, None)
    return image_transformed.to(device)


CAM = [
    "CAM_FRONT",
    "CAM_FRONT_RIGHT",
    "CAM_FRONT_LEFT",
    "CAM_BACK",
    "CAM_BACK_LEFT",
    "CAM_BACK_RIGHT",
]
num_view = len(CAM)


def update_infos(
    w, h, valid_inds, K, lidar2camera, crop_x1, crop_x2, crop_y1, crop_y2, valid_boxes
):
    f_x = K[valid_inds, 0, 0]
    f_y = K[valid_inds, 1, 1]
    c_x = K[valid_inds, 0, 2]
    c_y = K[valid_inds, 1, 2]

    crop_width = crop_x2 - crop_x1
    crop_height = crop_y2 - crop_y1
    s_x = w / crop_width
    s_y = h / crop_height

    K_crop_resize = torch.stack(
        [
            torch.stack(
                [f_x * s_x, torch.zeros_like(f_x), (c_x - crop_x1) * s_x], dim=-1
            ),
            torch.stack(
                [torch.zeros_like(f_y), f_y * s_y, (c_y - crop_y1) * s_y], dim=-1
            ),
            torch.tensor([0, 0, 1], dtype=torch.float32, device=K.device).repeat(
                len(f_x), 1
            ),
        ],
        dim=1,
    )

    resized_lidar2image = (
        torch.eye(4, device=K.device).unsqueeze(0).repeat(len(valid_inds), 1, 1)
    )
    resized_lidar2image[:, :3, :3] = K_crop_resize
    resized_lidar2image = resized_lidar2image @ lidar2camera[valid_inds].transpose(1, 2)

    scale_x = w / (crop_x2 - crop_x1)
    scale_y = h / (crop_y2 - crop_y1)

    resized_valid_boxes = torch.stack(
        [
            (valid_boxes[:, 0] - crop_x1) * scale_x,
            (valid_boxes[:, 1] - crop_y1) * scale_y,
            (valid_boxes[:, 2] - crop_x1) * scale_x,
            (valid_boxes[:, 3] - crop_y1) * scale_y,
        ],
        dim=-1,
    ).clamp(0, w)

    return K_crop_resize, resized_lidar2image, resized_valid_boxes
