from dataclasses import dataclass, field
from typing import Tuple
import numpy as np
import torch 
import tree
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
from vggt.models.vggt import VGGT
from ultralytics import YOLO
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from PIL import Image
import torchvision.transforms.functional as TF
from collections import defaultdict
import hdbscan
from tqdm import tqdm 
from torch import nn
from torch.nn.utils.rnn import pad_sequence
import math
import torchvision.ops as ops
import time
from .action_head.flow_matching_action_head import (
    FlowmatchingActionHead,
    FlowmatchingActionHeadConfig,
)
from .backbone import EagleBackbone
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
@dataclass
class GR00T_N1_5_Config(PretrainedConfig):
    model_type = "gr00t_n1_5"
    backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
    action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
    action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
    action_dim: int = field(init=False, metadata={"help": "Action dimension."})
    compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for key, value in kwargs.items():
            setattr(self, key, value)
class SOMA(PreTrainedModel):
    supports_gradient_checkpointing = True
    config_class = GR00T_N1_5_Config
    """
    we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
    here n is variable and can be e.g. time, 1 or user specified
    we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
    we expect these to have type BatchFeature, and they can of course have many other user specified keys too
    """
    def __init__(
        self,
        config: GR00T_N1_5_Config,
        local_model_path: str,
    ):
        assert isinstance(config.backbone_cfg, dict)
        assert isinstance(config.action_head_cfg, dict)
        super().__init__(config)
        self.local_model_path = local_model_path
        self.backbone = EagleBackbone(**config.backbone_cfg)
        action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
        self.action_head = FlowmatchingActionHead(action_head_cfg)
        self.action_horizon = config.action_horizon
        self.action_dim = config.action_dim
        self.compute_dtype = config.compute_dtype
        self.n_layer = None
        self.perception_head = None
        self.geometry_head = None
        self.det_head = None
        self.memory_mlp_hidden_dim = 384
        self.bbox_embedding_mlp = torch.nn.Sequential(
            torch.nn.LayerNorm(8 * 3),
            torch.nn.Linear(8 * 3, self.memory_mlp_hidden_dim*2),
            torch.nn.GELU(),
            torch.nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            torch.nn.GELU(),
        )
        self.spatial_pos_mlp = torch.nn.Sequential(
            torch.nn.LayerNorm(self.memory_mlp_hidden_dim),
            torch.nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim*2),
            torch.nn.GELU(),
            torch.nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            torch.nn.GELU(),
        )
        self.memory_mlp = torch.nn.Sequential(
            torch.nn.LayerNorm(self.memory_mlp_hidden_dim),
            torch.nn.Linear(self.memory_mlp_hidden_dim, self.memory_mlp_hidden_dim*2),
            torch.nn.GELU(),
            torch.nn.Linear(self.memory_mlp_hidden_dim*2, self.memory_mlp_hidden_dim),
            torch.nn.GELU(),
        )
        from .action_head.cross_attention_dit import SelfAttentionTransformerMask
        attn_cfg = {'attention_head_dim': 64, 
                    'dropout': 0.2, 
                    'final_dropout': True, 
                    'num_attention_heads': 32, 
                    'num_layers': 1, 
                    'positional_embeddings': None}
        self.boosting_ln = nn.LayerNorm(action_head_cfg.backbone_embedding_dim)
        self.boosting_transformer = SelfAttentionTransformerMask(**attn_cfg)
        self.memory_alignment_mlp = torch.nn.Sequential(
            torch.nn.LayerNorm(self.memory_mlp_hidden_dim),
            torch.nn.Linear(self.memory_mlp_hidden_dim, action_head_cfg.backbone_embedding_dim),
            torch.nn.GELU()
        )
        self.current_overview_labels = None
        self.current_bbox_3d_labels = None
    def init_weights(self):
        self.bbox_embedding_mlp.apply(self._init_layer_weight)
        self.spatial_pos_mlp.apply(self._init_layer_weight)
        self.memory_mlp.apply(self._init_layer_weight)
        self.memory_alignment_mlp.apply(self._init_layer_weight)
        self.boosting_transformer.apply(self._init_layer_weight)
    def _init_layer_weight(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if m.bias is not None:
                fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                torch.nn.init.uniform_(m.bias, -bound, bound)
        elif isinstance(m, torch.nn.Embedding):
            torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.ones_(m.weight)
            torch.nn.init.zeros_(m.bias)
    def embed_bbox_3d(self, bbox_3d_corners):
        """
        Args:
            bbox_3d_corners: torch.Tensor, shape (N, 8, 3)
        Returns:
            torch.Tensor, shape (N, bbox_embedding_dim)
        """
        if isinstance(bbox_3d_corners, np.ndarray):
            bbox_3d_corners = torch.from_numpy(bbox_3d_corners).float().to(self.device)
        elif not isinstance(bbox_3d_corners, torch.Tensor):
            raise TypeError("bbox_3d_corners must be torch.Tensor or np.ndarray")
        N = bbox_3d_corners.shape[0]
        x = bbox_3d_corners.reshape(N, 8 * 3)
        return self.bbox_embedding_mlp(x)
    def process_spatial_memory(self, input):
        """
        Args:
            overview_memory: torch.Tensor, shape (N, memory_dim)
        Returns:
            torch.Tensor, shape (N, memory_dim)
        """
        batchsize = len(input['overview_memory'])
        for i in range(batchsize):
            overview_memory = input['overview_memory'][i]
            overview_labels = input['overview_labels'][i]
            bbox_3d_corners = input['bbox_3d_corners'][i]
            bbox_3d_labels = input['bbox_3d_labels'][i]
            if isinstance(overview_memory, np.ndarray):
                overview_memory = torch.from_numpy(overview_memory).float().to(self.device)
            elif not isinstance(overview_memory, torch.Tensor):
                raise TypeError("overview_memory must be torch.Tensor or np.ndarray")
            bbox_pos_embedding = self.embed_bbox_3d(bbox_3d_corners)
            grouped_spatial_embeddings = []
            for label in overview_labels:
                mask = (bbox_3d_labels == label)
                if mask.any():
                    mean_embedding = bbox_pos_embedding[mask].mean(dim=0)
                else:
                    mean_embedding = torch.zeros_like(bbox_pos_embedding[0])
                grouped_spatial_embeddings.append(mean_embedding)
            grouped_spatial_embeddings = torch.stack(grouped_spatial_embeddings, dim=0)
            spatial_pos_cls = self.spatial_pos_mlp(grouped_spatial_embeddings)
            process_overview_memory = self.memory_mlp(overview_memory) + spatial_pos_cls
            final_spatial_memory, final_spatial_memory_labels = self.fused_current_obs(input['raw_images'][i], process_overview_memory, overview_labels)
            input['overview_memory'][i] = final_spatial_memory
            input['overview_labels'][i] = final_spatial_memory_labels
        return input
    def batched_bbox_to_3d_aabb_robust(
        self,
        bboxes: torch.Tensor,
        world_points: torch.Tensor,
        initial_transformation: torch.Tensor,
        z_percentile: float = 0.25,
        z_margin: float = 0.2
    ) -> torch.Tensor | None:
        """
        使用批次化操作，稳健地从一批2D边界框中获取3D轴对齐包围盒（AABB）的8个角点。
        整个过程在GPU上完成。
        参数:
        bboxes: torch.Tensor, shape (N, 4)，包含N个[x1, y1, x2, y2]格式的边界框。
        world_points: torch.Tensor, shape (H, W, 3)，当前图片的点云（原始世界坐标）。
        initial_transformation: torch.Tensor, shape (4, 4)，统一空间变换矩阵。
        z_percentile: float, 用于确定物体主体深度的百分位数。
        z_margin: float, 允许的深度范围（米）。
        返回:
        (N, 8, 3) torch.Tensor，代表N个3D包围盒在统一空间下的8个角点。
        如果所有框都无效，则返回 None。
        """
        if bboxes.shape[0] == 0:
            return None
        device = bboxes.device
        num_boxes = bboxes.shape[0]
        world_points_permuted = world_points.permute(2, 0, 1).unsqueeze(0)
        box_indices = torch.zeros((num_boxes, 1), device=device)
        boxes_for_roi = torch.cat([box_indices, bboxes], dim=1)
        output_size = (32, 32)
        rois = ops.roi_align(
            world_points_permuted,
            boxes_for_roi,
            output_size=output_size,
            spatial_scale=1.0,
            sampling_ratio=-1
        )
        points_list = rois.permute(0, 2, 3, 1).reshape(num_boxes, -1, 3)
        valid_points_mask = torch.any(points_list != 0, dim=2)
        z_values = points_list[:, :, 2].clone()
        z_values[~valid_points_mask] = 1e9
        sorted_z, _ = torch.sort(z_values, dim=1)
        num_valid = valid_points_mask.sum(dim=1)
        num_valid[num_valid == 0] = 1 
        percentile_idx = (num_valid * z_percentile).long()
        main_body_z = torch.gather(sorted_z, 1, percentile_idx.unsqueeze(1)).squeeze(1)
        foreground_mask = (torch.abs(points_list[:, :, 2] - main_body_z.unsqueeze(1)) < z_margin) & valid_points_mask
        points_for_aabb = points_list.clone()
        points_for_aabb[~foreground_mask] = torch.tensor([float('inf'), float('inf'), float('inf')], device=device)
        min_corners, _ = torch.min(points_for_aabb, dim=1)
        points_for_aabb[~foreground_mask] = torch.tensor([float('-inf'), float('-inf'), float('-inf')], device=device)
        max_corners, _ = torch.max(points_for_aabb, dim=1)
        valid_box_mask = min_corners[:, 0] != float('inf')
        if not torch.any(valid_box_mask):
            return None
        min_corners = min_corners[valid_box_mask]
        max_corners = max_corners[valid_box_mask]
        x_min, y_min, z_min = min_corners.T
        x_max, y_max, z_max = max_corners.T
        corners_local = torch.stack([
            torch.stack([x_min, y_min, z_min], dim=1),
            torch.stack([x_max, y_min, z_min], dim=1),
            torch.stack([x_max, y_max, z_min], dim=1),
            torch.stack([x_min, y_max, z_min], dim=1),
            torch.stack([x_min, y_min, z_max], dim=1),
            torch.stack([x_max, y_min, z_max], dim=1),
            torch.stack([x_max, y_max, z_max], dim=1),
            torch.stack([x_min, y_max, z_max], dim=1),
        ], dim=1)
        corners_local_homo = torch.cat([corners_local, torch.ones(corners_local.shape[0], 8, 1, device=device)], dim=2)
        transformed_corners_homo = corners_local_homo @ initial_transformation.T
        transformed_corners = transformed_corners_homo[:, :, :3]
        return transformed_corners, valid_box_mask
    def filtering_keys(self, input): 
        self.current_overview_labels = input['overview_labels']
        self.current_bbox_3d_labels = input['bbox_3d_labels']
        if 'overview_labels' in input:
            del input['overview_labels']
        if 'bbox_3d_labels' in input:
            del input['bbox_3d_labels']
        if "raw_images" in input:
            del input["raw_images"]
        return input
    def validate_inputs(self, inputs):
        detected_error = False
        error_msg = ERROR_MSG
        if "action" in inputs:
            action = inputs["action"]
            type_ok = isinstance(action, torch.Tensor)
            shape_ok = (
                len(action.shape) == 3
                and action.shape[1] == self.action_horizon
                and action.shape[2] == self.action_dim
            )
            if not type_ok:
                error_msg += f"\n{action.dtype=}"
                detected_error = True
            if not shape_ok:
                error_msg += f"\n{action.shape=}"
                detected_error = True
        if "video" in inputs:
            video = inputs["video"]
            type_ok = isinstance(video, np.ndarray)
            dtype_ok = video.dtype == np.uint8
            shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
            if not type_ok:
                error_msg += f"\n{type(video)=}"
                detected_error = True
            if not dtype_ok:
                error_msg += f"\n{video.dtype=}"
                detected_error = True
            if not shape_ok:
                error_msg += f"\n{video.shape=}"
                detected_error = True
        if detected_error:
            raise ValueError(error_msg)
    def validate_data(self, action_head_outputs, backbone_outputs, is_training):
        fail_backbone = (
            not isinstance(backbone_outputs, BatchFeature)
            or BACKBONE_FEATURE_KEY not in backbone_outputs
        )
        if fail_backbone:
            error_msg = ERROR_MSG
            error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
            error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
            error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
            raise ValueError(error_msg)
        fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
            (
                LOSS_KEY in action_head_outputs and is_training
            )
            or (
                ACTION_KEY in action_head_outputs
                and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
                and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
            )
        )
        if fail_action_head:
            error_msg = ERROR_MSG
            error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
            error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
            error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
            error_msg += f"\n{self.action_horizon=}"
            error_msg += f"\n{self.action_dim=}"
            raise ValueError(error_msg)
    def get_opengl_conversion_matrix(self) -> np.ndarray:
        """
        Constructs and returns the OpenGL conversion matrix.
        Returns:
            numpy.ndarray: A 4x4 OpenGL conversion matrix.
        """
        matrix = np.identity(4)
        matrix[1, 1] = -1
        matrix[2, 2] = -1
        return matrix
    def fused_current_obs(self, images, memory, memory_labels):
        geo_prediction, cam_transformation = self.forward_geometry(images)
        bbox_2d_results = self.forward_2d_detection(images)
        feature_imgs = self.forward_perception(images)
        cropped_features, cropped_classes = self.crop_feature_from_bbox(feature_imgs, bbox_2d_results, images)
        if len(cropped_features) == 0:
            return memory, memory_labels
        instance_labels = self.get_most_semanctic_instance(cropped_classes, cropped_features, bbox_2d_results)
        object_lastest_memory, object_lastest_memory_labels = self.get_latest_object_semantics(cropped_features, instance_labels)
        bbox_lastest_3d_corners, bbox_3d_lastest_labels = self.get_latest_object_3d_position(instance_labels, geo_prediction, bbox_2d_results, cam_transformation)
        bbox_3d_lastest_pos_embedding = self.embed_bbox_3d(bbox_lastest_3d_corners)
        grouped_spatial_lastest_embeddings = []
        for label in object_lastest_memory_labels:
            mask = (bbox_3d_lastest_labels == label)
            if mask.any():
                mean_embedding = bbox_3d_lastest_pos_embedding[mask].mean(dim=0)
            else:
                mean_embedding = torch.zeros_like(bbox_3d_lastest_pos_embedding[0])
            grouped_spatial_lastest_embeddings.append(mean_embedding)
        grouped_spatial_lastest_embeddings = torch.stack(grouped_spatial_lastest_embeddings, dim=0)
        spatial_lastest_cls = self.spatial_pos_mlp(grouped_spatial_lastest_embeddings)
        process_lastest_memory = self.memory_mlp(object_lastest_memory) + spatial_lastest_cls
        updated_memory = []
        updated_labels = []
        for i, label in enumerate(object_lastest_memory_labels):
            if label in memory_labels:
                idx = list(memory_labels).index(label)
                mem_vec = memory[idx]
                new_vec = process_lastest_memory[i]
                sim = torch.nn.functional.cosine_similarity(mem_vec.unsqueeze(0), new_vec.unsqueeze(0)).item()
                alpha = min(max(sim, 0.1), 0.9)
                updated_vec = alpha * new_vec + (1 - alpha) * mem_vec
                updated_memory.append(updated_vec)
                updated_labels.append(label)
            else:
                updated_memory.append(process_lastest_memory[i])
                updated_labels.append(label)
        for i, label in enumerate(memory_labels):
            if label not in object_lastest_memory_labels:
                updated_memory.append(memory[i])
                updated_labels.append(label)
        memory = torch.stack(updated_memory)
        memory_labels = updated_labels
        return memory, memory_labels
    def get_most_semanctic_instance(self, cropped_classes, cropped_features, bbox_pred):
        grouped_data = defaultdict(lambda: {'features': [], 'indices': []})
        for i, crop_feat in enumerate(cropped_features):
            cls_id = cropped_classes[i]
            grouped_data[cls_id]['features'].append(crop_feat)
            grouped_data[cls_id]['indices'].append(i)
        final_instance_labels = [None] * len(cropped_features)
        for cls_id, data in grouped_data.items():
            features_list = data['features'] # 这是 (200, 384) 张量的列表
            original_indices = data['indices']
            class_name = bbox_pred[0].names.get(cls_id, f"class_{cls_id}")
            num_instances = len(features_list)
            if num_instances == 1:
                instance_label = f"{class_name}"
                final_instance_labels[original_indices[0]] = instance_label
                continue
            dist_matrix = np.zeros((num_instances, num_instances))
            for i in range(num_instances):
                for j in range(i + 1, num_instances):
                    pc1 = features_list[i]
                    pc2 = features_list[j]
                    distance = self.calculate_chamfer_distance(pc1, pc2)
                    dist_matrix[i, j] = distance
                    dist_matrix[j, i] = distance
            clusterer = hdbscan.HDBSCAN(
                min_cluster_size=2,
                metric='precomputed', # <--- 告诉HDBSCAN我们提供的是距离矩阵
                allow_single_cluster=True
            )
            cluster_labels = clusterer.fit_predict(dist_matrix)
            unique_labels = set(cluster_labels.tolist())
            if unique_labels <= {0, -1}:
                for i, original_idx in enumerate(original_indices):
                    instance_label = f"{class_name}"
                    final_instance_labels[original_idx] = instance_label
                continue
            else:
                for i, cluster_id in enumerate(cluster_labels):
                    original_idx = original_indices[i]
                    if cluster_id == -1:
                        instance_label = f"{class_name}_noise"
                    else:
                        instance_label = f"{class_name}_{cluster_id}"
                    final_instance_labels[original_idx] = instance_label
        return final_instance_labels
    def get_latest_object_3d_position(self, final_instance_labels, geo_pred, bbox_pred, cam_transformation):
        all_bboxes_by_img = [[] for _ in range(len(bbox_pred))]
        all_labels_by_img = [[] for _ in range(len(bbox_pred))]
        current_label_idx = 0
        for img_idx, result in enumerate(bbox_pred):
            if result.boxes is None or len(result.boxes) == 0:
                continue
            num_boxes = len(result.boxes)
            labels = final_instance_labels[current_label_idx : current_label_idx + num_boxes]
            current_label_idx += num_boxes
            all_bboxes_by_img[img_idx] = result.boxes.xyxy
            all_labels_by_img[img_idx] = labels
        final_corners_list = []
        final_labels_list = []
        cam_transformation_tensor = torch.tensor(cam_transformation).float().to(self.device)
        for img_idx in range(len(bbox_pred)):
            bboxes_tensor = all_bboxes_by_img[img_idx]
            if len(bboxes_tensor) == 0:
                continue
            labels_for_img = all_labels_by_img[img_idx]
            world_points_tensor = geo_pred['world_points'][0][img_idx] 
            result = self.batched_bbox_to_3d_aabb_robust(
                bboxes=bboxes_tensor,
                world_points=world_points_tensor,
                initial_transformation=cam_transformation_tensor,
                z_percentile=0.25,
                z_margin=0.2
            )
            if result is not None:
                corners_3d_batch, valid_mask = result
                valid_labels = [label for i, label in enumerate(labels_for_img) if valid_mask[i]]
                final_corners_list.append(corners_3d_batch.float().cpu().numpy())
                final_labels_list.extend(valid_labels)
        if not final_corners_list:
            return np.array([]), np.array([])
        return np.concatenate(final_corners_list, axis=0), np.array(final_labels_list)
    def get_latest_object_semantics(self, cropped_features, final_instance_labels):
        label_to_features = {}
        for feat, label in zip(cropped_features, final_instance_labels):
            if label not in label_to_features:
                label_to_features[label] = []
            label_to_features[label].append(feat)
        fused_matrix = []
        fused_labels = []
        for label, feats in label_to_features.items():
            all_points = torch.cat([f.reshape(-1, f.shape[1]) for f in feats], dim=0)
            mean_feat = all_points.mean(dim=0)
            fused_matrix.append(mean_feat)
            fused_labels.append(label)
        fused_matrix = torch.stack(fused_matrix)
        return fused_matrix, fused_labels
    def bbox_to_3d_aabb_robust(self, bbox, world_points, initial_transformation, 
                           z_percentile=0.25, z_margin=0.2):
        """
        使用物体内部点云，稳健地获取3D轴对齐包围盒（AABB）的8个角点。
        参数:
        bbox: [x1, y1, x2, y2]，像素坐标。
        world_points: (H, W, 3) 当前图片的点云（原始世界坐标）。
        initial_transformation: (4, 4) numpy.ndarray，统一空间变换矩阵。
        z_percentile: float, 用于确定物体主体深度的百分位数。默认为0.25（取靠前的25%深度）。
        z_margin: float, 允许的深度范围（米）。距离主体深度超过此范围的点将被视为背景。
        返回:
        (8, 3) numpy数组，代表3D包围盒在统一空间下的8个角点。
        如果有效点不足，则返回 None。
        """
        x1, y1, x2, y2 = map(int, bbox)
        h, w = world_points.shape[:2]
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)
        if y1 >= y2 or x1 >= x2:
            return None
        object_points_grid = world_points[y1:y2, x1:x2]
        points_list = object_points_grid.reshape(-1, 3)
        valid_points_mask = np.any(points_list != 0, axis=1)
        valid_points = points_list[valid_points_mask]
        if valid_points.shape[0] < 20:
            return None
        main_body_z = np.percentile(valid_points[:, 2], z_percentile * 100)
        foreground_mask = np.abs(valid_points[:, 2] - main_body_z) < z_margin
        foreground_points = valid_points[foreground_mask]
        if foreground_points.shape[0] < 20:
            return None
        min_corner = np.min(foreground_points, axis=0)
        max_corner = np.max(foreground_points, axis=0)
        x_min, y_min, z_min = min_corner
        x_max, y_max, z_max = max_corner
        corners_local = np.array([
            [x_min, y_min, z_min],
            [x_max, y_min, z_min],
            [x_max, y_max, z_min],
            [x_min, y_max, z_min],
            [x_min, y_min, z_max],
            [x_max, y_min, z_max],
            [x_max, y_max, z_max],
            [x_min, y_max, z_max]
        ])
        corners_local_homo = np.concatenate([corners_local, np.ones((8, 1))], axis=1)
        transformed_corners_homo = corners_local_homo @ initial_transformation.T
        transformed_corners = transformed_corners_homo[:, :3]
        return transformed_corners
    def calculate_chamfer_distance(self, point_cloud1, point_cloud2):
        """
        计算两个点云之间的Chamfer Distance。
        :param point_cloud1: 第一个点云, shape: (N, D) -> (200, 384)
        :param point_cloud2: 第二个点云, shape: (M, D) -> (200, 384)
        :return: 两个点云之间的Chamfer Distance值 (一个标量)
        """
        dists1 = torch.cdist(point_cloud1, point_cloud2)
        min_dists1, _ = torch.min(dists1, dim=1)
        dists2 = torch.cdist(point_cloud2, point_cloud1)
        min_dists2, _ = torch.min(dists2, dim=1)
        chamfer_dist = torch.mean(min_dists1) + torch.mean(min_dists2)
        return chamfer_dist.item()
    def crop_feature_from_bbox(self, feature_map, pred_2d, images):
        cropped_features = []
        cropped_classes = []
        for idx, feature_map in enumerate(feature_map):
            bboxes = pred_2d[idx].boxes.xyxy.cpu().numpy()
            classes = pred_2d[idx].boxes.cls.cpu().numpy()
            conf = pred_2d[idx].boxes.conf.cpu().numpy()
            keep = conf > 0.2
            bboxes = bboxes[keep]
            classes = classes[keep]
            conf = conf[keep]
            H_feat, W_feat = feature_map.shape[1], feature_map.shape[2]
            H_img, W_img = images[idx].size[0], images[idx].size[1]
            scale_h = H_feat / H_img
            scale_w = W_feat / W_img
            for box, cls in zip(bboxes, classes):
                x1, y1, x2, y2 = box
                x1_f = int(x1 * scale_w)
                x2_f = int(x2 * scale_w)
                y1_f = int(y1 * scale_h)
                y2_f = int(y2 * scale_h)
                crop = feature_map[:, y1_f:y2_f, x1_f:x2_f]
                crop_flat = crop.reshape(crop.shape[0], -1)
                n_points = crop_flat.shape[1]
                if n_points >= 200:
                    idxs = np.linspace(0, n_points - 1, 200, dtype=int)
                    crop_sampled = crop_flat[:, idxs]
                else:
                    mean = crop_flat.mean(dim=1, keepdim=True)
                    std = crop_flat.std(dim=1, keepdim=True)
                    crop_sampled = crop_flat
                    n_missing = 200 - n_points
                    gaussian_samples = torch.randn((crop.shape[0], n_missing), device=mean.device) * std + mean
                    crop_sampled = torch.cat([crop_sampled, gaussian_samples], dim=1)
                cropped_features.append(crop_sampled.permute(1, 0))
                cropped_classes.append(int(cls))
        return cropped_features, cropped_classes
    def forward_perception(self, images):
        image_resized = self.resize_transform(np.array(images))
        image_resized_norm = TF.normalize(image_resized, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        with torch.inference_mode():
            with torch.autocast(device_type='cuda', dtype=torch.float32):
                feats = self.perception_head.get_intermediate_layers(
                    image_resized_norm.cuda(), 
                    n=range(self.n_layers), 
                    reshape=True,
                    norm=True
                )
        feature_map_batch = feats[-1]
        return feature_map_batch
    def resize_transform(
        self,
        mask_images: np.ndarray,
        image_size: int = 768,
        patch_size: int = 16,
    ) -> torch.Tensor:
        batch_size = mask_images.shape[0]
        tensors = []
        for i in range(batch_size):
            img = mask_images[i]
            if img.ndim == 3:
                pil_img = Image.fromarray(img)
            else:
                pil_img = Image.fromarray(img, mode='L')
            w, h = pil_img.size
            h_patches = int(image_size / patch_size)
            w_patches = int((w * image_size) / (h * patch_size)) 
            target_h = h_patches * patch_size
            target_w = w_patches * patch_size
            resized = TF.resize(pil_img, (target_h, target_w))
            tensors.append(TF.to_tensor(resized))
        return torch.stack(tensors)
    def forward_2d_detection(self, images):
        results = self.det_head(images)
        return results
    def forward_geometry(self, images):
        images = self.preprocess_images(images).to(self.device)
        dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
        with torch.no_grad():
            with torch.amp.autocast('cuda', dtype=dtype):
                predictions = self.geometry_head(images)
        predictions, initial_transformation =  self.postprocess_camera_transform(predictions, images)
        return predictions, initial_transformation
    def postprocess_camera_transform(self, input, images):
        extrinsic, intrinsic = pose_encoding_to_extri_intri(input["pose_enc"], images.shape[-2:])
        input["extrinsic"] = extrinsic
        input["intrinsic"] = intrinsic
        camera_matrices = input["extrinsic"][0]
        num_cameras = len(camera_matrices)
        extrinsics_matrices = np.zeros((num_cameras, 4, 4))
        extrinsics_matrices[:, :3, :4] = camera_matrices.cpu().numpy()
        extrinsics_matrices[:, 3, 3] = 1
        def get_opengl_conversion_matrix() -> np.ndarray:
            """
            Constructs and returns the OpenGL conversion matrix.
            Returns:
                numpy.ndarray: A 4x4 OpenGL conversion matrix.
            """
            matrix = np.identity(4)
            matrix[1, 1] = -1
            matrix[2, 2] = -1
            return matrix
        opengl_conversion_matrix = get_opengl_conversion_matrix()
        extrinsics_flat = extrinsics_matrices[:, :3, 3]
        mean_translation = extrinsics_flat.mean(axis=0)
        dists = np.linalg.norm(extrinsics_flat - mean_translation, axis=1)
        middle_idx = np.argmin(dists)
        middle_extrinsic = extrinsics_matrices[middle_idx]
        initial_transformation = np.linalg.inv(middle_extrinsic) @ opengl_conversion_matrix
        return input, initial_transformation
    def preprocess_images(
        self,
        image_list: list[Image.Image],
        mode: str = "crop",
        target_size: int = 518,
        patch_size: int = 14,
    ) -> torch.Tensor:
        """
        加载并预处理一批图像对象以供模型输入。
        Args:
            image_list (list): 一个包含 PIL.Image 对象的列表。
            mode (str, optional): 预处理模式, "crop" 或 "pad".
                - "crop" (默认): 将宽度设置为 target_size，保持宽高比，如果高度超过 target_size 则进行中心裁剪。
                - "pad": 保持宽高比，使最大边为 target_size，然后将较小的边填充以形成方形。
            target_size (int, optional): 目标尺寸，默认为 518。
            patch_size (int, optional): 模型要求的 patch 大小，尺寸必须是它的倍数，默认为 14。
        Returns:
            torch.Tensor: 经过预处理并批处理好的图像张量，形状为 (N, 3, H, W)。
        Raises:
            ValueError: 如果输入列表为空或模式无效。
        """
        if not image_list:
            raise ValueError("输入图像列表不能为空 (image_list cannot be empty)")
        if mode not in ["crop", "pad"]:
            raise ValueError("模式必须是 'crop' 或 'pad' (Mode must be either 'crop' or 'pad')")
        processed_images = []
        for img in image_list:
            if img.mode == "RGBA":
                background = Image.new("RGB", img.size, (255, 255, 255))
                background.paste(img, mask=img.split()[3])
                img = background
            elif img.mode != "RGB":
                img = img.convert("RGB")
            width, height = img.size
            new_width, new_height = 0, 0
            if mode == "pad":
                if width >= height:
                    new_width = target_size
                    new_height = round(height * (new_width / width))
                else:
                    new_height = target_size
                    new_width = round(width * (new_height / height))
            else:
                new_width = target_size
                new_height = round(height * (new_width / width))
            new_width = round(new_width / patch_size) * patch_size
            new_height = round(new_height / patch_size) * patch_size
            new_width = max(new_width, patch_size)
            new_height = max(new_height, patch_size)
            img_resized = TF.resize(img, (new_height, new_width), interpolation=TF.InterpolationMode.BICUBIC)
            img_tensor = TF.to_tensor(img_resized)
            if mode == "crop":
                if img_tensor.shape[1] > target_size:
                    img_tensor = TF.center_crop(img_tensor, (target_size, img_tensor.shape[2]))
            elif mode == "pad":
                h_pad = target_size - img_tensor.shape[1]
                w_pad = target_size - img_tensor.shape[2]
                padding = [w_pad // 2, h_pad // 2, w_pad - (w_pad // 2), h_pad - (h_pad // 2)]
                img_tensor = TF.pad(img_tensor, padding, fill=1.0)
            processed_images.append(img_tensor)
        max_height = max(img.shape[1] for img in processed_images)
        max_width = max(img.shape[2] for img in processed_images)
        padded_images = []
        for img in processed_images:
            h_pad = max_height - img.shape[1]
            w_pad = max_width - img.shape[2]
            if h_pad > 0 or w_pad > 0:
                padding = [w_pad // 2, h_pad // 2, w_pad - (w_pad // 2), h_pad - (h_pad // 2)]
                img = TF.pad(img, padding, fill=1.0)
            padded_images.append(img)
        return torch.stack(padded_images)
    def boost_backbone_outputs_spatial(self, backbone_outputs, inputs):
        backbone_features = backbone_outputs['backbone_features']  # (B, T, D)
        memories = inputs['overview_memory']   # list of (M_i, D)
        total_memory = pad_sequence(memories, batch_first=True)  
        lengths = torch.tensor([m.size(0) for m in memories], device=total_memory.device)
        max_len = total_memory.size(1)
        memory_mask = torch.arange(max_len, device=total_memory.device)[None, :] < lengths[:, None]
        total_memory_expand = self.memory_alignment_mlp(total_memory)
        concat_features = torch.cat([backbone_features, total_memory_expand], dim=1)
        backbone_mask = torch.ones(
            backbone_features.size(0), backbone_features.size(1), 
            dtype=torch.bool, device=backbone_features.device
        )
        memory_mask = memory_mask.to(concat_features.device)
        full_mask = torch.cat([backbone_mask, memory_mask], dim=1)
        concat_features = self.boosting_ln(concat_features)
        interaction_features = self.boosting_transformer(concat_features, full_mask)
        boosted_backbone_features = interaction_features[:, :backbone_features.shape[1], :]
        backbone_outputs['backbone_features'] = boosted_backbone_features
        return backbone_outputs
    def forward(
        self,
        inputs: dict,
    ) -> BatchFeature:
        inputs = self.filtering_keys(inputs)
        backbone_inputs, action_inputs = self.prepare_input(inputs)
        backbone_outputs = self.backbone(backbone_inputs)
        action_head_outputs = self.action_head(backbone_outputs, action_inputs)
        self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
        return action_head_outputs
    def get_action(
        self,
        inputs: dict,
    ) -> BatchFeature:
        backbone_inputs, action_inputs = self.prepare_input(inputs)
        backbone_outputs = self.backbone(backbone_inputs)
        action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
        self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
        return action_head_outputs
    def prepare_input(self, inputs) -> Tuple[BatchFeature, BatchFeature]:
        self.validate_inputs(inputs)
        backbone_inputs = self.backbone.prepare_input(inputs)
        action_inputs = self.action_head.prepare_input(inputs)
        def to_device_with_maybe_dtype(x):
            if isinstance(x, np.ndarray):
                x = torch.from_numpy(x)
            if torch.is_floating_point(x):
                return x.to(self.device, dtype=self.action_head.dtype)
            else:
                return x.to(self.device)
        backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
        action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
        return backbone_inputs, action_inputs
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        tune_visual = kwargs.pop("tune_visual", True)
        tune_llm = kwargs.pop("tune_llm", False)
        tune_projector = kwargs.pop("tune_projector", True)
        tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
        print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
        print(f"Tune backbone vision tower: {tune_visual}")
        print(f"Tune backbone LLM: {tune_llm}")
        print(f"Tune action head projector: {tune_projector}")
        print(f"Tune action head DiT: {tune_diffusion_model}")
        try:
            local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
        except (HFValidationError, RepositoryNotFoundError):
            print(
                f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
            )
            local_model_path = pretrained_model_name_or_path
        pretrained_model = super().from_pretrained(
            local_model_path, local_model_path=local_model_path, **kwargs
        )
        pretrained_model.backbone.set_trainable_parameters(
            tune_visual=tune_visual, tune_llm=tune_llm
        )
        pretrained_model.action_head.set_trainable_parameters(
            tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
        )
        REPO_DIR = "/workspaces/Jeff/Isaac-GR00T/dinov3"
        MODEL_DINOV3_VITS = "dinov3_vits16"
        MODEL_DINOV3_VITSP = "dinov3_vits16plus"
        MODEL_DINOV3_VITB = "dinov3_vitb16"
        MODEL_DINOV3_VITL = "dinov3_vitl16"
        MODEL_DINOV3_VITHP = "dinov3_vith16plus"
        MODEL_DINOV3_VIT7B = "dinov3_vit7b16"
        MODEL_NAME = MODEL_DINOV3_VITSP
        MODEL_TO_NUM_LAYERS = {
            MODEL_DINOV3_VITS: 12,
            MODEL_DINOV3_VITSP: 12,
            MODEL_DINOV3_VITB: 12,
            MODEL_DINOV3_VITL: 24,
            MODEL_DINOV3_VITHP: 32,
            MODEL_DINOV3_VIT7B: 40,
        }
        pretrained_model.n_layers = MODEL_TO_NUM_LAYERS[MODEL_NAME]
        pretrained_model.perception_head = torch.hub.load(REPO_DIR, MODEL_NAME, source='local')
        pretrained_model.det_head = YOLO('/workspaces/Jeff/Isaac-GR00T/yolov8x-worldv2.pt') # to be defined
        pretrained_model.det_head.overrides['verbose'] = False
        pretrained_model.det_head.overrides['show'] = False
        def frozen_train(*args, **kwargs):
            return pretrained_model.det_head.eval()
        pretrained_model.det_head.train = frozen_train
        pretrained_model.geometry_head = VGGT.from_pretrained("/workspaces/Jeff/Isaac-GR00T/checkpoint/VGGT-1B") # to be defined
        pretrained_model.geometry_head.eval()
        pretrained_model.geometry_head.requires_grad_(False)
        pretrained_model.det_head.eval()
        pretrained_model.det_head.model.requires_grad_(False)
        pretrained_model.perception_head.eval()
        pretrained_model.perception_head.requires_grad_(False)
        pretrained_model.init_weights()
        return pretrained_model
AutoConfig.register("gr00t_n1_5", GR00T_N1_5_Config)
AutoModel.register(GR00T_N1_5_Config, SOMA)
