import os
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision.ops import box_convert

from tqdm import tqdm
import numpy as np
from PIL import Image
import random
import base64
from torchvision import transforms
import skimage
import pickle
import copy
import re

import warnings

warnings.filterwarnings("ignore")

from builder.builder import build_model
from builder.my_utils_torch import *
from builder.my_utils import *
from MoGe.moge.laserscan import LaserScan
import MoGe.utils3d as utils3d
from builder.util import *


class Config:
    def __init__(self, args):
        self.basepath = args.basepath
        self.datapath = args.datapath
        self.target_class = args.target_class
        self.split = 0
        self.num_splits = 1
        self.resize_h = args.resize_h
        self.resize_w = args.resize_w
        self.H = args.H
        self.W = args.W
        self.no_sr = args.no_sr
        self.no_gpt = args.no_gpt
        self.max_inpaint_attempts = args.max_inpaint_attempts
        self.api_key = args.api_key
        self.sweeps = args.sweeps
        self.num_samples = args.num_samples
        self.rtol = args.rtol
        self.use_normal_mask = args.use_normal_mask
        self.erosion = args.erosion


idx = 16


class DataLoader:
    def __init__(self, config, device):
        self.config = config
        self.load_data()
        self.device = device

    def load_data(self):
        with open(f"{self.config.basepath}/10sweep_infos.pkl", "rb") as f:
            self.infos = pickle.load(f)
        self.infos = copy.deepcopy(self.infos)

    def get_data_indices(self, num_splits):
        part = len(self.infos["lidarpath"]) // num_splits
        start_idx = part * self.config.split
        end_idx = (
            part * (self.config.split + 1)
            if self.config.split != (num_splits - 1)
            else len(self.infos["lidarpath"])
        )
        return start_idx, end_idx

    def get_filtered_lidar(self, idx):
        lidar = np.fromfile(
            self.config.datapath + self.infos["lidarpath"][idx], dtype=np.float32
        ).reshape([-1, 5])[:, :4]
        return torch.from_numpy(lidar).float().to(self.device)

    def get_transforms(self, idx):
        return {
            "idx": idx,
            "lidar2camera": torch.from_numpy(self.infos["lidar2camera"][idx])
            .float()
            .to(self.device),
            "lidar2image": torch.from_numpy(self.infos["lidar2image"][idx])
            .float()
            .to(self.device),
            "camera2lidar": torch.from_numpy(self.infos["camera2lidar"][idx])
            .float()
            .to(self.device),
            "camera_intrinsic": torch.from_numpy(self.infos["camera_intrinsic"][idx])
            .float()
            .to(self.device),
            "img_path": self.infos["imgpath"][idx],
            "gt_boxes": torch.from_numpy(self.infos["gt_boxes"][idx])
            .float()
            .to(self.device),
            "gt_names": self.infos["gt_names"][idx],
            "tr": torch.from_numpy(self.infos["sweep_tr"][idx]).float().to(self.device),
            "time": torch.from_numpy(self.infos["sweep_time"][idx])
            .float()
            .to(self.device),
            "cam_map": [
                "CAM_FRONT",
                "CAM_FRONT_RIGHT",
                "CAM_FRONT_LEFT",
                "CAM_BACK",
                "CAM_BACK_LEFT",
                "CAM_BACK_RIGHT",
            ],
        }


