# -*- coding: utf-8 -*-
# Author: , Hao Xiang <haxiang@g.ucla.edu>,
# License: TDG-Attribution-NonCommercial-NoDistrib

import time
import os

import cv2
import numpy as np
import open3d as o3d
import matplotlib
import matplotlib.pyplot as plt

from matplotlib import cm
import torch

from opencood.utils import box_utils
from opencood.utils import common_utils

VIRIDIS = np.array(cm.get_cmap('plasma').colors)
VID_RANGE = np.linspace(0.0, 1.0, VIRIDIS.shape[0])


def bbx2linset(bbx_corner, order='hwl', color=(0, 1, 0)):
    """
    Convert the torch tensor bounding box to o3d lineset for visualization.

    Parameters
    ----------
    bbx_corner : torch.Tensor
        shape: (n, 8, 3).

    order : str
        The order of the bounding box if shape is (n, 7)

    color : tuple
        The bounding box color.

    Returns
    -------
    line_set : list
        The list containing linsets.
    """
    if not isinstance(bbx_corner, np.ndarray):
        bbx_corner = common_utils.torch_tensor_to_numpy(bbx_corner)

    if len(bbx_corner.shape) == 2:
        bbx_corner = box_utils.boxes_to_corners_3d(bbx_corner,
                                                   order)

    # Our lines span from points 0 to 1, 1 to 2, 2 to 3, etc...
    lines = [[0, 1], [1, 2], [2, 3], [0, 3],
             [4, 5], [5, 6], [6, 7], [4, 7],
             [0, 4], [1, 5], [2, 6], [3, 7]]

    # Use the same color for all lines
    colors = [list(color) for _ in range(len(lines))]
    bbx_linset = []

    for i in range(bbx_corner.shape[0]):
        bbx = bbx_corner[i]
        # o3d use right-hand coordinate
        bbx[:, :1] = - bbx[:, :1]

        line_set = o3d.geometry.LineSet()
        line_set.points = o3d.utility.Vector3dVector(bbx)
        line_set.lines = o3d.utility.Vector2iVector(lines)
        line_set.colors = o3d.utility.Vector3dVector(colors)
        bbx_linset.append(line_set)

    return bbx_linset


def bbx2oabb(bbx_corner, order='hwl', color=(0, 0, 1)):
    """
    Convert the torch tensor bounding box to o3d oabb for visualization.

    Parameters
    ----------
    bbx_corner : torch.Tensor
        shape: (n, 8, 3).

    order : str
        The order of the bounding box if shape is (n, 7)

    color : tuple
        The bounding box color.

    Returns
    -------
    oabbs : list
        The list containing all oriented bounding boxes.
    """
    if not isinstance(bbx_corner, np.ndarray):
        bbx_corner = common_utils.torch_tensor_to_numpy(bbx_corner)

    if len(bbx_corner.shape) == 2:
        bbx_corner = box_utils.boxes_to_corners_3d(bbx_corner,
                                                   order)
    oabbs = []

    for i in range(bbx_corner.shape[0]):
        bbx = bbx_corner[i]
        # o3d use right-hand coordinate
        bbx[:, :1] = - bbx[:, :1]

        tmp_pcd = o3d.geometry.PointCloud()
        tmp_pcd.points = o3d.utility.Vector3dVector(bbx)

        oabb = tmp_pcd.get_oriented_bounding_box()
        oabb.color = color
        oabbs.append(oabb)

    return oabbs


def bbx2aabb(bbx_center, order):
    """
    Convert the torch tensor bounding box to o3d aabb for visualization.

    Parameters
    ----------
    bbx_center : torch.Tensor
        shape: (n, 7).

    order: str
        hwl or lwh.

    Returns
    -------
    aabbs : list
        The list containing all o3d.aabb
    """
    if not isinstance(bbx_center, np.ndarray):
        bbx_center = common_utils.torch_tensor_to_numpy(bbx_center)
    bbx_corner = box_utils.boxes_to_corners_3d(bbx_center, order)

    aabbs = []

    for i in range(bbx_corner.shape[0]):
        bbx = bbx_corner[i]
        # o3d use right-hand coordinate
        bbx[:, :1] = - bbx[:, :1]

        tmp_pcd = o3d.geometry.PointCloud()
        tmp_pcd.points = o3d.utility.Vector3dVector(bbx)

        aabb = tmp_pcd.get_axis_aligned_bounding_box()
        aabb.color = (0, 0, 1)
        aabbs.append(aabb)

    return aabbs


def linset_assign_list(vis,
                       lineset_list1,
                       lineset_list2,
                       update_mode='update'):
    """
    Associate two lists of lineset.

    Parameters
    ----------
    vis : open3d.Visualizer
    lineset_list1 : list
    lineset_list2 : list
    update_mode : str
        Add or update the geometry.
    """
    for j in range(len(lineset_list1)):
        index = j if j < len(lineset_list2) else -1
        lineset_list1[j] = \
            lineset_assign(lineset_list1[j],
                                     lineset_list2[index])
        if update_mode == 'add':
            vis.add_geometry(lineset_list1[j])
        else:
            vis.update_geometry(lineset_list1[j])


