#!/usr/bin/env python3
"""
VGGT validation visualization tool

Provides visualization during training validation:
- Depth visualization (separately for ext1 and ext2)
- Point cloud generation and saving (with wrist origin red sphere)
- Wrist-view point cloud projection
"""

import os
from re import S
import sys
import numpy as np
import cv2
from PIL import Image, ImageDraw
import torch
import torch.nn.functional as F
import trimesh
import matplotlib.pyplot as plt
from pathlib import Path
import json
from typing import Tuple, List, Optional, Dict, Any
import logging
from datetime import datetime
from vggt.utils.pose_enc import extri_intri_to_pose_encoding,pose_encoding_to_extri_intri



class ValidationVisualizer:
    """
    Validation stage visualization tool
    """
    
    def __init__(self, output_base_dir: str = "val_vis", rank: int = 0):
        """
        Visualizer for validation phase
        
        Args:
            output_base_dir: output base directory
            rank: rank in distributed training
        """
        self.output_base_dir = Path(output_base_dir)
        self.rank = rank
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Only visualize on rank 0 to avoid multi-process conflicts
        self.should_visualize = (rank == 0)
        
        # Always initialize directory paths to avoid missing-attribute errors
        # New: create timestamped session directory; each training session uses a dedicated directory
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        if experiment_name is None:
            self.val_vis_base = self.output_base_dir / "validation_visualizations"
        else:
            self.val_vis_base = self.output_base_dir / "validation_visualizations" / experiment_name
        self.val_vis_dir = self.val_vis_base / timestamp
        # self.depth_dir = self.val_vis_dir / "depth_maps"
        # self.pointcloud_dir = self.val_vis_dir / "pointclouds"
        # self.gt_pointcloud_dir = self.val_vis_dir / "gt_pointclouds"  # New: GT point cloud directory
        self.projection_dir = self.val_vis_dir
        if self.should_visualize:
            # Create all required directories
            for dir_path in [self.val_vis_dir]:
                dir_path.mkdir(parents=True, exist_ok=True)
            
            logging.info(f"Validation visualization directory created: {self.val_vis_dir}")
            logging.info(f"Timestamped session directory: {timestamp}")
        
        logging.info(f"ValidationVisualizer initialized (rank={rank}, visualize={self.should_visualize})")
    
    def visualize(self, predictions: Dict, batch: Dict, epoch: int, batch_idx: int) -> Dict:
        """
        Full validation result visualization
        
        Args:
            predictions: model predictions
            batch: input batch data
            epoch: current epoch
            batch_idx: current batch index
            
        Returns:
            dict of generated file paths
        """
        if not self.should_visualize:
            return {}
        
        # Memory management: cleanup before starting
        gc.collect()
        
        # Since batch_size is 24, keep only the 0th sample for visualization
        # Slice all batch data to keep only the 0th sample
        for key in batch:
            if isinstance(batch[key], np.ndarray) and batch[key].ndim > 0:
                batch[key] = batch[key][:1]  # keep only the 0th sample
            if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 0:
                batch[key] = batch[key][:1]  # keep only the 0th sample
            elif isinstance(batch[key], list) and len(batch[key]) > 0:
                batch[key] = batch[key][:1]  # keep only the 0th sample
        
        # Strict assertions - ensure unique shapes
        # 1. Check basic batch structure
        assert "images" in batch, "'images' missing in batch"
        # assert "depths" in batch, "'depths' missing in batch"
        # assert "depth_nan_masks" in batch, "'depth_nan_masks' missing in batch"
        assert "extrinsics" in batch, "'extrinsics' missing in batch"
        assert "intrinsics" in batch, "'intrinsics' missing in batch"
        # assert "point_cloud" in batch, "'point_cloud' missing in batch"
        # assert "point_colors" in batch, "'point_colors' missing in batch"
        
        # 2. Check image data shapes - unique shapes
        images = batch["images"]
        
        # Check batch dimension (now single-sample)
        assert images.ndim == 5, f"Invalid image ndim: {images.ndim}, expect 5 (B,S,C,H,W)"
        # assert depths.ndim == 4, f"Invalid depths ndim: {depths.ndim}, expect 4 (B,S,H,W)"
        # assert depth_nan_masks.ndim == 4, f"Invalid depth_nan_masks ndim: {depth_nan_masks.ndim}, expect 4 (B,S,H,W)"
        
        # Check batch size (should be 1 now)
        assert images.shape[0] == 1, f"Expect batch size 1, got {images.shape[0]}"
        # assert depths.shape[0] == 1, f"Expect batch size 1, got {depths.shape[0]}"
        # assert depth_nan_masks.shape[0] == 1, f"Expect batch size 1, got {depth_nan_masks.shape[0]}"
        
        # Check number of views (sequence length)
        # assert images.shape[1] == 2, f"Expect 2 cameras, got {images.shape[1]}"
        # assert depths.shape[1] == 2, f"Expect 2 depth maps, got {depths.shape[1]}"
        # assert depth_nan_masks.shape[1] == 2, f"Expect 2 depth masks, got {depth_nan_masks.shape[1]}"
        
        # Check image size - images are (B,S,C,H,W)
        assert images.shape[2] == 3, f"Invalid image shape: {images.shape}, expect 3-channel"
        # assert depths.shape[2:] == (294, 518), f"Invalid depths shape: {depths.shape[2:]}, expect (294,518)"
        # assert depth_nan_masks.shape[2:] == (294, 518), f"Invalid depth_nan_masks shape: {depth_nan_masks.shape[2:]}, expect (294,518)"
        
        # 3. Check camera parameter shapes - unique shapes (support arbitrary number of views S)
        extrinsics = batch["extrinsics"]  # (1, S, 3, 4)
        intrinsics = batch["intrinsics"]  # (1, S, 3, 3)
        wrist_extrinsics = batch["wrist_extrinsics"]  # (1, 3, 4)
        wrist_intrinsics = batch["wrist_intrinsics"]  # (1, 3, 3)
        
        # Convert to numpy arrays (if tensors)
        if torch.is_tensor(extrinsics):
            extrinsics = extrinsics.cpu().numpy()
        if torch.is_tensor(intrinsics):
            intrinsics = intrinsics.cpu().numpy()
        if torch.is_tensor(wrist_extrinsics):
            wrist_extrinsics = wrist_extrinsics.cpu().numpy()
        if torch.is_tensor(wrist_intrinsics):
            wrist_intrinsics = wrist_intrinsics.cpu().numpy()
        
        wrist_extrinsics = wrist_extrinsics[0]
        wrist_intrinsics = wrist_intrinsics[0]
        assert extrinsics.ndim == 4 and extrinsics.shape[0] == 1 and extrinsics.shape[2:] == (3, 4), f"wrist extrinsic shape error: {extrinsics.shape}, expect (1,S,3,4)"
        assert intrinsics.ndim == 4 and intrinsics.shape[0] == 1 and intrinsics.shape[2:] == (3, 3), f"wrist intrinsic shape error: {intrinsics.shape}, expect (1,S,3,3)"
        assert wrist_extrinsics.shape == (1, 3, 4), f"wrist extrinsic shape error: {wrist_extrinsics.shape}, expect (1,3,4)"
        assert wrist_intrinsics.shape == (1, 3, 3), f"wrist intrinsic shape error: {wrist_intrinsics.shape}, expect (1,3,3)"
        
        # 4. Check point cloud data - unique shape
        # Use predicted point cloud instead of GT point cloud
        if "world_points" in predictions:
            points_3d = predictions["world_points"][0]
            if torch.is_tensor(points_3d):
                points_3d = points_3d.cpu().numpy()
        else:
            raise ValueError("predictions missing 'world_points'")
        points_3d = points_3d.reshape(-1, 3)  # reshape to (N, 3)
        
        # Get image data for a single sample
        images_sample = images[0]
        
        # Uniformly generate colors: concatenate RGBs from all views
        colors_list = []
        for cam_idx in range(images_sample.shape[0]):
            rgb = images_sample[cam_idx]
            if torch.is_tensor(rgb):
                rgb = rgb.cpu().numpy()
            rgb = np.transpose(rgb, (1, 2, 0))
            assert rgb.shape[-1] == 3, f"Invalid RGB shape: {rgb.shape}"
            colors_list.append(rgb.reshape(-1, 3))
        colors = np.concatenate(colors_list, axis=0)
        
        # Ensure point cloud and color counts match
        assert len(points_3d) == len(colors), f"Point cloud and color counts mismatch: {len(points_3d)} vs {len(colors)}"
        assert points_3d.ndim == 2, f"Invalid point cloud ndim: {points_3d.ndim}, expect 2"
        assert points_3d.shape[1] == 3, f"Invalid point cloud shape: {points_3d.shape}, expect (N,3)"
        assert colors.ndim == 2, f"Invalid color ndim: {colors.ndim}, expect 2"
        assert colors.shape[1] == 3, f"Invalid color shape: {colors.shape}, expect (N,3)"
        
        # 5. Check predictions basic structure
        assert "pose_enc" in predictions, "predictions missing 'pose_enc'"
        assert "wrist_pose_enc" in predictions, "predictions missing 'wrist_pose_enc'"
        
        
        # === Create output directory ===
        timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        output_dir = Path(self.projection_dir) / f"epoch_{epoch}_batch_{batch_idx}_{timestamp}"
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # === Visualization processing ===
        # Use the 0th sample (now only 1 sample in batch)
        sample_idx = 0
        
        # Extract data for the 0th sample
        images_sample = images[sample_idx]
        # depths_sample = depths[sample_idx]
        # depth_nan_masks_sample = depth_nan_masks[sample_idx]
        extrinsics_sample = extrinsics[sample_idx]
        intrinsics_sample = intrinsics[sample_idx]
        
        # === 1. Visualize camera views (arbitrary S) ===
        S = images_sample.shape[0]
        camera_names = [f"ext{i+1}" for i in range(S)]
        camera_indices = list(range(S))
        
        for i, (camera_name, camera_idx) in enumerate(zip(camera_names, camera_indices)):
            # Get data for the current camera
            camera_rgb = images_sample[camera_idx]  # (3, 294, 518)
            # camera_depth = depths_sample[camera_idx]  # (294, 518) - GT depth
            # camera_valid_depth = depth_nan_masks_sample[camera_idx]  # (294, 518)
            camera_extrinsic = extrinsics_sample[camera_idx]  # (3, 4)
            camera_intrinsic = intrinsics_sample[camera_idx]  # (3, 3)
            
            # Get predicted depth
            assert "depth" in predictions, "predictions missing 'depth'"
            pred_depth = predictions["depth"][0, camera_idx].cpu().numpy()  # take 0th sample, camera_idx camera
            pred_depth = pred_depth.squeeze()  # remove last dimension, from (294, 518, 1) to (294, 518)
            
            # Convert to numpy and handle dimensions
            if torch.is_tensor(camera_rgb):
                camera_rgb = camera_rgb.cpu().numpy()
            # if torch.is_tensor(camera_depth):
            #     camera_depth = camera_depth.cpu().numpy()
            # if torch.is_tensor(camera_valid_depth):
            #     camera_valid_depth = camera_valid_depth.cpu().numpy()
            if torch.is_tensor(camera_extrinsic):
                camera_extrinsic = camera_extrinsic.cpu().numpy()
            if torch.is_tensor(camera_intrinsic):
                camera_intrinsic = camera_intrinsic.cpu().numpy()
             
            # Convert RGB format (C, H, W) -> (H, W, C)
            camera_rgb = np.transpose(camera_rgb, (1, 2, 0))  # (294, 518, 3)
            camera_rgb = (camera_rgb * 255).astype(np.uint8)
            all_camera_rgb = images_sample.cpu().numpy()
            all_camera_rgb = np.transpose(all_camera_rgb, (0,2,3,1))
            all_camera_rgb = (all_camera_rgb*255).astype(np.uint8)
            # Create 2x2 grid
            fig, axes = plt.subplots(2, 2, figsize=(12, 10))
            fig.suptitle(f'{camera_name.upper()} Visualization', fontsize=16)
            
            # 1. RGB image
            axes[0, 0].imshow(camera_rgb)
            axes[0, 0].set_title('RGB Image')
            axes[0, 0].axis('off')
            
            # 2. Prediction Depth - use predicted depth
            depth_img = axes[0, 1].imshow(pred_depth, cmap='viridis')
            axes[0, 1].set_title('Prediction Depth')
            axes[0, 1].axis('off')
            plt.colorbar(depth_img, ax=axes[0, 1], fraction=0.046, pad=0.04)
            
            # 3. GT Depth (using valid mask)
            # gt_depth_masked = camera_depth.copy()
            # gt_depth_masked[~camera_valid_depth] = np.nan
            # gt_depth_img = axes[1, 0].imshow(gt_depth_masked, cmap='viridis')
            # axes[1, 0].set_title('GT Depth (Valid)')
            # axes[1, 0].axis('off')
            # plt.colorbar(gt_depth_img, ax=axes[1, 0], fraction=0.046, pad=0.04)
            
            # 4. Point Cloud Projection - use predicted point cloud and predicted camera parameters
            # Get predicted point cloud (if exists)
            assert "world_points" in predictions
            pred_points_3d = predictions["world_points"][0].cpu().numpy()  # take 0th sample
            
            # Uniformly process world_points dimension
            if pred_points_3d.ndim == 3:
                pred_points_3d = pred_points_3d[None, ...]
            assert pred_points_3d.ndim == 4, f"Unexpected world_points shape: {pred_points_3d.shape}"
            pred_points_3d = pred_points_3d[camera_idx].reshape(-1, 3)
            pred_colors = all_camera_rgb[camera_idx].reshape(-1, 3)
            
            # Ensure point cloud and color counts match
            assert len(pred_points_3d) == len(pred_colors), f"Point cloud and color counts mismatch: {len(pred_points_3d)} vs {len(pred_colors)}"
            
            # Get predicted camera parameters (if exists)
            assert "pose_enc_list" in predictions
            # pose_enc shape is [24, 2, 9], where 2 represents two cameras (ext1, ext2)
            # We need to get the corresponding pose based on the current camera index camera_idx
            pred_pose_enc = predictions["pose_enc_list"][ -1][0, camera_idx].cpu()  # take 0th sample, camera_idx camera
            extrinsics,intrinsics = pose_encoding_to_extri_intri(pred_pose_enc.unsqueeze(0).unsqueeze(0),image_size_hw=camera_rgb.shape[:2])
            pred_extrinsic = extrinsics[0,0].numpy()
            pred_intrinsic = intrinsics[0,0].numpy()
            
            
            # Project point cloud
 
            projection = self.visualize_point_cloud_projection(
                points_3d=pred_points_3d,
                point_colors=pred_colors,
                camera_intrinsics=pred_intrinsic,
                camera_extrinsics=pred_extrinsic,
                image_shape=camera_rgb.shape[:2],
                need_inverse=False # For ext1/ext2, it's world2camera, no need to invert
            )
            axes[1, 1].imshow(projection)
            axes[1, 1].set_title('Point Cloud Projection')
            axes[1, 1].axis('off')
            
            # Save 2x2 grid
            output_path = output_dir / f"{camera_name}_grid.png"
            plt.tight_layout()
            plt.savefig(str(output_path), dpi=150, bbox_inches='tight')
            plt.close()
        
        # Save separate projection
        
        # Save separate RGB and depth
        
        # === 2. Visualize GT wrist RGB ===
        if "wrist_image" in batch and batch["wrist_image"] is not None:
            wrist_rgb = batch["wrist_image"]
            
            # Handle batch dimension
            if wrist_rgb.ndim == 4:  # (B, H, W, 3) format
                wrist_rgb = wrist_rgb[sample_idx]  # take 0th sample
            elif wrist_rgb.ndim == 3:  # (H, W, 3) format, already single sample
                pass
        else:
            raise ValueError(f"Invalid wrist_rgb dimension: {wrist_rgb.ndim}, expect 3 or 4")
        
        wrist_rgb = np.array(wrist_rgb.cpu())
        assert wrist_rgb.ndim == 3, f"Invalid wrist_rgb dimension: {wrist_rgb.ndim}, expect 3"
        assert wrist_rgb.shape[2] == 3, f"Invalid wrist_rgb shape: {wrist_rgb.shape}, expect (H,W,3)"
        
        # Ensure color range is correct
        wrist_rgb = (wrist_rgb).astype(np.uint8)
        
        # Save real wrist RGB image
        wrist_rgb_path = output_dir / "wrist_rgb.png"
        plt.imsave(str(wrist_rgb_path), wrist_rgb)  # Use plt.imsave to maintain RGB format
        
        # === 3.1. Visualize wrist projection ===
        # Get predicted wrist pose and point cloud
        if "wrist_pose_enc_list" in predictions and "world_points" in predictions:
            # Get predicted wrist pose
            wrist_pose_enc = predictions["wrist_pose_enc_list"][-1][0].cpu()  # take 0th sample
            # 🔥 NEW: wrist_head now outputs single wrist pose [B, 1, target_dim] instead of [B, S, target_dim]
            wrist_pose_enc = wrist_pose_enc[0]  # take the only (and unique) wrist pose
            wrist_ext,wrist_intrinsics = pose_encoding_to_extri_intri(wrist_pose_enc.unsqueeze(0).unsqueeze(0),image_size_hw=wrist_rgb.shape[:2]) # camera2world
            wrist_ext = wrist_ext[0,0].numpy()
            # Use GT wrist intrinsics instead of predicted intrinsics
            wrist_intrinsics_gt = batch["wrist_intrinsics"][0].cpu().numpy()  # GT intrinsic
            # Get predicted point cloud
            pred_points_3d = predictions["world_points"][0].cpu().numpy()
            pred_points_3d = pred_points_3d.reshape(-1, 3)  
            
            # Ensure point cloud and color counts match
            if len(pred_points_3d) != len(colors):
                logging.warning(f"Wrist projection: Point cloud and color counts mismatch: {len(pred_points_3d)} vs {len(colors)}")
                # If counts mismatch, truncate to smaller count
                min_count = min(len(pred_points_3d), len(colors))
                pred_points_3d = pred_points_3d[:min_count]
                colors = colors[:min_count]
            
            # Project to wrist view
            wrist_projection = self.visualize_point_cloud_projection(
                points_3d=pred_points_3d,
                point_colors=colors,  # Use original point cloud colors
                camera_extrinsics= wrist_ext,
                camera_intrinsics= wrist_intrinsics_gt[0],  # Use GT intrinsic
                image_shape=wrist_rgb.shape[:2],
                need_inverse=False # For wrist, it's world2camera, no need to invert
            )
            
            # Save wrist projection
            wrist_projection_path = output_dir / "wrist_projection.png"
            cv2.imwrite(str(wrist_projection_path), cv2.cvtColor(wrist_projection, cv2.COLOR_RGB2BGR))
            # Create wrist comparison image (original vs projection)
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            fig.suptitle('Wrist Camera: Original vs Projection', fontsize=16)
            
            axes[0].imshow(wrist_rgb)
            axes[0].set_title('Original Wrist RGB')
            axes[0].axis('off')
            
            axes[1].imshow(wrist_projection)
            axes[1].set_title('Point Cloud Projection')
            axes[1].axis('off')
            
            wrist_comparison_path = output_dir / "wrist_comparison.png"
            plt.tight_layout()
            plt.savefig(str(wrist_comparison_path), dpi=150, bbox_inches='tight')
            plt.close()
        
        # === 4. Visualize point cloud (with red and green spheres) ===
        # Create point cloud with red sphere (predicted wrist pose)
        points_with_red_sphere = self._add_wrist_origin_sphere(
            predictions=predictions,  # Pass predictions instead of batch
            points_3d=points_3d,
            colors=colors
        )
        
        # Save point cloud with red sphere
        red_sphere_path = output_dir / "pointcloud_with_red_sphere.glb"
        self._save_point_cloud_as_glb(
            points=points_with_red_sphere["points"],
            colors=points_with_red_sphere["colors"],
            output_path=str(red_sphere_path)
        )
        
        # === 5. Save metadata (uniform, arbitrary views) ===
        metadata = {
            "epoch": epoch,
            "batch_idx": batch_idx,
            "timestamp": timestamp,
            "point_cloud_size": len(points_3d),
            "training_mode": f"multi_view_{images_sample.shape[0]}",
            "cameras": [f"ext{i+1}" for i in range(images_sample.shape[0])] + ["wrist"],
            "image_shapes": {f"ext{i+1}": images_sample[i].shape for i in range(images_sample.shape[0])},
        }
        
        metadata_path = output_dir / "metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        # === Memory management: cleanup after ===
        plt.close('all')
        gc.collect()
        if hasattr(torch, 'cuda') and torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Return generated file paths (uniform)
        result = {
            "output_dir": str(output_dir),
            "metadata": str(metadata_path),
            "wrist_rgb": str(output_dir / "wrist_rgb.png"),
            "pointcloud_with_red_sphere": str(red_sphere_path),
        }
        for i in range(images_sample.shape[0]):
            name = f"ext{i+1}"
            result[f"{name}_grid"] = str(output_dir / f"{name}_grid.png")
            result[f"{name}_rgb"] = str(output_dir / f"{name}_rgb.png")
            result[f"{name}_depth"] = str(output_dir / f"{name}_depth.png")
            result[f"{name}_projection"] = str(output_dir / f"{name}_projection.png")
        
        # If wrist projection exists, add relevant paths
        if "wrist_pose_enc" in predictions and "world_points" in predictions:
            result["wrist_projection"] = str(output_dir / "wrist_projection.png")
            result["wrist_comparison"] = str(output_dir / "wrist_comparison.png")
        
        # If GT wrist pose exists, add green sphere point cloud path
        # if "wrist_extrinsics" in batch and batch["wrist_extrinsics"] is not None:
            # result["gt_pointcloud_with_green_sphere"] = str(green_sphere_path)
        
        # === 6. New: Projection visualization ===
        if "track_pairs" in batch and "wrist_pose_enc_list" in predictions and "world_points" in predictions:
 
            projection_vis_result = self._visualize_projection_tracking(
                predictions=predictions,
                batch=batch,
                output_dir=output_dir,
            )
            result.update(projection_vis_result)
         
        return result
    
    def _pose_to_extrinsics(self, pose_enc: np.ndarray) -> np.ndarray:
        """
        Convert pose encoding to extrinsic matrix
        
        Args:
            pose_enc: pose encoding (6,) or (9,) - supports 6D and 9D formats
                - 6D format: [tx, ty, tz, rx, ry, rz]
                - 9D format: [tx, ty, tz, qx, qy, qz, qw, fov_h, fov_w]
            
        Returns:
            Extrinsic matrix camera to world (3, 4)
        """
        assert pose_enc.shape == (9,), f"Invalid pose_enc shape: {pose_enc.shape}, expect (9,)"
        # 9D format: [tx, ty, tz, qx, qy, qz, qw, fov_h, fov_w]
        translation = pose_enc[:3]  # [tx, ty, tz]
        quaternion = pose_enc[3:7]  # [qx, qy, qz, qw]
        
        # Convert quaternion to rotation matrix
        import cv2
        rotation_matrix = cv2.Rodrigues(quaternion[:3])[0]  # Use first 3 components as rotation vector
        
        # Build extrinsic matrix [R|t]
        extrinsics = np.eye(4)
        extrinsics[:3, :3] = rotation_matrix
        extrinsics[:3, 3] = translation
        
        return extrinsics[:3, :]  # Return (3, 4)
    
    def visualize_point_cloud_projection(
        self,
        points_3d: np.ndarray,
        point_colors: np.ndarray,
        camera_extrinsics: np.ndarray,
        camera_intrinsics: np.ndarray,
        image_shape: Tuple[int, int],
        need_inverse: bool = False
    ) -> np.ndarray:
        """
        Project 3D point cloud to a specified camera view and visualize (sort by distance, draw far first, near last)
        
        Args:
            points_3d: 3D point cloud coordinates (N, 3) - world coordinate system
            point_colors: Point cloud colors (N, 3)
            camera_extrinsics: Camera extrinsic matrix (3, 4) - world2camera transformation matrix
            camera_intrinsics: Camera intrinsic matrix (3, 3) - GT intrinsic
            image_shape: Output image shape (H, W)
            need_inverse: Whether to invert the extrinsic matrix (True for wrist, False for ext1/ext2)
            
        Returns:
            Projected image (H, W, 3)
        """
        import cv2

        H, W = image_shape

        # Check input data
        assert points_3d.ndim == 2, f"Invalid point cloud ndim: {points_3d.ndim}, expect 2"
        assert points_3d.shape[1] == 3, f"Invalid point cloud shape: {points_3d.shape}, expect (N,3)"
        assert point_colors.ndim == 2, f"Invalid color ndim: {point_colors.ndim}, expect 2"
        assert point_colors.shape[1] == 3, f"Invalid color shape: {point_colors.shape}, expect (N,3)"
        assert len(points_3d) == len(point_colors), f"Point cloud and color counts mismatch: {len(points_3d)} vs {len(point_colors)}"
        assert camera_extrinsics.shape == (3, 4), f"Invalid camera extrinsic shape: {camera_extrinsics.shape}, expect (3,4)"
        assert camera_intrinsics.shape == (3, 3), f"Invalid camera intrinsic shape: {camera_intrinsics.shape}, expect (3,3)"

        # Decide whether to invert based on need_inverse parameter
        if need_inverse:
            # wrist projection: camera_extrinsics is camera2world transformation matrix, need to invert to get world2camera
            camera_ext_4x4 = np.vstack([camera_extrinsics, [0, 0, 0, 1]])  # Extend to 4x4 homogeneous coordinate matrix
            world2camera_ext = np.linalg.inv(camera_ext_4x4)[:3, :4]  # Invert to get world2camera transformation
        else:
            # ext1/ext2 projection: camera_extrinsics is already world2camera transformation matrix, use directly
            world2camera_ext = camera_extrinsics

        # Create output image
        image = np.full((H, W, 3), (0, 0, 0), dtype=np.uint8)

        # Vectorize processing of all 3D points
        # Convert to homogeneous coordinates [N, 4]
        points_homo = np.concatenate([points_3d, np.ones((len(points_3d), 1))], axis=1)

        # Project to camera coordinate system [N, 3]
        points_cam = (world2camera_ext @ points_homo.T).T

        # Depth filtering mask
        depth_mask = points_cam[:, 2] > 0.01
        if not np.any(depth_mask):
            return image

        # Apply depth filtering
        points_cam = points_cam[depth_mask]
        point_colors = point_colors[depth_mask]

        # Project to image plane [N, 2]
        points_2d = points_cam[:, :2] / points_cam[:, 2:3]

        # Apply intrinsic [N, 2]
        points_2d_homo = np.concatenate([points_2d, np.ones((len(points_2d), 1))], axis=1)
        points_pixel = (camera_intrinsics @ points_2d_homo.T).T
        projected_uv = points_pixel[:, :2]

        # Boundary check mask
        u_mask = (projected_uv[:, 0] >= 0) & (projected_uv[:, 0] < W)
        v_mask = (projected_uv[:, 1] >= 0) & (projected_uv[:, 1] < H)
        boundary_mask = u_mask & v_mask

        if not np.any(boundary_mask):
            return image

        # Apply boundary filtering
        projected_uv = projected_uv[boundary_mask]
        point_colors = point_colors[boundary_mask]
        points_cam = points_cam[boundary_mask]

        # Sort by distance (farther first, then closer)
        z_vals = points_cam[:, 2]
        sort_idx = np.argsort(z_vals)[::-1]  # from far to near (z large to z small)
        projected_uv = projected_uv[sort_idx]
        point_colors = point_colors[sort_idx]

        # Convert to integer coordinates
        u_coords = projected_uv[:, 0].astype(int)
        v_coords = projected_uv[:, 1].astype(int)

        # Handle color format (vectorized)
        if point_colors.max() < 2:
            point_colors = (point_colors * 255).astype(np.uint8)

        # Draw points in order (farther first, then closer)
        for i in range(len(u_coords)):
            u, v = u_coords[i], v_coords[i]
            color = point_colors[i].tolist()
            cv2.circle(image, (u, v), 2, color, -1)

        valid_count = len(u_coords)

        return image
    
    def _add_wrist_origin_sphere(self, predictions: Dict, points_3d: np.ndarray, colors: np.ndarray) -> Dict:
        """
        Add a red sphere to the point cloud to represent the predicted wrist origin
        
        Args:
            predictions: predictions data containing wrist pose predictions
            points_3d: original point cloud coordinates (N, 3)
            colors: original point cloud colors (N, 3)
            
        Returns:
            Point cloud data with red sphere {"points": ..., "colors": ...}
        """
        # Get predicted wrist pose
        wrist_pose_enc = predictions.get("wrist_pose_enc")
        if wrist_pose_enc is None:
            return {"points": points_3d, "colors": colors}
        
        # Handle batch dimension and pose encoding format
        if torch.is_tensor(wrist_pose_enc):
            wrist_pose_enc = wrist_pose_enc.cpu().numpy()
        
        # 🔥 NEW: wrist_head now outputs single wrist pose [B, 1, target_dim] instead of [B, S, target_dim]
        wrist_pose_enc = wrist_pose_enc[0]  # take 0th sample
        assert wrist_pose_enc.shape == (1, 9), f"Invalid wrist_pose_enc shape: {wrist_pose_enc.shape}, expect (1,9)"
        # Now only one wrist pose, 9D format
        wrist_pose = wrist_pose_enc[0]  # take the only wrist pose, 9D format
        
        # Convert to camera-to-world extrinsic matrix
        wrist_ext = self._pose_to_extrinsics(wrist_pose)
        
        # Extract wrist origin position
        # wrist_ext is camera2world transformation matrix T_wc
        wrist_origin = wrist_ext[:3, 3]  # take translation part of inverse matrix
        
        # Generate red sphere point cloud
        sphere_points, sphere_colors = self._generate_sphere_points(
            center=wrist_origin,
            radius=0.05,  # 5cm radius
            color=(255, 0, 0),  # Red
            num_points=100
        )
        
        
        # Combine original point cloud and red sphere
        combined_points = np.vstack([points_3d, sphere_points])
        combined_colors = np.vstack([colors, sphere_colors])
        
        return {"points": combined_points, "colors": combined_colors}
    
    def _add_gt_wrist_origin_sphere(self, batch: Dict, points_3d: np.ndarray, colors: np.ndarray) -> Dict:
        """
        Add a green sphere to the point cloud to represent the GT wrist origin
        
        Args:
            batch: batch data containing GT wrist pose
            points_3d: original point cloud coordinates (N, 3)
            colors: original point cloud colors (N, 3)
            
        Returns:
            Point cloud data with green sphere {"points": ..., "colors": ...}
        """
        # Get GT wrist pose
        wrist_extrinsics = batch.get("wrist_extrinsics")
        if wrist_extrinsics is None:
            return {"points": points_3d, "colors": colors}
        
        wrist_ext = wrist_extrinsics[0][0].cpu().numpy()
        
        assert wrist_ext.shape == (3, 4), f"Invalid wrist extrinsic shape: {wrist_ext.shape}, expect (3,4)"
        
        # Extract GT wrist origin position
        # wrist_ext is world2camera transformation matrix T_wc
        # To get the camera's position in the world coordinate system, we need to invert: T_cw = inv(T_wc)
        # Camera position = T_cw * [0,0,0,1] = Translation part of T_cw
        wrist_ext_4x4 = np.vstack([wrist_ext, [0, 0, 0, 1]])  # Extend to 4x4 homogeneous coordinate matrix
        wrist_ext_inv = np.linalg.inv(wrist_ext_4x4)  # Invert to get camera2world transformation
        gt_wrist_origin = wrist_ext_inv[:3, 3]  # take translation part of inverse matrix
        
        # Generate green sphere point cloud
        sphere_points, sphere_colors = self._generate_sphere_points(
            center=gt_wrist_origin,
            radius=0.05,  # 5cm radius
            color=(0, 255, 0),  # Green
            num_points=100
        )
        
        
        # Combine original point cloud and green sphere
        combined_points = np.vstack([points_3d, sphere_points])
        combined_colors = np.vstack([colors, sphere_colors])
        
        return {"points": combined_points, "colors": combined_colors}
    
    def _generate_sphere_points(self, center: np.ndarray, radius: float, color: Tuple[int, int, int], num_points: int = 100) -> Tuple[np.ndarray, np.ndarray]:
        """
        Generate a sphere point cloud
        
        Args:
            center: Sphere center coordinates (3,)
            radius: Sphere radius
            color: Sphere color (R, G, B)
            num_points: Number of points in the sphere
            
        Returns:
            Sphere point cloud coordinates and colors
        """
        # Generate points uniformly distributed on the sphere surface
        phi = np.linspace(0, 2*np.pi, int(np.sqrt(num_points)))
        theta = np.linspace(0, np.pi, int(np.sqrt(num_points)))
        phi, theta = np.meshgrid(phi, theta)
        
        # Spherical to Cartesian coordinates
        x = radius * np.sin(theta) * np.cos(phi)
        y = radius * np.sin(theta) * np.sin(phi)
        z = radius * np.cos(theta)
        
        # Flatten and add center offset
        sphere_points = np.concatenate([x.flatten(), y.flatten(), z.flatten()]).reshape(-1, 3)
        sphere_points = sphere_points + center  # Use numpy broadcasting
        
        # Generate colors
        sphere_colors = np.full((len(sphere_points), 3), color, dtype=np.uint8)
        
        return sphere_points, sphere_colors
    
    def _save_point_cloud_as_glb(self, points: np.ndarray, colors: np.ndarray, output_path: str):
        """
        Save point cloud as GLB format
        
        Args:
            points: Point cloud coordinates (N, 3)
            colors: Point cloud colors (N, 3)
            output_path: Output file path
        """
        # Create trimesh point cloud object
        point_cloud = trimesh.PointCloud(
            vertices=points,
            colors=colors
        )
        
        # Export as GLB format
        point_cloud.export(output_path)
    
    def _visualize_projection_tracking(
        self,
        predictions: Dict,
        batch: Dict,
        output_dir: Path,
    ) -> Dict[str, str]:
        """
        Visualize projection tracking results
        
        Args:
            predictions: Model predictions
            batch: Input batch data
            output_dir: Output directory
            
        Returns:
            dict of generated file paths
        """
        import cv2
        import numpy as np
        from vggt.utils.pose_enc import pose_encoding_to_extri_intri
        is_single_view_data = False
        if "single_view_training" in batch:
            is_single_view_data = batch["single_view_training"][0].item() if torch.is_tensor(batch["single_view_training"]) else batch["single_view_training"][0]
        
        # Only process data for batch_index=0
        track_pairs = batch["track_pairs"]
        if len(track_pairs.get("wrist_uv", [])) == 0:
            print("no track pairs")
            return {}
        
        # Filter out track pairs for batch_index=0
        if "batch_indices" in track_pairs:
            batch_indices = track_pairs["batch_indices"]
            batch_0_mask = [i == 0 for i in batch_indices]
            
            # Extract data for batch_0
            wrist_uv_batch0 = [track_pairs["wrist_uv"][i] for i, mask in enumerate(batch_0_mask) if mask]
            
            # 🎯 New data structure: use unified ext_uv field
            if "ext_uv" in track_pairs:
                ext_uv_batch0 = [track_pairs["ext_uv"][i] for i, mask in enumerate(batch_0_mask) if mask]
            else:
                # Old compatible data structure
                ext1_uv_batch0 = [track_pairs["ext1_uv"][i] for i, mask in enumerate(batch_0_mask) if mask]
                ext2_uv_batch0 = [track_pairs["ext2_uv"][i] for i, mask in enumerate(batch_0_mask) if mask]
            pair_type_batch0 = [track_pairs["pair_type"][i] for i, mask in enumerate(batch_0_mask) if mask]
            
            confidence_batch0 = [track_pairs["confidence"][i] for i, mask in enumerate(batch_0_mask) if mask]
        # print(len(wrist_uv_batch0),len(track_pairs["wrist_uv"]))
        if len(wrist_uv_batch0) == 0:
            return {}
        
        # Get wrist RGB image
        wrist_rgb = batch["wrist_image"][0].cpu().numpy()  # take batch_0
        wrist_rgb = wrist_rgb.astype(np.uint8)
        
        # Resize wrist_rgb from 1280x720 to 518x294
        # print(wrist_rgb.shape)
        # if wrist_rgb.shape[1] == 1280:
        wrist_rgb = cv2.resize(wrist_rgb, (518, 294))
        # else:
        #     wrist_rgb = cv2.resize(wrist_rgb, (518, 518))
        
        H, W = wrist_rgb.shape[:2]
        
        # Get predicted wrist pose and GT wrist intrinsics
        wrist_pose_enc = predictions["wrist_pose_enc_list"][-1][0]  # [1, 9] - batch_0
        wrist_pose_enc = wrist_pose_enc[0]  # [9] - take the only wrist pose
        
        # Convert to extrinsic and intrinsic
        wrist_ext_pred, _ = pose_encoding_to_extri_intri(
            wrist_pose_enc.unsqueeze(0).unsqueeze(0),  # [1, 1, 9]
            image_size_hw=(H, W),
            build_intrinsics=False
        )
        wrist_ext_pred = wrist_ext_pred[0, 0].cpu().numpy()  # [3, 4]
        
        # Use GT wrist intrinsics
        wrist_intrinsics_gt = batch["wrist_intrinsics"][0].cpu().numpy()  # [3, 3] - batch_0
        
        # Get predicted world points
        world_points = predictions["world_points"][0]  # [2, H, W, 3] - batch_0
        
        # Create visualization image
        # 1. Real wrist view + track points
        wrist_with_tracks = wrist_rgb.copy()
        
        # 2. Generate wrist point cloud projection
        wrist_projection = self._generate_wrist_point_cloud_projection(
            predictions=predictions,
            batch=batch,
            image_shape=(H, W)
        )
        
        # 3. Comparison image (left-right concatenation)
        comparison_img = np.zeros((H, W*2, 3), dtype=np.uint8)
        # Process each track pair
        valid_projections = 0
        total_pairs = len(wrist_uv_batch0)
        comparison_img[:, :W] = wrist_with_tracks  # Left half: real wrist view + track points
        comparison_img[:, W:] = wrist_projection   # Right half: wrist point cloud projection
        
        for i in range(total_pairs):
            wrist_uv = wrist_uv_batch0[i]
            pair_type = pair_type_batch0[i]
            confidence = confidence_batch0[i]
            
            # Skip low confidence points
            if confidence < 0.1:
                continue
            
            # Draw points on real wrist view
            wrist_u, wrist_v = int(wrist_uv[0]), int(wrist_uv[1])
            if 0 <= wrist_u < W and 0 <= wrist_v < H:
                cv2.circle(comparison_img, (wrist_u, wrist_v), 3, (0, 255, 0), -1)  # Green point
            ext_uv = ext_uv_batch0[i]
            world_points_seq = pair_type
            
            # Skip invalid ext UV coordinates
            if ext_uv[0] < 0 or ext_uv[1] < 0:
                continue
            
            # Get 3D point from world points
            try:
                point_3d = self._get_interpolated_3d_point_numpy(
                    world_points[world_points_seq].cpu().numpy(),  # [H, W, 3]
                    ext_uv[0],  # u coordinate
                    ext_uv[1]   # v coordinate
                )
                
                # Project to wrist view
                projected_uv, depth, is_valid = self._project_3d_to_wrist_numpy(
                    point_3d,
                    wrist_ext_pred,
                    wrist_intrinsics_gt
                )
                if is_valid:
                    
                    # print(is_valid,projected_uv)
                    # Draw points on projected image
                    proj_u, proj_v = int(projected_uv[0][0]), int(projected_uv[0][1])
                    if 0 <= proj_u < W and 0 <= proj_v < H:
                        cv2.circle(comparison_img, (W + proj_u, proj_v), 3, (255, 0, 0), -1)  # Red point
                        cv2.circle(comparison_img, (wrist_u, wrist_v), 3, (0, 255, 0), -1)  # Green point
                        cv2.line(comparison_img, (wrist_u, wrist_v), (W + proj_u, proj_v), (255, 255, 255), 1)
                        
                        valid_projections += 1
                
            except Exception as e:
                print(e)
                import traceback
                traceback.print_exc()
                continue
        
        # Save image
        comparison_path = output_dir / "wrist_tracking_comparison.png"
        
        cv2.imwrite(str(comparison_path), cv2.cvtColor(comparison_img, cv2.COLOR_RGB2BGR))
        
        return {
            "wrist_tracking_comparison": str(comparison_path),
            "valid_projections": valid_projections,
            "total_track_pairs": total_pairs
        }
    
    def _get_interpolated_3d_point_numpy(
        self,
        world_points_map: np.ndarray,
        u: float,
        v: float
    ) -> np.ndarray:
        """
        Get 3D point using bilinear interpolation with numpy
        
        Args:
            world_points_map: 3D world point cloud map [H, W, 3]
            u: U coordinate (float)
            v: V coordinate (float)
            
        Returns:
            3D point [3]
        """
        H, W, _ = world_points_map.shape
        
        # Get indices of four corner points
        u0, u1 = int(np.floor(u)), int(np.ceil(u))
        v0, v1 = int(np.floor(v)), int(np.ceil(v))
        
        # Calculate interpolation weights
        wu = u - u0
        wv = v - v0
        
        # Handle boundary cases
        if u0 == u1:
            wu = 0.0
        if v0 == v1:
            wv = 0.0
        
        # Get four corner points (handle boundaries)
        u0_clamped = max(0, min(u0, W-1))
        u1_clamped = max(0, min(u1, W-1))
        v0_clamped = max(0, min(v0, H-1))
        v1_clamped = max(0, min(v1, H-1))
        
        p00 = world_points_map[v0_clamped, u0_clamped, :]
        p01 = world_points_map[v0_clamped, u1_clamped, :]
        p10 = world_points_map[v1_clamped, u0_clamped, :]
        p11 = world_points_map[v1_clamped, u1_clamped, :]
        
        # Bilinear interpolation
        point_3d = (1-wu)*(1-wv)*p00 + wu*(1-wv)*p01 + (1-wu)*wv*p10 + wu*wv*p11
        
        return point_3d
    
    def _project_3d_to_wrist_numpy(
        self,
        point_3d: np.ndarray,
        wrist_ext: np.ndarray,
        wrist_intrinsics: np.ndarray
    ) -> Tuple[np.ndarray, float, bool]:
        """
        Project 3D point to wrist camera view using numpy
        
        Args:
            point_3d: 3D point in world coordinate system [3]
            wrist_ext: wrist camera extrinsic matrix [3, 4] - world2camera transformation matrix
            wrist_intrinsics: wrist camera intrinsic matrix [3, 3]
            
        Returns:
            Tuple of (projected_uv, depth, is_valid)
            - projected_uv: [2] - projected UV coordinates
            - depth: Depth value
            - is_valid: Whether the projection is valid
        """
        # Invert extrinsic: from camera2world to world2camera
        wrist_ext_4x4 = np.vstack([wrist_ext, [0, 0, 0, 1]])  # Extend to 4x4 homogeneous coordinate matrix
        world2camera_ext = wrist_ext_4x4[:3, :4]  # Invert to get world2camera transformation
        
        # Convert to homogeneous coordinates
        point_homo = np.append(point_3d, 1.0)  # [4]
        
        # Project to camera coordinate system
        point_cam = world2camera_ext @ point_homo  # [3]
        
        # Check if depth is positive
        depth = point_cam[2]
        if depth <= 0.01:
            return np.array([0, 0]), depth, False
        
        # Project to image plane
        point_2d = point_cam[:2] / depth  # [2]
        
        # Apply intrinsic
        point_2d_homo = np.append(point_2d, 1.0)  # [3]
        point_pixel = wrist_intrinsics @ point_2d_homo  # [3]
        projected_uv = point_pixel[:2]  # [2]
        
        return projected_uv, depth, True
    
    def _generate_wrist_point_cloud_projection(
        self,
        predictions: Dict,
        batch: Dict,
        image_shape: Tuple[int, int]
    ) -> np.ndarray:
        """
        Generate wrist point cloud projection image
        
        Args:
            predictions: Model predictions
            batch: Input batch data
            image_shape: Output image shape (H, W)
            
        Returns:
            wrist point cloud projection image (H, W, 3)
        """
        H, W = image_shape
        
        # Get predicted wrist pose and world points
        wrist_pose_enc = predictions["wrist_pose_enc_list"][-1][0]  # [1, 9] - batch_0
        wrist_pose_enc = wrist_pose_enc[0]  # [9] - take the only wrist pose
        
        # Convert to extrinsic and intrinsic
        wrist_ext_pred, _ = pose_encoding_to_extri_intri(
            wrist_pose_enc.unsqueeze(0).unsqueeze(0),  # [1, 1, 9]
            image_size_hw=(H, W),
            build_intrinsics=False
        )
        wrist_ext_pred = wrist_ext_pred[0, 0].cpu().numpy()  # [3, 4]
        
        # Use GT wrist intrinsics
        wrist_intrinsics_gt = batch["wrist_intrinsics"][0].cpu().numpy()  # [3, 3] - batch_0
        
        # Get predicted world points
        world_points = predictions["world_points"][0]  # [2, H, W, 3] - batch_0
        
        # Combine world points from all views
        all_world_points = world_points.reshape(-1, 3)  # [N, 3]
        
        # Generate colors (using RGB from all views)
        images = batch["images"][0]  # [S, C, H, W] - batch_0
        colors_list = []
        for cam_idx in range(images.shape[0]):
            rgb = images[cam_idx]
            if torch.is_tensor(rgb):
                rgb = rgb.cpu().numpy()
            rgb = np.transpose(rgb, (1, 2, 0))
            assert rgb.shape[-1] == 3, f"Invalid RGB shape: {rgb.shape}"
            colors_list.append(rgb.reshape(-1, 3))
        colors = np.concatenate(colors_list, axis=0)
        
        # Ensure point cloud and color counts match
        if len(all_world_points) != len(colors):
            logging.warning(f"Wrist projection: Point cloud and color counts mismatch: {len(all_world_points)} vs {len(colors)}")
            # If counts mismatch, truncate to smaller count
            min_count = min(len(all_world_points), len(colors))
            all_world_points = all_world_points[:min_count]
            colors = colors[:min_count]
        
        # Project to wrist view
        wrist_projection = self.visualize_point_cloud_projection(
            points_3d=all_world_points.cpu().numpy(),
            point_colors=colors,
            camera_extrinsics=wrist_ext_pred,
            camera_intrinsics=wrist_intrinsics_gt[0],  # Use GT intrinsic
            image_shape=(H, W),
            need_inverse=False  # For wrist, it's world2camera, no need to invert
        )
        
        return wrist_projection