class BoxProcessor:
    def __init__(self, config, device, num_view, val_preprocess):
        self.config = config
        self.device = device
        self.num_view = num_view
        self.val_preprocess = val_preprocess

    def get_random_box(
        self, data_dict, depths, ground_masks, iou_mask, ref_box, num_samples=5
    ):
        samples, valid_inds = [], []
        ground_lidar = depth_to_pointmap(
            depths,
            ground_masks,
            data_dict["camera_intrinsic"][:, :3, :3],
            data_dict["camera2lidar"],
        )

        for v in range(self.num_view):
            angle = get_filtered_values(data_dict["gt_boxes"][:, 6])
            angle_rad = angle[torch.randperm(len(angle))[0]].cpu().numpy().reshape(1, 1)

            view_lidar = ground_lidar[v][ground_masks[v] == 1]

            corners_3d_world = get_frustum_corners_world(
                x_min=0,
                x_max=self.config.W,
                y_min=0,
                y_max=self.config.H,
                K=data_dict["camera_intrinsic"][v, :3, :3],
                R=data_dict["lidar2camera"][v, :3, :3],
                t=data_dict["lidar2camera"][v, :3, 3],
                z_near=5.0,
                z_far=48.0,
            )
            random_box_corner = sample_lidar_points_outside_holes(
                corners_3d_world[:, :2],
                view_lidar,
                data_dict["lidar2image"][v : v + 1],
                ground_masks,
                self.config.H,
                self.config.W,
                v,
                iou_mask,
                ref_box,
                angle_rad,
                num_samples=num_samples,
                max_attempts=5,
            )
            if not random_box_corner.sum() == 0:
                samples.append(random_box_corner)
                valid_inds.append(v)
        return samples, valid_inds

    def get_iou_mask(self, image_processor, imgs, data_dict):
        iou_mask = torch.zeros(
            (self.num_view, self.config.H, self.config.W),
            dtype=torch.float32,
            device=self.device,
        )

        if not len(data_dict["gt_boxes"]) == 0:
            gt_corners = boxes_to_corners_3d(data_dict["gt_boxes"])
            gt_project_2d = projection_2d_box(
                data_dict["lidar2image"],
                gt_corners,
                self.config.H,
                self.config.W,
                convert_int=True,
            )

            x_min = gt_project_2d[:, :, 0].long().clamp(0, self.config.H - 1)
            y_min = gt_project_2d[:, :, 1].long().clamp(0, self.config.W - 1)
            x_max = gt_project_2d[:, :, 2].long().clamp(0, self.config.H - 1)
            y_max = gt_project_2d[:, :, 3].long().clamp(0, self.config.W - 1)

            boxes_b = torch.stack([y_min, x_min, y_max, x_max], dim=2)

            boxes_batch = [
                boxes_b[i][torch.any(boxes_b[i] != 0, dim=1)].float().cpu().numpy()
                for i in range(len(data_dict["cam_map"]))
            ]
            img_batch = [
                imgs[i].permute(1, 2, 0).float().cpu().numpy() / 255.0
                for i in range(len(data_dict["cam_map"]))
            ]

            box_idx = [
                i for i in range(len(boxes_batch)) if boxes_batch[i].shape[0] != 0
            ]
            boxes_batch = [boxes_batch[i] for i in box_idx]
            img_batch = [img_batch[i] for i in box_idx]

            mask = image_processor.get_object_mask(
                img_batch, boxes_batch, erosion=1
            ).float()
            iou_mask[box_idx] = mask
        return iou_mask

    def get_box(self, iou_mask, data_dict, depths, ground_masks, ref_box, num_samples):
        box_corners, valid_inds = self.get_random_box(
            data_dict, depths, ground_masks, iou_mask, ref_box, num_samples
        )

        if len(box_corners) == 0:
            return None, None, None, None, None
        else:
            num_boxes = box_corners[0].shape[0]

            inpaint_mask = torch.zeros(
                (self.num_view, num_boxes, self.config.H, self.config.W),
                dtype=torch.float32,
                device=self.device,
            )
            r_valid_inds, r_valid_boxes = [], []
            for i, v in enumerate(valid_inds):
                sample_2d = projection_2d_box(
                    data_dict["lidar2image"][v : v + 1],
                    box_corners[i],
                    self.config.H,
                    self.config.W,
                )
                row_min = sample_2d[0, :, 0].long().clamp(0, self.config.H - 1)
                row_max = sample_2d[0, :, 2].long().clamp(0, self.config.H - 1)
                col_min = sample_2d[0, :, 1].long().clamp(0, self.config.W - 1)
                col_max = sample_2d[0, :, 3].long().clamp(0, self.config.W - 1)
                valid_boxes = torch.stack(
                    [col_min, row_min, col_max, row_max], dim=1
                )  # (K, 4)

                for n in range(num_boxes):
                    inpaint_mask[
                        v, n, row_min[n] : row_max[n], col_min[n] : col_max[n]
                    ] = 1
                    inpaint_mask[v, n] = inpaint_mask[v, n] * (1 - iou_mask[v])
                    r_valid_inds.append(v)
                    r_valid_boxes.append(valid_boxes[n])
            r_valid_boxes = torch.stack(r_valid_boxes)
            r_valid_inds = torch.tensor(r_valid_inds, device=self.device)

            inpaint_mask = (
                (
                    inpaint_mask[valid_inds].view(-1, 1, self.config.H, self.config.W)
                    * 255
                )
                .repeat(1, 3, 1, 1)
                .to(torch.uint8)
            )
            iou_mask = (
                iou_mask.view(self.num_view, -1, self.config.H, self.config.W)[
                    valid_inds
                ]
                .repeat(1, box_corners[0].shape[0], 1, 1)
                .view(-1, self.config.H, self.config.W)
            )
            box_corners = torch.cat(box_corners, dim=0)
            return box_corners, r_valid_inds, r_valid_boxes, inpaint_mask, iou_mask

    def get_box_orient(self, pcs_array):
        assert pcs_array.ndim == 2 and pcs_array.shape[1] == 3, "Input must be (N, 3)"
        # Height estimation
        z_vals = pcs_array[:, 2]
        z_min = np.min(z_vals)
        z_max = np.max(z_vals)

        # Center xy around mean
        xy = pcs_array[:, :2]
        mean_xy = np.mean(xy, axis=0, keepdims=True)
        xy_centered = xy - mean_xy

        # Covariance and eigen decomposition
        cov = np.cov(xy_centered.T)  # shape (2, 2)
        eigvals, eigvecs = np.linalg.eigh(cov)

        # Sort eigenvectors by eigenvalue descending
        idx_desc = np.argsort(eigvals)[::-1]
        main_dir = eigvecs[:, idx_desc[0]]  # principal direction

        # Heading from principal direction
        yaw = np.arctan2(main_dir[1], main_dir[0])
        return yaw

    def get_fit_box(self, pcs_tensor, yaw):
        R = torch.tensor(
            [[torch.cos(-yaw), -torch.sin(-yaw)], [torch.sin(-yaw), torch.cos(-yaw)]],
            device=pcs_tensor.device,
        )
        points_xy = pcs_tensor[:, :2] @ R.T
        min_xy = points_xy.min(dim=0)[0]
        max_xy = points_xy.max(dim=0)[0]
        dx_tight, dy_tight = max_xy - min_xy
        height = pcs_tensor[:, 2].max() - pcs_tensor[:, 2].min()
        # center in rotated frame → unrotate
        center_xy_rot = (min_xy + max_xy) / 2
        center_xy = center_xy_rot @ R
        center_z = height / 2 + pcs_tensor[:, 2].min()
        box = torch.tensor(
            [center_xy[0], center_xy[1], center_z, dx_tight, dy_tight, height, yaw],
            dtype=torch.float32,
        ).to(self.device)
        return box.reshape(1, -1)

    def get_temporal_box(
        self, pcs_list, boxes, c, resized_lidar2image, x_bin=3, y_bin=3, z_bin=3
    ):
        if c == "cv" or c == "construction":
            step_forward = np.random.uniform(0.05, 0.1)

        t_boxes, _ = simulate_car_motion(
            boxes,
            sweep=self.config.sweeps,
            step_forward=step_forward,
            step_heading=math.radians(0),
        )
        selected = torch.stack(
            [select_uniform_points(pcs, x_bin, y_bin, z_bin) for pcs in pcs_list]
        )
        selected_points = torch.stack(
            [pcs_list[i][selected[i]][:, :3] for i in range(len(pcs_list))]
        )

        center_0 = t_boxes[:, 0][:, :3]  # (B, 3)
        t_centers = t_boxes[:, :, :3]  # (B, T, 3)
        delta = t_centers - center_0.unsqueeze(1)  # (B, T, 3)
        moved_points = selected_points.unsqueeze(1) + delta.unsqueeze(2)  # (B, T, K, 3)

        tracking_points = []
        for b in range(moved_points.shape[0]):
            moved_points2d, _ = project_points(
                moved_points[b],
                resized_lidar2image[b : b + 1].repeat(self.config.sweeps, 1, 1),
            )
            moved_points2d[..., 0] = moved_points2d[..., 0].clip(
                0, self.config.resize_w - 1
            )
            moved_points2d[..., 1] = moved_points2d[..., 1].clip(
                0, self.config.resize_h - 1
            )

            points = []
            for i in range(x_bin * y_bin * z_bin):
                points.append(
                    [
                        (moved_points2d[j, i, 0].item(), moved_points2d[j, i, 1].item())
                        for j in range(self.config.sweeps)
                    ]
                )
            tracking_points.append(points)
        return tracking_points, moved_points