def lineset_assign(lineset1, lineset2):
    """
    Assign the attributes of lineset2 to lineset1.

    Parameters
    ----------
    lineset1 : open3d.LineSet
    lineset2 : open3d.LineSet

    Returns
    -------
    The lineset1 object with 2's attributes.
    """

    lineset1.points = lineset2.points
    lineset1.lines = lineset2.lines
    lineset1.colors = lineset2.colors

    return lineset1


def color_encoding(intensity, mode='intensity'):
    """
    Encode the single-channel intensity to 3 channels rgb color.

    Parameters
    ----------
    intensity : np.ndarray
        Lidar intensity, shape (n,)

    mode : str
        The color rendering mode. intensity, z-value and constant are
        supported.

    Returns
    -------
    color : np.ndarray
        Encoded Lidar color, shape (n, 3)
    """
    assert mode in ['intensity', 'z-value', 'constant']

    if mode == 'intensity':
        intensity_col = 1.0 - np.log(intensity) / np.log(np.exp(-0.004 * 100))
        int_color = np.c_[
            np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 0]),
            np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 1]),
            np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 2])]

    elif mode == 'z-value':
        min_value = -1.5
        max_value = 0.5
        norm = matplotlib.colors.Normalize(vmin=min_value, vmax=max_value)
        cmap = cm.jet
        m = cm.ScalarMappable(norm=norm, cmap=cmap)

        colors = m.to_rgba(intensity)
        colors[:, [2, 1, 0, 3]] = colors[:, [0, 1, 2, 3]]
        colors[:, 3] = 0.5
        int_color = colors[:, :3]

    elif mode == 'constant':
        # regard all point cloud the same color
        int_color = np.ones((intensity.shape[0], 3))
        int_color[:, 0] *= 247 / 255
        int_color[:, 1] *= 244 / 255
        int_color[:, 2] *= 237 / 255

    return int_color


def save_feature_response_map(feature_tensor,
                              save_path,
                              reduction='mean_abs',
                              cmap='plasma',
                              upsample_factor=4,
                              title=None,
                              symmetric=False):
    """
    Convert a BEV feature tensor into a 2D response map and save it as an image.

    Parameters
    ----------
    feature_tensor : torch.Tensor or np.ndarray
        Expected shapes: (C, H, W), (N, C, H, W) or (H, W).
    save_path : str
        Output path for the saved image.
    reduction : str
        How to collapse channel dimension. Supported:
        ['mean_abs', 'max_abs', 'sum_squares', 'mean', 'sum'].
    cmap : str
        Matplotlib colormap name.
    upsample_factor : int
        Factor to upscale the heatmap for readability.
    title : str
        Optional text annotation drawn on the image.
    symmetric : bool
        Whether to normalize the response around zero (useful for delta maps).
    """
    if feature_tensor is None:
        return

    if torch.is_tensor(feature_tensor):
        tensor = feature_tensor.detach().cpu()
    else:
        tensor = torch.from_numpy(np.asarray(feature_tensor))

    if tensor.dim() == 4:
        tensor = tensor.mean(dim=0)

    if tensor.dim() == 3:
        if reduction == 'mean_abs':
            response = tensor.abs().mean(dim=0)
        elif reduction == 'max_abs':
            response = tensor.abs().amax(dim=0)
        elif reduction == 'sum_squares':
            response = tensor.pow(2).sum(dim=0)
        elif reduction == 'mean':
            response = tensor.mean(dim=0)
        elif reduction == 'sum':
            response = tensor.sum(dim=0)
        else:
            raise ValueError(f'Unsupported reduction mode: {reduction}')
    elif tensor.dim() == 2:
        response = tensor
    else:
        raise ValueError(f'Unsupported feature tensor shape: {tensor.shape}')

    response_np = response.numpy()

    if symmetric:
        max_val = np.max(np.abs(response_np))
        if max_val > 0:
            normalized = (response_np / (2 * max_val)) + 0.5
        else:
            normalized = np.full_like(response_np, 0.5)
    else:
        response_np = response_np - np.min(response_np)
        max_val = np.max(response_np)
        if max_val > 0:
            normalized = response_np / max_val
        else:
            normalized = np.zeros_like(response_np)

    normalized = np.clip(normalized, 0.0, 1.0)
    cmap_fn = cm.get_cmap(cmap)
    heatmap = (cmap_fn(normalized)[..., :3] * 255).astype(np.uint8)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_RGB2BGR)

    if upsample_factor > 1:
        h, w = heatmap.shape[:2]
        heatmap = cv2.resize(
            heatmap,
            (w * upsample_factor, h * upsample_factor),
            interpolation=cv2.INTER_CUBIC
        )

    if title:
        font_scale = max(0.4, min(heatmap.shape[0], heatmap.shape[1]) / 512.0)
        cv2.putText(
            heatmap,
            title,
            (12, 28),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (255, 255, 255),
            1,
            cv2.LINE_AA
        )

    dir_name = os.path.dirname(save_path)
    if dir_name:
        os.makedirs(dir_name, exist_ok=True)

    cv2.imwrite(save_path, heatmap)