class ImageProcessor:
    def __init__(
        self, config, device, inpainter, invsr, imagesam, grounding_model, moge
    ):
        self.config = config
        self.inpainter = inpainter
        self.invsr = invsr
        self.imagesam = imagesam
        self.grounding_model = grounding_model
        self.moge = moge
        self.totensor = transforms.ToTensor()
        self.device = device
        pos_prefix = neg_prefix = ""

        self.promptA = [pos_prefix + " P_obj"]
        self.promptB = [pos_prefix + " P_obj"]
        self.negative_promptA = [neg_prefix + "P_obj"]
        self.negative_promptB = [neg_prefix + "P_obj"]
        self.negative_prompt = ["low quality"]

    @torch.no_grad()
    def get_road(self, data_dict, caption="road."):
        imgs, depths, ground_mask = [], [], []
        nights = [False] * len(data_dict["cam_map"])
        for v in range(len(data_dict["cam_map"])):
            pil_img = Image.open(
                self.config.datapath + data_dict["img_path"][v]
            ).convert("RGB")
            img = np.array(pil_img)
            if img.mean() < 30:
                nights[v] = True

            imgs.append(
                torch.from_numpy(img).permute(2, 0, 1).to(self.device).unsqueeze(0)
            )

            depth = self.moge.infer(self.totensor(pil_img).unsqueeze(0).to(self.device))
            depths.append(depth["depth"])

            ground_input = load_tr_image(pil_img, self.device)
            outputs = self.grounding_model(ground_input[None], captions=[caption])
            prediction_boxes = outputs["pred_boxes"].cpu()[
                0
            ]  # prediction_boxes.shape = (nq, 4)
            prediction_logits = (
                outputs["pred_logits"].cpu().sigmoid()[0]
            )  # prediction_logits.shape = (nq, 256)

            mask = prediction_logits.max(dim=1)[0] > 0.3

            boxes = prediction_boxes[mask]  # boxes.shape = (n, 4)
            boxes = boxes * torch.Tensor(
                [self.config.W, self.config.H, self.config.W, self.config.H]
            )
            input_boxes = box_convert(
                boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy"
            ).numpy()
            if (input_boxes.shape[0] == 0) or nights[v]:
                masks = np.zeros((1, self.config.H, self.config.W), dtype=np.float32)
            else:
                self.imagesam.set_image(img)
                masks, scores, logits = self.imagesam.predict(
                    point_coords=None,
                    point_labels=None,
                    box=input_boxes,
                    multimask_output=False,
                )
            if len(masks.shape) == 3:
                ground_mask.append(torch.from_numpy(masks[0]))
            elif len(masks.shape) == 4:
                combined_mask = np.any(masks[:, 0, :, :], axis=0).astype(masks.dtype)
                ground_mask.append(torch.from_numpy(combined_mask))
        depths = torch.cat(depths).to(self.device)
        imgs = torch.cat(imgs).to(self.device)
        ground_masks = torch.stack(ground_mask).to(self.device)
        return imgs, depths, ground_masks, nights

    @torch.no_grad()
    def prepare_inpaint_inputs(
        self,
        inpaint_mask,
        iou_mask,
        valid_inds,
        valid_imgs,
        valid_scale_depths,
        valid_lidar_depths,
        crop_x1,
        crop_x2,
        crop_y1,
        crop_y2,
    ):
        cropped_imgs = torch.zeros(
            (
                len(valid_inds),
                valid_imgs.shape[1],
                self.config.resize_h,
                self.config.resize_w,
            ),
            device=self.device,
        )
        cropped_masks = torch.zeros(
            (
                len(valid_inds),
                inpaint_mask.shape[1],
                self.config.resize_h,
                self.config.resize_w,
            ),
            device=self.device,
        )
        cropped_iou_masks = torch.zeros(
            (len(valid_inds), self.config.resize_h, self.config.resize_w),
            device=self.device,
        )
        cropped_depths = torch.zeros(
            (len(valid_inds), self.config.resize_h, self.config.resize_w),
            device=self.device,
        )
        cropped_lidar_depths = torch.zeros(
            (len(valid_inds), self.config.resize_h, self.config.resize_w),
            device=self.device,
        )

        for i in range(len(valid_inds)):
            cropped_img = (
                valid_imgs[
                    i, :, crop_y1[i] : crop_y2[i], crop_x1[i] : crop_x2[i]
                ].float()
                / 255.0
            )
            cropped_mask = inpaint_mask[
                i, :, crop_y1[i] : crop_y2[i], crop_x1[i] : crop_x2[i]
            ]
            cropped_iou_mask = iou_mask[
                i, crop_y1[i] : crop_y2[i], crop_x1[i] : crop_x2[i]
            ]
            cropped_depth = valid_scale_depths[
                i, crop_y1[i] : crop_y2[i], crop_x1[i] : crop_x2[i]
            ]
            cropped_lidar_depth = valid_lidar_depths[
                i, crop_y1[i] : crop_y2[i], crop_x1[i] : crop_x2[i]
            ]

            if self.config.no_sr and cropped_img.shape[1] < self.config.resize_h:
                cropped_img = self.invsr.sample_func(cropped_img.unsqueeze(0)).squeeze(
                    0
                )

            cropped_imgs[i] = F.interpolate(
                cropped_img.unsqueeze(0),
                size=(self.config.resize_h, self.config.resize_w),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)
            cropped_masks[i] = (
                F.interpolate(
                    cropped_mask.unsqueeze(0).float(),
                    size=(self.config.resize_h, self.config.resize_w),
                    mode="nearest",
                ).squeeze(0)
                / 255.0
            )
            cropped_iou_masks[i] = F.interpolate(
                cropped_iou_mask.unsqueeze(0).unsqueeze(0),
                size=(self.config.resize_h, self.config.resize_w),
                mode="nearest",
            ).squeeze()
            cropped_depths[i] = F.interpolate(
                cropped_depth.unsqueeze(0).unsqueeze(0),
                size=(self.config.resize_h, self.config.resize_w),
                mode="nearest",
            ).squeeze()
            cropped_lidar_depths[i] = F.interpolate(
                cropped_lidar_depth.unsqueeze(0).unsqueeze(0),
                size=(self.config.resize_h, self.config.resize_w),
                mode="nearest",
            ).squeeze()

        input_dict = {"image": cropped_imgs, "mask": cropped_masks}
        input_dict["image"] = input_dict["image"] * (1 - input_dict["mask"])

        return (
            input_dict,
            cropped_iou_masks,
            cropped_depths,
            cropped_lidar_depths,
            cropped_imgs,
        )

    @torch.no_grad()
    def attempt_inpaint(self, input_dict, gpt_validator, object_class, prompts):
        attempts = 0
        all_imgs = input_dict["image"]
        all_masks = input_dict["mask"]
        total = len(all_imgs)

        remaining_indices = list(range(total))
        final_results = [None] * total

        while attempts < self.config.max_inpaint_attempts and remaining_indices:
            current_imgs = torch.stack([all_imgs[i] for i in remaining_indices])
            current_masks = torch.stack([all_masks[i] for i in remaining_indices])
            current_prompts = [prompts[i] for i in remaining_indices]

            seed = np.random.randint(0, 2**32 - 1)
            inpaint_results = self.inpainter.pipe(
                promptA=self.promptA * len(current_imgs),
                promptB=self.promptB * len(current_imgs),
                promptU=current_prompts,
                image=current_imgs,
                mask=current_masks,
                num_inference_steps=45,
                generator=torch.Generator(self.device).manual_seed(seed),
                brushnet_conditioning_scale=1.0,
                # tradoff=0.5,
                negative_promptA=self.negative_promptA * len(current_imgs),
                negative_promptB=self.negative_promptB * len(current_imgs),
                output_type="pt",
            ).images

            if self.config.no_gpt:
                validity = gpt_validator.validate_all(inpaint_results, object_class)
                new_remaining_indices = []
                for idx_in_batch, is_valid in enumerate(validity):
                    original_idx = remaining_indices[idx_in_batch]
                    if is_valid:
                        final_results[original_idx] = inpaint_results[idx_in_batch]
                    else:
                        new_remaining_indices.append(original_idx)

                remaining_indices = new_remaining_indices
            else:
                for i, idx in enumerate(remaining_indices):
                    final_results[idx] = inpaint_results[i]
                break

            attempts += 1

        return final_results

    @torch.no_grad()
    def get_object_mask(self, img_batch, boxes_batch, erosion=3):
        self.imagesam.set_image_batch(img_batch)
        sam_output, scores_batch, _ = self.imagesam.predict_batch(
            None, None, box_batch=boxes_batch, multimask_output=False
        )
        sam_output_masks = []
        for i in range(len(img_batch)):
            if len(sam_output[i].shape) == 3:
                mask = torch.from_numpy(
                    skimage.morphology.binary_erosion(
                        sam_output[i][0], footprint=skimage.morphology.disk(erosion)
                    )
                )
            elif len(sam_output[i].shape) == 4:
                sum_sam_output = sam_output[i].sum(axis=0)
                mask = torch.from_numpy(
                    skimage.morphology.binary_erosion(
                        sum_sam_output[0], footprint=skimage.morphology.disk(erosion)
                    )
                )
            sam_output_masks.append(mask)
        sam_output_masks = torch.stack(sam_output_masks).to(self.device)

        return sam_output_masks

    def save_inpaint_results(
        self,
        inpaint_results,
        valid_imgs,
        valid_boxes,
        crop_y1,
        crop_y2,
        crop_x1,
        crop_x2,
        valid_inds,
        object_mask,
        lidar_paths,
        output_dir,
        cam_map,
    ):
        reconstructed_imgs = valid_imgs.clone()
        for v in range(valid_imgs.shape[0]):
            resized_back_img = F.interpolate(
                inpaint_results[v].unsqueeze(0),
                size=(crop_y2[v] - crop_y1[v], crop_x2[v] - crop_x1[v]),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)

            x1, y1, x2, y2 = valid_boxes[v].tolist()
            cropped_resized_back_img = resized_back_img[
                :, y1 - crop_y1[v] : y2 - crop_y1[v], x1 - crop_x1[v] : x2 - crop_x1[v]
            ]
            reconstructed_imgs[v, :, y1:y2, x1:x2] = (
                cropped_resized_back_img * 255
            ).to(reconstructed_imgs.dtype)

        file_name = lidar_paths.split("/")[-1].split(".")[0]
        unique_vals, counts = valid_inds.unique(return_counts=True)

        v_to_counter = {v.item(): 0 for v in unique_vals}
        for i, v in enumerate(valid_inds):
            v_int = v.item()
            c = v_to_counter[v_int]
            v_to_counter[v_int] += 1

            object_dir = os.path.join(
                output_dir, f"OBJECT_{cam_map[v_int]}/{file_name}/{c}"
            )
            os.makedirs(object_dir, exist_ok=True)
            crop_img = Image.fromarray(
                (inpaint_results[i].permute(1, 2, 0).cpu().numpy() * 255).astype(
                    "uint8"
                )
            )
            crop_img.save(os.path.join(object_dir, "0.jpg"))

            object_dir = os.path.join(
                output_dir, f"OBJECT_{cam_map[v_int]}/{file_name}/{c}_recon"
            )
            os.makedirs(object_dir, exist_ok=True)
            crop_img = Image.fromarray(
                (reconstructed_imgs[i].permute(1, 2, 0).cpu().numpy()).astype("uint8")
            )
            crop_img.save(os.path.join(object_dir, "0.jpg"))
        return reconstructed_imgs