def visualize_single_sample_output_gt(pred_tensor,
                                      gt_tensor,
                                      pcd,
                                      show_vis=True,
                                      save_path='',
                                      mode='constant',
                                      target_tensor=None, 
                                      oabb=True,
                                      ego_tensor=None):#oabb=True for OPV2V, False for V2V4Real
    """
    Visualize the prediction, groundtruth with point cloud together.

    Parameters
    ----------
    pred_tensor : torch.Tensor
        (N, 8, 3) prediction.

    gt_tensor : torch.Tensor
        (N, 8, 3) groundtruth bbx

    pcd : torch.Tensor
        PointCloud, (N, 4).

    show_vis : bool
        Whether to show visualization.

    save_path : str
        Save the visualization results to given path.

    mode : str
        Color rendering mode.
    
    ego_tensor : torch.Tensor or np.ndarray, optional
        Ego vehicle bounding box, shape (1, 7) or (1, 8, 3). Will be shown in blue.
    """
    # Always use bbx2linset for bounding boxes - LineSet supports line_width setting
    # OrientedBoundingBox (oabb) does NOT support line_width, so lines appear thin
    conversion = bbx2linset
    def custom_draw_geometry(pcd, pred, gt, target=None, ego=None): 
        
        vis = o3d.visualization.Visualizer()
        vis.create_window()

        opt = vis.get_render_option()
        opt.background_color = np.asarray([0, 0, 0])
        opt.point_size = 1.0
        opt.line_width = 150.0  # Much thicker lines for better visibility

        vis.add_geometry(pcd)
        for ele in pred:
            vis.add_geometry(ele)
        for ele in gt:
            vis.add_geometry(ele)
        if target is not None:
            for ele in target:
                vis.add_geometry(ele)
        if ego is not None:
            for ele in ego:
                vis.add_geometry(ele)

        vis.run()
        vis.destroy_window()

    if len(pcd.shape) == 3:
        pcd = pcd[0]
    origin_lidar = pcd
    if not isinstance(pcd, np.ndarray):
        origin_lidar = common_utils.torch_tensor_to_numpy(pcd)

    origin_lidar_intcolor = \
        color_encoding(origin_lidar[:, -1] if mode == 'intensity'
                       else origin_lidar[:, 2], mode=mode)
    # left -> right hand
    origin_lidar[:, :1] = -origin_lidar[:, :1]

    o3d_pcd = o3d.geometry.PointCloud()
    o3d_pcd.points = o3d.utility.Vector3dVector(origin_lidar[:, :3])
    o3d_pcd.colors = o3d.utility.Vector3dVector(origin_lidar_intcolor)

    oabbs_pred = conversion(pred_tensor, color=(1, 0, 0))
    oabbs_gt = conversion(gt_tensor, color=(0, 1, 0))
    
    # Create ego vehicle box (blue)
    oabbs_ego = None
    if ego_tensor is not None:
        oabbs_ego = conversion(ego_tensor, color=(0, 0, 1))
    
    # Build visualization elements list
    visualize_elements = [o3d_pcd] + oabbs_pred + oabbs_gt
    if target_tensor is not None:
        oabbs_target = conversion(target_tensor, color=(0, 0, 1)) 
        visualize_elements += oabbs_target
    if oabbs_ego is not None:
        visualize_elements += oabbs_ego
    
    if show_vis:
        if target_tensor is not None:
            custom_draw_geometry(o3d_pcd, oabbs_pred, oabbs_gt, oabbs_target, oabbs_ego)
        else:
            custom_draw_geometry(o3d_pcd, oabbs_pred, oabbs_gt, None, oabbs_ego)
    
    if save_path:
        save_o3d_visualization(visualize_elements, save_path)


def visualize_single_sample_output_bev(pred_box, gt_box, pcd, dataset,
                                       show_vis=True,
                                       save_path='', target_tensor=None):
    """
    Visualize the prediction, groundtruth with point cloud together in
    a bev format.

    Parameters
    ----------
    pred_box : torch.Tensor
        (N, 4, 2) prediction.

    gt_box : torch.Tensor
        (N, 4, 2) groundtruth bbx

    pcd : torch.Tensor
        PointCloud, (N, 4).

    show_vis : bool
        Whether to show visualization.

    save_path : str
        Save the visualization results to given path.
    """

    if not isinstance(pcd, np.ndarray):
        pcd = common_utils.torch_tensor_to_numpy(pcd)
    if pred_box is not None and not isinstance(pred_box, np.ndarray):
        pred_box = common_utils.torch_tensor_to_numpy(pred_box)
    if gt_box is not None and not isinstance(gt_box, np.ndarray):
        gt_box = common_utils.torch_tensor_to_numpy(gt_box)
    if target_tensor is not None and not isinstance(target_tensor, np.ndarray):
        target_box = common_utils.torch_tensor_to_numpy(target_tensor)

    ratio = dataset.params["preprocess"]["args"]["res"]
    L1, W1, H1, L2, W2, H2 = dataset.params["preprocess"]["cav_lidar_range"]
    bev_origin = np.array([L1, W1]).reshape(1, -1)
    # (img_row, img_col)
    bev_map = dataset.project_points_to_bev_map(pcd, ratio)
    # (img_row, img_col, 3)
    bev_map = \
        np.repeat(bev_map[:, :, np.newaxis], 3, axis=-1).astype(np.float32)
    bev_map = bev_map * 255

    if pred_box is not None:
        num_bbx = pred_box.shape[0]
        for i in range(num_bbx):
            bbx = pred_box[i]

            bbx = ((bbx - bev_origin) / ratio).astype(int)
            bbx = bbx[:, ::-1]
            cv2.polylines(bev_map, [bbx], True, (0, 0, 255), 1)

    if gt_box is not None and len(gt_box):
        for i in range(gt_box.shape[0]):
            bbx = gt_box[i][:4, :2]
            bbx = (((bbx - bev_origin)) / ratio).astype(int)
            bbx = bbx[:, ::-1]
            cv2.polylines(bev_map, [bbx], True, (255, 0, 0), 1)

    if target_tensor is not None and len(target_box):
        for i in range(target_box.shape[0]):
            bbx = gt_box[i][:4, :2]
            bbx = (((bbx - bev_origin)) / ratio).astype(int)
            bbx = bbx[:, ::-1]
            cv2.polylines(bev_map, [bbx], True, (0, 255, 0), 1)

    if show_vis:
        plt.axis("off")
        plt.imshow(bev_map)
        plt.show()
    if save_path:
        plt.axis("off")
        plt.imshow(bev_map)
        plt.savefig(save_path)