class GPTValidator:
    def __init__(self, config, gpt):
        self.config = config
        self.gpt = gpt

    def validate_all(self, images, object_class):
        results = []
        for img in images:
            img_bgr = cv2.cvtColor(
                (img.permute(1, 2, 0).cpu().numpy() * 255).astype("uint8"),
                cv2.COLOR_RGB2BGR,
            )
            _, buffer = cv2.imencode(".jpg", img_bgr)
            base64_image = base64.b64encode(buffer).decode("utf-8")

            response = self.gpt.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": f"Does the image contain only a single instance of the {object_class}, and is it appropriately scaled relative to the scene? Answer with only yes or no.",
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpg;base64,{base64_image}",
                                    "detail": "low",
                                },
                            },
                        ],
                        "temperature": 0.2,
                    }
                ],
                max_tokens=10,
            )
            answer = re.sub(
                r"[^a-zA-Z]", "", response.choices[0].message.content
            ).lower()
            results.append(answer == "yes")
        return results


class VideoProcessor:
    def __init__(
        self, config, i2v, videosam, video_depth_anything, depth_anything, device
    ):
        self.config = config
        self.i2v = i2v
        self.videosam = videosam
        self.video_depth_anything = video_depth_anything
        self.depth_anything = depth_anything
        self.device = device

    @torch.no_grad()
    def apply_i2v_tracking(self, tracking_points, temporal_mask, first_frame_path):
        output_RGB, _ = self.i2v.run(
            first_frame_path,
            tracking_points,
            1,
            temporal_mask.cpu().numpy(),
            ctrl_scale=0.6,
        )
        img_np = output_RGB.permute(0, 2, 3, 1).mul(255).cpu().numpy().astype(np.uint8)

        for i in range(1, self.config.sweeps):
            img_bgr = cv2.cvtColor(img_np[i], cv2.COLOR_RGB2BGR)
            image_path = first_frame_path.replace("0.jpg", f"{i}.jpg")
            cv2.imwrite(image_path, img_bgr)

        return img_np, output_RGB

    @torch.no_grad()
    def segment_video_objects(
        self, first_frame_path, resized_valid_box, cropped_iou_mask, erosion=2
    ):
        inference_state = self.videosam.init_state(
            video_path="/".join(first_frame_path.split("/")[:-1])
        )

        _, out_obj_ids, out_mask_logits = self.videosam.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=0,
            obj_id=0,
            box=resized_valid_box.float().cpu().numpy(),
        )

        video_segments = {}
        for (
            out_frame_idx,
            out_obj_ids,
            out_mask_logits,
        ) in self.videosam.propagate_in_video(inference_state):
            video_segments[out_frame_idx] = {
                out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
                for i, out_obj_id in enumerate(out_obj_ids)
            }

        video_masks = []
        for i in range(len(video_segments)):
            mask = video_segments[i][0]
            mask = torch.from_numpy(
                skimage.morphology.binary_erosion(
                    mask[0], footprint=skimage.morphology.disk(erosion)
                )
            ).to(self.device)
            mask[cropped_iou_mask == 1] = 0
            video_masks.append(mask)

        self.videosam.reset_state(inference_state)

        return torch.stack(video_masks).to(self.device)

    def compute_scale_and_shift_full(self, prediction, target, mask=None):
        prediction = prediction.astype(np.float32)
        target = target.astype(np.float32)
        if mask is None:
            mask = np.ones_like(target) == 1
        mask = mask.astype(np.float32)

        a_00 = np.sum(mask * prediction * prediction)
        a_01 = np.sum(mask * prediction)
        a_11 = np.sum(mask)

        b_0 = np.sum(mask * prediction * target)
        b_1 = np.sum(mask * target)

        x_0 = 1
        x_1 = 0

        det = a_00 * a_11 - a_01 * a_01

        if det != 0:
            x_0 = (a_11 * b_0 - a_01 * b_1) / det
            x_1 = (-a_01 * b_0 + a_00 * b_1) / det

        return x_0, x_1

    @torch.no_grad()
    def estimate_video_depth(self, img_np):
        disp, metric = [], []

        disparity, fps = self.video_depth_anything.infer_video_depth(
            img_np.astype(np.float32),
            self.config.sweeps,
            input_size=self.config.resize_h,
            device=self.device,
            fp32=False,
        )

        for f in range(self.config.sweeps):
            metric_depth = self.depth_anything.infer_image(
                img_np.astype(np.float32)[f], self.config.resize_h, device=self.device
            )
            metric.append(1 / metric_depth)
            disp.append(disparity[f])

        scale, shift = self.compute_scale_and_shift_full(
            np.concatenate(disp), np.concatenate(metric)
        )
        inverse_reconstructed_metric_depth = (disparity * scale) + shift

        return torch.from_numpy(1 / inverse_reconstructed_metric_depth).to(self.device)