def visualize_single_sample_dataloader(batch_data,
                                       o3d_pcd,
                                       order,
                                       key='origin_lidar',
                                       visualize=False,
                                       save_path='',
                                       oabb=False,
                                       mode='constant'):
    """
    Visualize a single frame of a single CAV for validation of data pipeline.

    Parameters
    ----------
    o3d_pcd : o3d.PointCloud
        Open3d PointCloud.

    order : str
        The bounding box order.

    key : str
        origin_lidar for late fusion and stacked_lidar for early fusion.

    visualize : bool
        Whether to visualize the sample.

    batch_data : dict
        The dictionary that contains current timestamp's data.

    save_path : str
        If set, save the visualization image to the path.

    oabb : bool
        If oriented bounding box is used.
    """

    origin_lidar = batch_data[key]
    if not isinstance(origin_lidar, np.ndarray):
        origin_lidar = common_utils.torch_tensor_to_numpy(origin_lidar)
    # we only visualize the first cav for single sample
    if len(origin_lidar.shape) > 2:
        origin_lidar = origin_lidar[0]
    origin_lidar_intcolor = \
        color_encoding(origin_lidar[:, -1] if mode == 'intensity'
                       else origin_lidar[:, 2], mode=mode)

    # left -> right hand
    origin_lidar[:, :1] = -origin_lidar[:, :1]

    o3d_pcd.points = o3d.utility.Vector3dVector(origin_lidar[:, :3])
    o3d_pcd.colors = o3d.utility.Vector3dVector(origin_lidar_intcolor)

    object_bbx_center = batch_data['object_bbx_center']
    object_bbx_mask = batch_data['object_bbx_mask']
    object_bbx_center = object_bbx_center[object_bbx_mask == 1]

    aabbs = bbx2linset(object_bbx_center, order) if not oabb else \
        bbx2oabb(object_bbx_center, order)
    visualize_elements = [o3d_pcd] + aabbs
    if visualize:
        o3d.visualization.draw_geometries(visualize_elements)

    if save_path:
        save_o3d_visualization(visualize_elements, save_path)

    return o3d_pcd, aabbs


def visualize_inference_sample_dataloader(pred_box_tensor,
                                          gt_box_tensor,
                                          origin_lidar,
                                          o3d_pcd,
                                          mode='constant'):
    """
    Visualize a frame during inference for video stream.

    Parameters
    ----------
    pred_box_tensor : torch.Tensor
        (N, 8, 3) prediction.

    gt_box_tensor : torch.Tensor
        (N, 8, 3) groundtruth bbx

    origin_lidar : torch.Tensor
        PointCloud, (N, 4).

    o3d_pcd : open3d.PointCloud
        Used to visualize the pcd.

    mode : str
        lidar point rendering mode.
    """

    if not isinstance(origin_lidar, np.ndarray):
        origin_lidar = common_utils.torch_tensor_to_numpy(origin_lidar)
    # we only visualize the first cav for single sample
    if len(origin_lidar.shape) > 2:
        origin_lidar = origin_lidar[0]
    # this is for 2-stage origin lidar, it has different format
    if origin_lidar.shape[1] > 4:
        origin_lidar = origin_lidar[:, 1:]

    origin_lidar_intcolor = \
        color_encoding(origin_lidar[:, -1] if mode == 'intensity'
                       else origin_lidar[:, 2], mode=mode)

    if not isinstance(pred_box_tensor, np.ndarray):
        if pred_box_tensor is not None:
            pred_box_tensor = common_utils.torch_tensor_to_numpy(pred_box_tensor)
    if not isinstance(gt_box_tensor, np.ndarray):
        gt_box_tensor = common_utils.torch_tensor_to_numpy(gt_box_tensor)

    # left -> right hand
    origin_lidar[:, :1] = -origin_lidar[:, :1]

    o3d_pcd.points = o3d.utility.Vector3dVector(origin_lidar[:, :3])
    o3d_pcd.colors = o3d.utility.Vector3dVector(origin_lidar_intcolor)

    gt_o3d_box = bbx2linset(gt_box_tensor, order='hwl', color=(0, 1, 0))
    if pred_box_tensor is not None:
        pred_o3d_box = bbx2linset(pred_box_tensor, color=(1, 0, 0))
    else:
        pred_o3d_box = None

    return o3d_pcd, pred_o3d_box, gt_o3d_box


def visualize_sequence_dataloader(dataloader, order, color_mode='constant'):
    """
    Visualize the batch data in animation.

    Parameters
    ----------
    dataloader : torch.Dataloader
        Pytorch dataloader

    order : str
        Bounding box order(N, 7).

    color_mode : str
        Color rendering mode.
    """
    vis = o3d.visualization.Visualizer()
    vis.create_window()

    vis.get_render_option().background_color = [0.05, 0.05, 0.05]
    vis.get_render_option().point_size = 1.0
    vis.get_render_option().show_coordinate_frame = True

    # used to visualize lidar points
    vis_pcd = o3d.geometry.PointCloud()
    # used to visualize object bounding box, maximum 50
    vis_aabbs = []
    for _ in range(50):
        vis_aabbs.append(o3d.geometry.LineSet())

    while True:
        for i_batch, sample_batched in enumerate(dataloader):
            print(i_batch)
            pcd, aabbs = \
                visualize_single_sample_dataloader(sample_batched['ego'],
                                                   vis_pcd,
                                                   order,
                                                   mode=color_mode)
            if i_batch == 0:
                vis.add_geometry(pcd)
                for i in range(len(vis_aabbs)):
                    index = i if i < len(aabbs) else -1
                    vis_aabbs[i] = lineset_assign(vis_aabbs[i], aabbs[index])
                    vis.add_geometry(vis_aabbs[i])

            for i in range(len(vis_aabbs)):
                index = i if i < len(aabbs) else -1
                vis_aabbs[i] = lineset_assign(vis_aabbs[i], aabbs[index])
                vis.update_geometry(vis_aabbs[i])

            vis.update_geometry(pcd)
            vis.poll_events()
            vis.update_renderer()
            time.sleep(0.001)

    vis.destroy_window()


def convert_lineset_to_cylinders(line_set, radius=0.15):
    """
    Convert a LineSet to a mesh of cylinders for thick line rendering.
    
    Parameters
    ----------
    line_set : open3d.geometry.LineSet
    radius : float
        Radius of the cylinders (determines line thickness).
        
    Returns
    -------
    open3d.geometry.TriangleMesh
    """
    points = np.asarray(line_set.points)
    lines = np.asarray(line_set.lines)
    colors = np.asarray(line_set.colors)
    
    meshes = []
    
    for i in range(len(lines)):
        # Get line endpoints
        idx1, idx2 = lines[i]
        p1 = points[idx1]
        p2 = points[idx2]
        
        # Get color
        if len(colors) > i:
            color = colors[i]
        elif len(colors) > 0:
            color = colors[0]
        else:
            color = [1, 0, 0]
            
        # Calculate length and direction
        vec = p2 - p1
        length = np.linalg.norm(vec)
        if length == 0:
            continue
            
        # Create cylinder
        # Note: Open3D creates cylinder along Z axis with center at origin
        cylinder = o3d.geometry.TriangleMesh.create_cylinder(radius=radius, height=length, resolution=8)
        cylinder.compute_vertex_normals()
        cylinder.paint_uniform_color(color)
        
        # Calculate rotation
        z_axis = np.array([0, 0, 1])
        vec_norm = vec / length
        
        # Axis of rotation is cross product of Z and vector
        axis = np.cross(z_axis, vec_norm)
        axis_len = np.linalg.norm(axis)
        
        if axis_len < 1e-6:
            # Vectors are parallel
            if np.dot(z_axis, vec_norm) < 0:
                # Anti-parallel, rotate 180 degrees
                R = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
            else:
                R = np.eye(3)
        else:
            axis = axis / axis_len
            angle = np.arccos(np.clip(np.dot(z_axis, vec_norm), -1.0, 1.0))
            # Create rotation matrix
            R = o3d.geometry.get_rotation_matrix_from_axis_angle(axis * angle)
            
        # Apply transformation: Rotate then Translate
        cylinder.rotate(R, center=[0, 0, 0])
        cylinder.translate((p1 + p2) / 2)
        
        meshes.append(cylinder)
        
    # Combine all cylinders into one mesh
    if not meshes:
        return None
        
    combined_mesh = meshes[0]
    for m in meshes[1:]:
        combined_mesh += m
        
    return combined_mesh


# Global visualizer cache to balance performance and stability
_global_vis = None
_global_vis_geoms = []
_usage_count = 0
RESET_INTERVAL = 30  # Reset visualizer every 30 uses to prevent memory leaks