class PointCloudProcessor:
    def __init__(self, config, device):
        self.config = config
        self.device = device

    def depth_to_pointcloud(
        self, depth, object_mask, result, data_dict, filter_ratio=0.015, video=False
    ):
        if video:
            result = result.view(-1, 3, self.config.resize_h, self.config.resize_w)
        grayscale_image = torch.clamp(
            (
                0.2989 * result[:, 0] + 0.5870 * result[:, 1] + 0.1140 * result[:, 2]
            ).float()
            * 255,
            min=50,
        )

        if video:
            K = (
                data_dict["K_crop_resize"]
                .unsqueeze(1)
                .repeat(1, self.config.sweeps, 1, 1)
                .view(-1, 3, 3)
            )
            E = (
                data_dict["camera2lidar"]
                .unsqueeze(1)
                .repeat(1, self.config.sweeps, 1, 1)
                .view(-1, 4, 4)
            )
            d = depth.view(-1, self.config.resize_h, self.config.resize_w)
            m = object_mask.view(-1, self.config.resize_h, self.config.resize_w)
            point_cloud_map = depth_to_pointmap(d, m, K, E)
            normals, _ = points_to_normals(point_cloud_map, mask=m)
            mask = (
                m
                & ~torch.isnan(point_cloud_map[..., 2])
                & ~torch.isnan(point_cloud_map[..., 1])
                & ~torch.isnan(point_cloud_map[..., 0])
            )
        else:
            point_cloud_map = depth_to_pointmap(
                depth,
                object_mask,
                data_dict["K_crop_resize"],
                data_dict["camera2lidar"],
            )
            normals, _ = points_to_normals(point_cloud_map, mask=object_mask)
            mask = (
                object_mask
                & ~torch.isnan(point_cloud_map[..., 2])
                & ~torch.isnan(point_cloud_map[..., 1])
                & ~torch.isnan(point_cloud_map[..., 0])
            )

        normal_attenuation = torch.abs(normals[:, :, :, 2]) ** 0.5
        distance_attenuation = torch.exp(-0.03 * torch.norm(point_cloud_map, dim=-1))
        intensity = torch.clamp(
            grayscale_image * normal_attenuation * distance_attenuation, 0, 255
        )

        pcs_list = [
            torch.concat(
                [point_cloud_map[i][mask[i]], intensity[i][mask[i]].view(-1, 1)], dim=1
            )
            for i in range(point_cloud_map.shape[0])
        ]
        color_list = [
            result[i].permute(1, 2, 0)[mask[i]] for i in range(point_cloud_map.shape[0])
        ]

        filtered_pcs = []
        filtered_colors = []
        for pcs, colors in zip(pcs_list, color_list):
            mask = torch.ones(pcs.shape[0], dtype=torch.bool, device=pcs.device)
            for dim in range(3):  # x, y, z
                lower = pcs[:, dim].quantile(filter_ratio)
                upper = pcs[:, dim].quantile(1 - filter_ratio)
                dim_mask = (pcs[:, dim] > lower) & (pcs[:, dim] < upper)
                mask &= dim_mask
            filtered_pcs.append(pcs[mask])
            filtered_colors.append(colors[mask])
        return filtered_pcs, filtered_colors

    def scale_depthmap(self, depth, object_mask, ref_box, data_dict, video=False):
        target_depth = depth[:, 0] if video else depth
        point_cloud_map = depth_to_pointmap(
            target_depth,
            object_mask,
            data_dict["K_crop_resize"],
            data_dict["camera2lidar"],
        )
        mask = (
            object_mask
            & ~torch.isnan(point_cloud_map[..., 2])
            & ~torch.isnan(point_cloud_map[..., 1])
            & ~torch.isnan(point_cloud_map[..., 0])
        )
        pcs = [point_cloud_map[i][mask[i]] for i in range(point_cloud_map.shape[0])]
        points_xyz_padded, _ = pad_points_list_to_tensor(pcs)
        boxes = compute_box(points_xyz_padded)

        z_scale = ref_box[0, 5].repeat(boxes.shape[0]) / boxes[:, 5]
        scaled_depth = torch.zeros_like(depth)
        for v in range(depth.shape[0]):
            scaled_depth[v] = depth[v] * z_scale[v]
        return scaled_depth

    def transform_lidar_points(self, points, transformation_matrix):
        if points.shape[1] == 3:
            points = np.hstack(
                (points, np.ones((points.shape[0], 1), dtype=np.float32))
            )  # (N, 4)
        transformed_points = (transformation_matrix @ points.T).T  # (N, 4)
        return transformed_points[:, :3]

    def pcs_to_lidar(self, points, intensity):
        laser_scan = LaserScan(project=True, H=32, W=1070, fov_up=10.0, fov_down=-30.0)
        laser_scan.set_points(points=points, remissions=intensity)
        rv_depth = laser_scan.proj_range
        rv_intensity = laser_scan.proj_remission
        rv_points = laser_scan.unproject(rv_depth, rv_intensity, H=32).astype(
            np.float32
        )
        return rv_points

    def save_10sweep_pointcloud(
        self, pcs_list, valid_inds, data_dict, t_boxes, dynamic_inds=None
    ):
        tr = data_dict["tr"].clone().detach().cpu().numpy()
        time = data_dict["time"].float().clone().detach().cpu().numpy()
        unique_vals, counts = valid_inds.unique(return_counts=True)
        v_to_counter = {v.item(): 0 for v in unique_vals}

        tr_pcses = []
        for i, (pcs, v) in enumerate(zip(pcs_list, valid_inds)):
            v_int = v.item()
            v_to_counter[v_int] += 1

            if (dynamic_inds is not None) and (v_int not in dynamic_inds):
                continue

            tr_pcs = []
            if dynamic_inds is None:
                pcs = pcs.clone().detach().cpu().numpy()
                points, intensity = pcs[:, :3], pcs[:, 3:4]

            for s in range(tr.shape[0]):
                if dynamic_inds is not None:
                    points, intensity = (
                        pcs[s][:, :3].clone().detach().cpu().numpy(),
                        pcs[s][:, 3:4].clone().detach().cpu().numpy(),
                    )
                t_lag = time[s]

                tr_inv = np.linalg.inv(tr[s])
                tr_points = self.transform_lidar_points(points, tr_inv)
                tr_points = np.hstack((tr_points, intensity))
                rv_tr_points = self.pcs_to_lidar(tr_points[:, :3], tr_points[:, 3])

                tr_points_recovered = self.transform_lidar_points(
                    rv_tr_points[:, :3], tr[s]
                )
                tr_points_recovered = np.hstack(
                    (tr_points_recovered, rv_tr_points[:, 3:4])
                )
                time_lag = (
                    np.ones((tr_points_recovered.shape[0], 1), dtype=np.float32) * t_lag
                )
                tr_points_recovered = np.hstack((tr_points_recovered, time_lag))
                tr_pcs.append(tr_points_recovered)

            tr_pcs = np.concatenate(tr_pcs)
            tr_pcs[:, :3] -= t_boxes[i, :3].cpu().numpy()
            tr_pcses.append(tr_pcs)
        return tr_pcses