def save_o3d_visualization(element, save_path):
    """
    Save the open3d drawing to folder.
    Uses a hybrid approach: Reuse visualizer for N frames, then reset.
    This avoids both black screens (from aggressive reuse) and crashes (from frequent creation).

    Parameters
    ----------
    element : list
        List of o3d.geometry objects.

    save_path : str
        The save path.
    """
    import os
    import gc
    global _global_vis, _global_vis_geoms, _usage_count
    
    # Set environment variable for offscreen rendering
    old_display = os.environ.get('DISPLAY', None)
    if 'microsoft' in os.uname().release.lower() or old_display is None:
        os.environ['OPEN3D_USE_HEADLESS_RENDERING'] = '1'
        if old_display is None:
            os.environ['DISPLAY'] = ':0'
    
    vis = None
    try:
        # Pre-processing: Convert LineSets to Cylinders for better visibility (especially on Windows)
        # This fixes the issue where line_width is ignored on some platforms
        new_elements = []
        for geom in element:
            if isinstance(geom, o3d.geometry.LineSet):
                # radius=0.1 -> 20cm diameter lines, should be very visible
                cylinder_mesh = convert_lineset_to_cylinders(geom, radius=0.1) 
                if cylinder_mesh is not None:
                    new_elements.append(cylinder_mesh)
                else:
                    new_elements.append(geom)
            else:
                new_elements.append(geom)
        element = new_elements

        # 1. Get or create visualizer
        if _global_vis is None:
            _global_vis = o3d.visualization.Visualizer()
            try:
                _global_vis.create_window(visible=False, width=3840, height=2160)
            except TypeError:
                _global_vis.create_window(width=3840, height=2160)
            
            # Set render options
            opt = _global_vis.get_render_option()
            opt.background_color = np.asarray([0.05, 0.05, 0.05])
            opt.point_size = 2
            opt.line_width = 50.0
            
            _global_vis_geoms = []
            _usage_count = 0
        
        vis = _global_vis
        
        # 2. Clear previous geometries (without using remove_geometry which causes black screen)
        # Instead, we rely on the periodic reset to clear memory
        # But we must remove them from the SCENE, otherwise they stack up.
        # If remove_geometry causes black screen, we are in a dilemma.
        # Let's try remove_geometry but with update_renderer immediately after.
        
        for geom in _global_vis_geoms:
            vis.remove_geometry(geom, reset_bounding_box=False)
        _global_vis_geoms = []
        
        # 3. Add new geometries
        vis_elements = element
        for i, geom in enumerate(vis_elements):
            # Reset bounding box on first geometry to ensure camera frustum is valid
            reset_bbox = (i == 0)
            vis.add_geometry(geom, reset_bounding_box=reset_bbox)
            _global_vis_geoms.append(geom)
        
        # 4. Calculate scene bounds for camera
        try:
            all_points = []
            for geom in vis_elements:
                if isinstance(geom, o3d.geometry.PointCloud):
                    points = np.asarray(geom.points)
                    if len(points) > 0:
                        all_points.append(points)
                elif isinstance(geom, o3d.geometry.TriangleMesh):
                    points = np.asarray(geom.vertices)
                    if len(points) > 0:
                        all_points.append(points)
                elif isinstance(geom, o3d.geometry.LineSet):
                    points = np.asarray(geom.points)
                    if len(points) > 0:
                        all_points.append(points)
            
            if len(all_points) > 0:
                all_points = np.vstack(all_points)
                center = np.mean(all_points, axis=0)
            else:
                center = np.array([0, 0, 0])
        except:
            center = np.array([0, 0, 0])
        
        # 5. Update and Render
        # Important: Update geometry if it was modified (though we added fresh ones)
        for geom in vis_elements:
            vis.update_geometry(geom)
        
        vis.poll_events()
        vis.update_renderer()
        
        # Set camera parameters
        ctr = vis.get_view_control()
        ctr.set_lookat(center.tolist())
        ctr.set_front([0.3, 0.3, -0.9])
        ctr.set_up([0, 1, 0])
        ctr.set_zoom(0.2)

        vis.poll_events()
        vis.update_renderer() # Render twice to ensure updates propagate

        # Capture screenshot
        vis.capture_screen_image(save_path, do_render=True)
        
        # 6. Periodic Reset Logic
        _usage_count += 1
        if _usage_count >= RESET_INTERVAL:
            vis.destroy_window()
            _global_vis = None
            _global_vis_geoms = []
            gc.collect()
            
    except Exception as e:
        # Fallback: use matplotlib for visualization
        print(f"Warning: Open3D visualization failed ({str(e)[:100]}), trying matplotlib fallback...")
        try:
            import matplotlib
            matplotlib.use('Agg')  # Use non-interactive backend
            import matplotlib.pyplot as plt
            from mpl_toolkits.mplot3d import Axes3D
            
            # Create matplotlib figure as fallback with higher resolution
            fig = plt.figure(figsize=(24, 18))  # Larger figure size
            ax = fig.add_subplot(111, projection='3d')
            
            # Plot point clouds and boxes
            for geom in element:
                if isinstance(geom, o3d.geometry.PointCloud):
                    points = np.asarray(geom.points)
                    colors = np.asarray(geom.colors) if geom.has_colors() else None
                    if colors is not None and len(colors) > 0:
                        ax.scatter(points[:, 0], points[:, 1], points[:, 2], 
                                 c=colors, s=1.0, alpha=0.7)  # Larger points, more opaque
                    else:
                        ax.scatter(points[:, 0], points[:, 1], points[:, 2], 
                                 s=1.0, alpha=0.7, c='gray')
                elif isinstance(geom, o3d.geometry.LineSet):
                    points = np.asarray(geom.points)
                    lines = np.asarray(geom.lines)
                    colors = np.asarray(geom.colors) if geom.has_colors() else None
                    if colors is None or len(colors) == 0:
                        colors = [(1, 0, 0)] * len(lines)
                    for line, color in zip(lines, colors):
                        pts = points[line]
                        ax.plot3D(pts[:, 0], pts[:, 1], pts[:, 2], 
                                color=color, linewidth=3)  # Thicker lines
            
            ax.set_xlabel('X', fontsize=14)
            ax.set_ylabel('Y', fontsize=14)
            ax.set_zlabel('Z', fontsize=14)
            ax.view_init(elev=20, azim=45)
            ax.set_box_aspect([1,1,0.5])
            
            # Save with higher DPI for better quality
            plt.savefig(save_path, dpi=600, bbox_inches='tight', facecolor='black')
            plt.close(fig)
            print(f"✓ Saved visualization using matplotlib fallback to {save_path}")
            
        except Exception as e2:
            print(f"Error: Both visualization methods failed.")
            print(f"  Open3D error: {str(e)[:100]}")
            print(f"  Matplotlib error: {str(e2)[:100]}")
            print(f"  Skipping visualization for {save_path}")
    
    finally:
        # Restore DISPLAY environment variable
        if old_display is not None:
            os.environ['DISPLAY'] = old_display
        elif 'DISPLAY' in os.environ and old_display is None:
            del os.environ['DISPLAY']


def visualize_bev(batch_data):
    bev_input = batch_data["processed_lidar"]["bev_input"]
    label_map = batch_data["label_dict"]["label_map"]
    if not isinstance(bev_input, np.ndarray):
        bev_input = common_utils.torch_tensor_to_numpy(bev_input)

    if not isinstance(label_map, np.ndarray):
        label_map = label_map[0].numpy() if not label_map[0].is_cuda else \
            label_map[0].cpu().detach().numpy()

    if len(bev_input.shape) > 3:
        bev_input = bev_input[0, ...]

    plt.matshow(np.sum(bev_input, axis=0))
    plt.axis("off")
    plt.matshow(label_map[0, :, :])
    plt.axis("off")
    plt.show()


def draw_box_plt(boxes_dec, ax, color=None, linewidth_scale=1.0):
    """
    draw boxes in a given plt ax
    :param boxes_dec: (N, 5) or (N, 7) in metric
    :param ax:
    :return: ax with drawn boxes
    """
    if not len(boxes_dec)>0:
        return ax
    boxes_np= boxes_dec
    if not isinstance(boxes_np, np.ndarray):
        boxes_np = boxes_np.cpu().detach().numpy()
    if boxes_np.shape[-1]>5:
        boxes_np = boxes_np[:, [0, 1, 3, 4, 6]]
    x = boxes_np[:, 0]
    y = boxes_np[:, 1]
    dx = boxes_np[:, 2]
    dy = boxes_np[:, 3]

    x1 = x - dx / 2
    y1 = y - dy / 2
    x2 = x + dx / 2
    y2 = y + dy / 2
    theta = boxes_np[:, 4:5]
    # bl, fl, fr, br
    corners = np.array([[x1, y1],[x1,y2], [x2,y2], [x2, y1]]).transpose(2, 0, 1)
    new_x = (corners[:, :, 0] - x[:, None]) * np.cos(theta) + (corners[:, :, 1]
              - y[:, None]) * (-np.sin(theta)) + x[:, None]
    new_y = (corners[:, :, 0] - x[:, None]) * np.sin(theta) + (corners[:, :, 1]
              - y[:, None]) * (np.cos(theta)) + y[:, None]
    corners = np.stack([new_x, new_y], axis=2)
    for corner in corners:
        ax.plot(corner[[0,1,2,3,0], 0], corner[[0,1,2,3,0], 1], color=color, linewidth=0.5*linewidth_scale)
        # draw front line (
        ax.plot(corner[[2, 3], 0], corner[[2, 3], 1], color=color, linewidth=2*linewidth_scale)
    return ax


def draw_points_boxes_plt(pc_range, points=None, boxes_pred=None, boxes_gt=None, save_path=None,
                          points_c='y.', bbox_gt_c='green', bbox_pred_c='red', return_ax=False, ax=None):
    if ax is None:
        ax = plt.figure(figsize=(15, 6)).add_subplot(1, 1, 1)
        ax.set_aspect('equal', 'box')
        ax.set(xlim=(pc_range[0], pc_range[3]),
               ylim=(pc_range[1], pc_range[4]))
    if points is not None:
        ax.plot(points[:, 0], points[:, 1], points_c, markersize=0.1)
    if (boxes_gt is not None) and len(boxes_gt)>0:
        ax = draw_box_plt(boxes_gt, ax, color=bbox_gt_c)
    if (boxes_pred is not None) and len(boxes_pred)>0:
        ax = draw_box_plt(boxes_pred, ax, color=bbox_pred_c)
    plt.xlabel('x')
    plt.ylabel('y')

    plt.savefig(save_path)
    if return_ax:
        return ax
    plt.close()


# ==============================================================================
# Standalone t-SNE Visualization Function
# ==============================================================================

def plot_tsne_separation(benign_features, attacked_features, save_path='tsne_feature_separation.png'):
    """
    Visualize the distribution of feature vectors using t-SNE.
    
    This is a completely standalone function that does not depend on any 
    existing class structures. It creates a 2D scatter plot showing the 
    separation between benign and attacked feature distributions.
    
    Parameters
    ----------
    benign_features : numpy.ndarray
        Array of shape (N, D) containing clean/benign feature vectors,
        where N is the number of benign samples and D is the feature dimension.
    attacked_features : numpy.ndarray
        Array of shape (M, D) containing malicious/attacked feature vectors,
        where M is the number of attacked samples and D is the feature dimension.
    save_path : str
        Full path (including filename) where the plot will be saved.
        Default: 'tsne_feature_separation.png'
    
    Returns
    -------
    None
        The plot is saved to disk at the specified path.
    
    Example
    -------
    >>> import numpy as np
    >>> benign = np.random.randn(100, 256)  # 100 benign samples, 256-dim features
    >>> attacked = np.random.randn(50, 256) + 2  # 50 attacked samples, shifted
    >>> plot_tsne_separation(benign, attacked, 'output/tsne_plot.png')
    """
    # Import necessary libraries inside the function for standalone usage
    from sklearn.manifold import TSNE
    import matplotlib
    matplotlib.use('Agg')  # Use non-interactive backend for saving
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    
    # Validate inputs
    if benign_features is None or len(benign_features) == 0:
        print("Warning: benign_features is empty. Skipping t-SNE plot.")
        return
    if attacked_features is None or len(attacked_features) == 0:
        print("Warning: attacked_features is empty. Skipping t-SNE plot.")
        return
    
    # Ensure inputs are numpy arrays
    if not isinstance(benign_features, np.ndarray):
        benign_features = np.array(benign_features)
    if not isinstance(attacked_features, np.ndarray):
        attacked_features = np.array(attacked_features)
    
    # Flatten if necessary (e.g., if shape is (N, C, H, W))
    if benign_features.ndim > 2:
        benign_features = benign_features.reshape(benign_features.shape[0], -1)
    if attacked_features.ndim > 2:
        attacked_features = attacked_features.reshape(attacked_features.shape[0], -1)
    
    # Create labels: 0 for benign, 1 for attacked
    benign_labels = np.zeros(len(benign_features))
    attacked_labels = np.ones(len(attacked_features))
    
    # Concatenate features and labels
    all_features = np.vstack([benign_features, attacked_features])
    all_labels = np.concatenate([benign_labels, attacked_labels])
    
    # Determine perplexity (must be less than n_samples)
    n_samples = len(all_features)
    perplexity = min(30, max(5, n_samples // 4))
    
    print(f"[t-SNE] Running with {n_samples} samples (benign: {len(benign_features)}, attacked: {len(attacked_features)})")
    print(f"[t-SNE] Feature dimension: {all_features.shape[1]}, perplexity: {perplexity}")
    
    # Run t-SNE to reduce to 2D
    tsne = TSNE(
        n_components=2,
        random_state=42,
        perplexity=perplexity,
        n_iter=1000,
        learning_rate='auto',
        init='pca'
    )
    features_2d = tsne.fit_transform(all_features)
    
    # Create scatter plot with a clean, publication-quality style
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot benign points (blue)
    benign_mask = all_labels == 0
    ax.scatter(
        features_2d[benign_mask, 0],
        features_2d[benign_mask, 1],
        c='#3498db',  # Nice blue
        label='Benign',
        alpha=0.7,
        s=60,
        edgecolors='white',
        linewidths=0.5
    )
    
    # Plot attacked points (red)
    attacked_mask = all_labels == 1
    ax.scatter(
        features_2d[attacked_mask, 0],
        features_2d[attacked_mask, 1],
        c='#e74c3c',  # Nice red
        label='Attacked',
        alpha=0.7,
        s=60,
        edgecolors='white',
        linewidths=0.5
    )
    
    # Add legend
    ax.legend(fontsize=14, loc='best', framealpha=0.9)
    
    # Remove axis ticks for cleaner appearance
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Remove spines for cleaner look
    for spine in ax.spines.values():
        spine.set_visible(False)
    
    # Add title
    ax.set_title('t-SNE: Feature Distribution (Benign vs Attacked)', fontsize=16, fontweight='bold')
    
    # Ensure output directory exists
    output_dir = os.path.dirname(save_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    # Save the plot as high-resolution PNG
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    
    # Clear and close the figure to free memory
    plt.clf()
    plt.close(fig)
    
    print(f"[t-SNE] Plot saved to: {save_path}")

