# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

import clip
import time
import torch
import pprint
import torchvision
import numpy as np
import torch.nn as nn
import bitsandbytes as bnb
import torch.nn.functional as F
from scipy.spatial.transform import Rotation
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel.distributed import DistributedDataParallel

from yarr.agents.agent import ActResult
import peract_colab.arm.utils as arm_utils
from peract_colab.arm.optim.lamb import Lamb

import dual_stream.mvt.utils as mvt_utils
from dual_stream.utils.cortical_utils import _clip_encode_text
from dual_stream.utils.lr_sched_utils import GradualWarmupScheduler
from dual_stream.utils import rvt_utils, peract_utils, cortical_utils
from dual_stream.mvt.augmentation import apply_se3_aug_con, aug_utils
from dual_stream.utils.visual_utils import visualize_comparison # ,IMG_IDX
from dual_stream.utils.vggt_utils import (
    DEBUG,
    # timeit,
    sigmoid,
    get_vggt_feature_map,
    interpolate_features,
    align_features_with_kp,
    get_3d_keypoints_from_gt
) 

def eval_con(gt, pred):
    assert gt.shape == pred.shape, print(f"{gt.shape} {pred.shape}")
    assert len(gt.shape) == 2
    dist = torch.linalg.vector_norm(gt - pred, dim=1)
    return {"avg err": dist.mean()}


def eval_con_cls(gt, pred, num_bin=72, res=5, symmetry=1):
    """
    Evaluate continuous classification where floating point values are put into
    discrete bins
    :param gt: (bs,)
    :param pred: (bs,)
    :param num_bin: int for the number of rotation bins
    :param res: float to specify the resolution of each rotation bin
    :param symmetry: degrees of symmetry; 2 is 180 degree symmetry, 4 is 90
        degree symmetry
    """
    assert gt.shape == pred.shape
    assert len(gt.shape) in [0, 1], gt
    assert num_bin % symmetry == 0, (num_bin, symmetry)
    gt = torch.tensor(gt)
    pred = torch.tensor(pred)
    num_bin //= symmetry
    pred %= num_bin
    gt %= num_bin
    dist = torch.abs(pred - gt)
    dist = torch.min(dist, num_bin - dist)
    dist_con = dist.float() * res
    return {"avg err": dist_con.mean()}


def eval_cls(gt, pred):
    """
    Evaluate classification performance
    :param gt_coll: (bs,)
    :param pred: (bs,)
    """
    assert gt.shape == pred.shape
    assert len(gt.shape) == 1
    return {"per err": (gt != pred).float().mean()}


def eval_all(
    wpt,
    pred_wpt,
    action_rot,
    pred_rot_quat,
    action_grip_one_hot,
    grip_q,
    action_collision_one_hot,
    collision_q,
):
    bs = len(wpt)
    assert wpt.shape == (bs, 3), wpt
    assert pred_wpt.shape == (bs, 3), pred_wpt
    assert action_rot.shape == (bs, 4), action_rot
    assert pred_rot_quat.shape == (bs, 4), pred_rot_quat
    assert action_grip_one_hot.shape == (bs, 2), action_grip_one_hot
    assert grip_q.shape == (bs, 2), grip_q
    assert action_collision_one_hot.shape == (bs, 2), action_collision_one_hot
    assert collision_q.shape == (bs, 2), collision_q

    eval_trans = []
    eval_rot_x = []
    eval_rot_y = []
    eval_rot_z = []
    eval_grip = []
    eval_coll = []

    for i in range(bs):
        eval_trans.append(
            eval_con(wpt[i : i + 1], pred_wpt[i : i + 1])["avg err"]
            .cpu()
            .numpy()
            .item()
        )

        euler_gt = Rotation.from_quat(action_rot[i]).as_euler("xyz", degrees=True)
        euler_pred = Rotation.from_quat(pred_rot_quat[i]).as_euler("xyz", degrees=True)

        eval_rot_x.append(
            eval_con_cls(euler_gt[0], euler_pred[0], num_bin=360, res=1)["avg err"]
            .cpu()
            .numpy()
            .item()
        )
        eval_rot_y.append(
            eval_con_cls(euler_gt[1], euler_pred[1], num_bin=360, res=1)["avg err"]
            .cpu()
            .numpy()
            .item()
        )
        eval_rot_z.append(
            eval_con_cls(euler_gt[2], euler_pred[2], num_bin=360, res=1)["avg err"]
            .cpu()
            .numpy()
            .item()
        )

        eval_grip.append(
            eval_cls(
                action_grip_one_hot[i : i + 1].argmax(-1),
                grip_q[i : i + 1].argmax(-1),
            )["per err"]
            .cpu()
            .numpy()
            .item()
        )

        eval_coll.append(
            eval_cls(
                action_collision_one_hot[i : i + 1].argmax(-1),
                collision_q[i : i + 1].argmax(-1),
            )["per err"]
            .cpu()
            .numpy()
        )

    return eval_trans, eval_rot_x, eval_rot_y, eval_rot_z, eval_grip, eval_coll


def manage_eval_log(
    self,
    tasks,
    wpt,
    pred_wpt,
    action_rot,
    pred_rot_quat,
    action_grip_one_hot,
    grip_q,
    action_collision_one_hot,
    collision_q,
    reset_log=False,
):
    bs = len(wpt)
    assert wpt.shape == (bs, 3), wpt
    assert pred_wpt.shape == (bs, 3), pred_wpt
    assert action_rot.shape == (bs, 4), action_rot
    assert pred_rot_quat.shape == (bs, 4), pred_rot_quat
    assert action_grip_one_hot.shape == (bs, 2), action_grip_one_hot
    assert grip_q.shape == (bs, 2), grip_q
    assert action_collision_one_hot.shape == (bs, 2), action_collision_one_hot
    assert collision_q.shape == (bs, 2), collision_q

    if not hasattr(self, "eval_trans") or reset_log:
        self.eval_trans = {}
        self.eval_rot_x = {}
        self.eval_rot_y = {}
        self.eval_rot_z = {}
        self.eval_grip = {}
        self.eval_coll = {}

    (
        eval_trans,
        eval_rot_x,
        eval_rot_y,
        eval_rot_z,
        eval_grip,
        eval_coll,
    ) = eval_all(
        wpt=wpt,
        pred_wpt=pred_wpt,
        action_rot=action_rot,
        pred_rot_quat=pred_rot_quat,
        action_grip_one_hot=action_grip_one_hot,
        grip_q=grip_q,
        action_collision_one_hot=action_collision_one_hot,
        collision_q=collision_q,
    )

    for idx, task in enumerate(tasks):
        if not (task in self.eval_trans):
            self.eval_trans[task] = []
            self.eval_rot_x[task] = []
            self.eval_rot_y[task] = []
            self.eval_rot_z[task] = []
            self.eval_grip[task] = []
            self.eval_coll[task] = []
        self.eval_trans[task].append(eval_trans[idx])
        self.eval_rot_x[task].append(eval_rot_x[idx])
        self.eval_rot_y[task].append(eval_rot_y[idx])
        self.eval_rot_z[task].append(eval_rot_z[idx])
        self.eval_grip[task].append(eval_grip[idx])
        self.eval_coll[task].append(eval_coll[idx])

    return {
        "eval_trans": eval_trans,
        "eval_rot_x": eval_rot_x,
        "eval_rot_y": eval_rot_y,
        "eval_rot_z": eval_rot_z,
    }


def print_eval_log(self):
    logs = {
        "trans": self.eval_trans,
        "rot_x": self.eval_rot_x,
        "rot_y": self.eval_rot_y,
        "rot_z": self.eval_rot_z,
        "grip": self.eval_grip,
        "coll": self.eval_coll,
    }

    out = {}
    for name, log in logs.items():
        for task, task_log in log.items():
            task_log_np = np.array(task_log)
            mean, std, median = (
                np.mean(task_log_np),
                np.std(task_log_np),
                np.median(task_log_np),
            )
            out[f"{task}/{name}_mean"] = mean
            out[f"{task}/{name}_std"] = std
            out[f"{task}/{name}_median"] = median

    pprint.pprint(out)

    return out


def manage_loss_log(
    agent,
    loss_log,
    reset_log,
):
    if not hasattr(agent, "loss_log") or reset_log:
        agent.loss_log = {}

    for key, val in loss_log.items():
        if key in agent.loss_log:
            agent.loss_log[key].append(val)
        else:
            agent.loss_log[key] = [val]


def print_loss_log(agent):
    out = {}
    for key, val in agent.loss_log.items():
        out[key] = np.mean(np.array(val))
    pprint.pprint(out)
    return out


class CorticalAgent:
    def __init__(
        self,
        network: nn.Module,
        num_rotation_classes: int,
        stage_two: bool,
        add_lang: bool,
        amp: bool,
        bnb: bool,
        move_pc_in_bound: bool,
        lr: float = 0.0001,
        lr_cos_dec: bool = False,
        cos_dec_max_step: int = 60000,
        warmup_steps: int = 0,
        image_resolution: list = None,
        lambda_weight_l2: float = 0.0,
        transform_augmentation: bool = True,
        transform_augmentation_xyz: list = [0.1, 0.1, 0.1],
        transform_augmentation_rpy: list = [0.0, 0.0, 20.0],
        place_with_mean: bool = True,
        transform_augmentation_rot_resolution: int = 5,
        optimizer_type: str = "lamb",
        gt_hm_sigma: float = 1.5,
        img_aug: bool = False,
        add_rgc_loss: bool = False,
        scene_bounds: list = peract_utils.SCENE_BOUNDS,
        cameras: list = peract_utils.CAMERAS,
        rot_ver: int = 0,
        rot_x_y_aug: int = 2,
        log_dir="",
    ):
        """
        :param gt_hm_sigma: the std of the groundtruth hm, currently for for
            2d, if -1 then only single point is considered
        :type gt_hm_sigma: float
        :param rot_ver: version of the rotation prediction network
            Either:
                0: same as peract, independent discrete xyz predictions
                1: xyz prediction dependent on one another
        :param rot_x_y_aug: only applicable when rot_ver is 1, it specifies how
            much error we should add to groundtruth rotation while training
        :param log_dir: a folder location for saving some intermediate data
        """

        self._network = network
        self._num_rotation_classes = num_rotation_classes
        self._rotation_resolution = 360 / self._num_rotation_classes
        self._lr = lr
        self._image_resolution = image_resolution
        self._lambda_weight_l2 = lambda_weight_l2
        self._transform_augmentation = transform_augmentation
        self._place_with_mean = place_with_mean
        self._transform_augmentation_xyz = torch.from_numpy(
            np.array(transform_augmentation_xyz)
        )
        self._transform_augmentation_rpy = transform_augmentation_rpy
        self._transform_augmentation_rot_resolution = (
            transform_augmentation_rot_resolution
        )
        self._optimizer_type = optimizer_type
        self.gt_hm_sigma = gt_hm_sigma
        self.img_aug = img_aug
        self.add_rgc_loss = add_rgc_loss
        self.amp = amp
        self.bnb = bnb
        self.stage_two = stage_two
        self.add_lang = add_lang
        self.log_dir = log_dir
        self.warmup_steps = warmup_steps
        self.lr_cos_dec = lr_cos_dec
        self.cos_dec_max_step = cos_dec_max_step
        self.scene_bounds = scene_bounds
        self.cameras = cameras
        self.move_pc_in_bound = move_pc_in_bound
        self.rot_ver = rot_ver
        self.rot_x_y_aug = rot_x_y_aug

        self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
        if isinstance(self._network, DistributedDataParallel):
            self._net_mod = self._network.module
        else:
            self._net_mod = self._network

        self.num_all_rot = self._num_rotation_classes * 3

        self.scaler = GradScaler(enabled=self.amp)
        self.visual_idx = 0


    def build(self, training: bool, device: torch.device = None, vggt=None):
        self._training = training
        self._device = device

        if self._optimizer_type == "lamb":
            if self.bnb:
                print("Using 8-Bit Optimizer")
                self._optimizer = bnb.optim.LAMB(
                    self._network.parameters(),
                    lr=self._lr * 0.85,
                    weight_decay=self._lambda_weight_l2,
                    betas=(0.9, 0.999),
                )
            else:
                # From: https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
                self._optimizer = Lamb(
                    self._network.parameters(),
                    lr=self._lr,
                    weight_decay=self._lambda_weight_l2,
                    betas=(0.9, 0.999),
                    adam=False,
                )
        elif self._optimizer_type == "adam":
            self._optimizer = torch.optim.Adam(
                self._network.parameters(),
                lr=self._lr,
                weight_decay=self._lambda_weight_l2,
            )
        else:
            raise Exception("Unknown optimizer")

        if self.lr_cos_dec:
            after_scheduler = CosineAnnealingLR(
                self._optimizer,
                T_max=self.cos_dec_max_step,
                eta_min=self._lr / 100,  # mininum lr
            )
        else:
            after_scheduler = None
        self._lr_sched = GradualWarmupScheduler(
            self._optimizer,
            multiplier=1,
            total_epoch=self.warmup_steps,
            after_scheduler=after_scheduler,
        )

        if vggt is None:
            return
        self.vggt_model = vggt
        self.vggt_model.eval() 


    def load_clip(self):
        self.clip_model, self.clip_preprocess = clip.load("RN50", device=self._device)
        self.clip_model.eval()

    def unload_clip(self):
        del self.clip_model
        del self.clip_preprocess
        with torch.cuda.device(self._device):
            torch.cuda.empty_cache()

    # copied from per-act and removed the translation part
    def _get_one_hot_expert_actions(
        self,
        batch_size,
        action_rot,
        action_grip,
        action_ignore_collisions,
        device,
    ):
        """_get_one_hot_expert_actions.

        :param batch_size: int
        :param action_rot: np.array of shape (bs, 4), quternion xyzw format
        :param action_grip: torch.tensor of shape (bs)
        :param action_ignore_collisions: torch.tensor of shape (bs)
        :param device:
        """
        bs = batch_size
        assert action_rot.shape == (bs, 4)
        assert action_grip.shape == (bs,), (action_grip, bs)

        action_rot_x_one_hot = torch.zeros(
            (bs, self._num_rotation_classes), dtype=int, device=device
        )
        action_rot_y_one_hot = torch.zeros(
            (bs, self._num_rotation_classes), dtype=int, device=device
        )
        action_rot_z_one_hot = torch.zeros(
            (bs, self._num_rotation_classes), dtype=int, device=device
        )
        action_grip_one_hot = torch.zeros((bs, 2), dtype=int, device=device)
        action_collision_one_hot = torch.zeros((bs, 2), dtype=int, device=device)

        # fill one-hots
        for b in range(bs):
            gt_rot = action_rot[b]
            gt_rot = aug_utils.quaternion_to_discrete_euler(
                gt_rot, self._rotation_resolution
            )
            action_rot_x_one_hot[b, gt_rot[0]] = 1
            action_rot_y_one_hot[b, gt_rot[1]] = 1
            action_rot_z_one_hot[b, gt_rot[2]] = 1

            # grip
            gt_grip = action_grip[b]
            action_grip_one_hot[b, gt_grip] = 1

            # ignore collision
            gt_ignore_collisions = action_ignore_collisions[b, :]
            action_collision_one_hot[b, gt_ignore_collisions[0]] = 1

        return (
            action_rot_x_one_hot,
            action_rot_y_one_hot,
            action_rot_z_one_hot,
            action_grip_one_hot,
            action_collision_one_hot,
        )

    def get_q(self, out, dims, only_pred=False, get_q_trans=True):
        """
        :param out: output of mvt
        :param dims: tensor dimensions (bs, nc, h, w)
        :param only_pred: some speedupds if the q values are meant only for
            prediction
        :return: tuple of trans_q, rot_q, grip_q and coll_q that is used for
            training and preduction
        """
        bs, nc, h, w = dims
        assert isinstance(only_pred, bool)

        if get_q_trans:
            pts = None
            # (bs, h*w, nc)
            q_trans = out["trans"].view(bs, nc, h * w).transpose(1, 2)
            if not only_pred:
                q_trans = q_trans.clone()

            # if two stages, we concatenate the q_trans, and replace all other
            # q
            if self.stage_two:
                out = out["mvt2"]
                q_trans2 = out["trans"].view(bs, nc, h * w).transpose(1, 2)
                if not only_pred:
                    q_trans2 = q_trans2.clone()
                q_trans = torch.cat((q_trans, q_trans2), dim=2)
        else:
            pts = None
            q_trans = None
            if self.stage_two:
                out = out["mvt2"]

        if self.rot_ver == 0:
            # (bs, 218)
            rot_q = out["feat"].view(bs, -1)[:, 0 : self.num_all_rot]
            grip_q = out["feat"].view(bs, -1)[
                :, self.num_all_rot : self.num_all_rot + 2
            ]
            # (bs, 2)
            collision_q = out["feat"].view(bs, -1)[
                :, self.num_all_rot + 2 : self.num_all_rot + 4
            ]
        elif self.rot_ver == 1:
            rot_q = torch.cat(
                (out["feat_x"], out["feat_y"], out["feat_z"]), dim=-1
            ).view(bs, -1)
            grip_q = out["feat_ex_rot"].view(bs, -1)[:, :2]
            collision_q = out["feat_ex_rot"].view(bs, -1)[:, 2:]
        else:
            assert False

        y_q = None

        return q_trans, rot_q, grip_q, collision_q, y_q, pts
    
    
    def calculate_depth_loss(self, depth_pred, depth_gt):
        """
        计算相对深度排序损失 (Intra-view Relative Depth Loss)
        
        参数:
            depth_pred: 预测的深度图 [N, H, W] 或 [H, W]
            depth_gt: 基础模型提供的深度图 [N, H, W] 或 [H, W]
        
        返回:
            loss: 相对深度损失 (标量)
        """
        # 确保输入是浮点类型 (避免整数计算)
        depth_pred = depth_pred.float()
        depth_gt = depth_gt.float()
        
        # 签名函数定义
        def sign(x):
            return torch.where(x > 0, torch.ones_like(x), 
                            torch.where(x < 0, -torch.ones_like(x), torch.zeros_like(x)))
        
        # 随机采样点对 
        height, width = depth_pred.shape[-2], depth_pred.shape[-1]
        
        # 随机生成坐标点 
        num_points = min(500, height * width)  
        coords = torch.stack([
            torch.randint(0, height, (num_points,)),
            torch.randint(0, width, (num_points,))
        ], dim=-1)  # [num_points, 2]
        
        # 获取采样点对的深度值
        gt_depth_values = depth_gt[..., coords[:, 0], coords[:, 1]]  # [B, num_points] or [num_points]
        pred_depth_values = depth_pred[..., coords[:, 0], coords[:, 1]]  # [B, num_points] or [num_points]
        
        # 创建点对索引 
        idx_i, idx_j = torch.triu_indices(num_points, num_points, offset=1)
        
        # 点对的真实深度关系
        gt_depth_i = gt_depth_values[..., idx_i]
        gt_depth_j = gt_depth_values[..., idx_j]
        s_xy = sign(gt_depth_i - gt_depth_j)  # [-1, 0, +1]
        
        # 点对的预测深度关系分数
        pred_depth_i = pred_depth_values[..., idx_i]
        pred_depth_j = pred_depth_values[..., idx_j]
        s_xy_pred = pred_depth_i - pred_depth_j  # 预测得分
        
        # 只考虑有效点对 (s_xy ≠ 0)
        valid_mask = (s_xy != 0)
        
        # 计算损失项
        exp_term = -s_xy * s_xy_pred
        loss_term = torch.log(1 + torch.exp(exp_term))
        
        # 在有效点对上取平均
        num_valid = valid_mask.sum().float() + 1e-8  # 避免除0
        loss = torch.where(valid_mask, loss_term, torch.zeros_like(loss_term)).sum() / num_valid
        
        return loss

    
    # @timeit
    def calculate_matching_loss(
        self,
        out,
        point_map_view_1,
        point_map_view_2,
        point_map_view_3,
        kp_1,
        kp_2,
        kp_3,
        thres3d_neg=0.1,
        align=True,
        match_input_dict=None,
    ):
        """
        计算多视图特征匹配损失（三视图）

        参数:
            out: 包含特征描述符的字典
            point_map_view_{1,2,3}: 视图点云图 (B, H, W, 3) - 存储每个像素的3D坐标
            kp_{1,2,3}: 关键点坐标列表 [batch_size] -> (N_i, 2)
            thres3d_neg: 负样本3D距离阈值
            align: 表示是否需要对齐特征描述符到关键点
            match_input_dict:用于对齐, 包含所有2D关键点
        返回:
            ap_loss: 平均精度损失标量
        """
        # return torch.tensor(0.0)
    
        def prepare_point_map(p_map, kp):
            """
            从点云图提取关键点对应的3D坐标

            参数:
                p_map: 点云图 (B, H, W, 3)
                kp: 关键点坐标列表 [batch_size] -> (N_i, 2)

            返回:
                pts3d: 填充的3D坐标 (B, max_kp, 3)
            """
            batch_size = len(kp)
            max_kp = max(k.shape[0] for k in kp)

            pts3d = torch.zeros(batch_size, max_kp, 3, device=p_map.device)
            kp_mask = torch.zeros(
                batch_size, max_kp, dtype=torch.bool, device=p_map.device
            )

            for i, k in enumerate(kp):
                kp_mask[i, : k.shape[0]] = True
                h, w = p_map.shape[1:3]
                y_coords = torch.clamp(kp[i][:, 1].long(), 0, h - 1)
                x_coords = torch.clamp(kp[i][:, 0].long(), 0, w - 1)
                pts3d[i, : k.shape[0]] = p_map[i, y_coords, x_coords]

            return pts3d  # (B, N, 3)

        def compute_ap(desc_a, desc_b, pts3d_a, pts3d_b):
            """
            计算两个视图间的平均精度(AP)

            参数:
                desc_a, desc_b: 特征描述符 (B, N, C)
                pts3d_a, pts3d_b: 3D坐标 (B, N, 3)
                前者为pred , 后者为ground_truth
            返回:
                rpos / rall: 每个关键点的平均精度 (N,)
            """

            B, N, _ = desc_a.shape
            eye_mask = torch.eye(N, device=self._device).bool().unsqueeze(0)
            
            # if torch.isnan(pts3d_a).any() or torch.isnan(pts3d_b).any():
            #     print("NaN values detected in 3D points")    
                
            # 创建负样本掩码（距离大于阈值且非自身匹配）
            neg_mask = (torch.cdist(pts3d_a, pts3d_b) > thres3d_neg) & ~eye_mask
            
            #计算特征相似度矩阵
            sim = torch.bmm(desc_a, desc_b.transpose(-1, -2))  # (B, N, N)    
            pos_idxs = torch.stack(
                [
                    torch.zeros(N, dtype=torch.long, device=self._device),
                    torch.arange(N, device=self._device),
                    torch.arange(N, device=self._device),
                ],
                dim=1,
            )

            pos_sim = sim[pos_idxs[:, 0], pos_idxs[:, 1], pos_idxs[:, 2]]
            rpos = torch.sigmoid((1.0 - pos_sim) / 0.01) + 1
            rall = rpos + torch.sum(
                torch.sigmoid((sim[pos_idxs[:, 0], pos_idxs[:, 1]] - 1.0) / 0.01)
                * neg_mask[pos_idxs[:, 0], pos_idxs[:, 1]].float(),
                dim=-1,
            )
            return rpos / rall

        desc_1, desc_2, desc_3 = (
            out["feature_1"],
            out["feature_2"],
            out["feature_3"],
        )  # view features

        if align == True:
            img_size = self._net_mod.img_size
            img_patch_size = self._net_mod.mvt1.img_patch_size
            # print(kp_1.max(), kp_2.max(), kp_3.max(),)
            desc_1 = interpolate_features(desc_1, kp_1.squeeze(), h=img_size, w=img_size, patch_size=img_patch_size, 
                                          stride=img_patch_size, normalize=False)
            desc_2 = interpolate_features(desc_2, kp_2.squeeze(), h=img_size, w=img_size, patch_size=img_patch_size, 
                                          stride=img_patch_size, normalize=False)
            desc_3 = interpolate_features(desc_3, kp_3.squeeze(), h=img_size, w=img_size, patch_size=img_patch_size, 
                                          stride=img_patch_size, normalize=False)
            desc_1 = F.normalize(desc_1.squeeze(2).permute(0, 2, 1), p=2, dim=-1)
            desc_2 = F.normalize(desc_2.squeeze(2).permute(0, 2, 1), p=2, dim=-1)
            desc_3 = F.normalize(desc_3.squeeze(2).permute(0, 2, 1), p=2, dim=-1)
            # img_patch_size = self._net_mod.img_patch_size
            # # renderer = self._net_mod.renderer
            # desc_1, desc_2, desc_3 = align_features_with_kp(
            #     desc_1, desc_2, desc_3, match_input_dict , img_size , img_patch_size 
            # )# 和对齐后形状变为 (B, N, C)

        # 获取3D点坐标（518分辨率）
        pts3d_1 = prepare_point_map(point_map_view_1.squeeze(), kp_1.squeeze())
        pts3d_2 = prepare_point_map(point_map_view_2.squeeze(), kp_2.squeeze())
        pts3d_3 = prepare_point_map(point_map_view_3.squeeze(), kp_3.squeeze())
        # pts3d_1 = prepare_point_map(point_map_view_1, kp_1)
        # pts3d_2 = prepare_point_map(point_map_view_2, kp_2)
        # pts3d_3 = prepare_point_map(point_map_view_3, kp_3)

        ap1 = compute_ap(desc_1, desc_2, pts3d_1, pts3d_2)
        ap2 = compute_ap(desc_2, desc_3, pts3d_2, pts3d_3)
        ap3 = compute_ap(desc_3, desc_1, pts3d_3, pts3d_1)

        ap = (ap1 + ap2 + ap3) / 3
        ap_loss = torch.mean(1.0 - ap)

        return ap_loss

    # @timeit
    def update(
        self,
        step: int,
        replay_sample: dict,
        backprop: bool = True,
        eval_log: bool = False,
        reset_log: bool = False,
        compute_ap:bool = False
    ) -> dict:
        assert replay_sample["rot_grip_action_indicies"].shape[1:] == (1, 4)
        assert replay_sample["ignore_collisions"].shape[1:] == (1, 1)
        assert replay_sample["gripper_pose"].shape[1:] == (1, 7)
        assert replay_sample["lang_goal_embs"].shape[1:] == (1, 77, 512)
        assert replay_sample["low_dim_state"].shape[1:] == (
            1,
            self._net_mod.proprio_dim,
        )

        # sample
        action_rot_grip = replay_sample["rot_grip_action_indicies"][
            :, -1
        ].int()  # (b, 4) of int
        action_ignore_collisions = replay_sample["ignore_collisions"][
            :, -1
        ].int()  # (b, 1) of int
        action_gripper_pose = replay_sample["gripper_pose"][:, -1]  # (b, 7)
        action_trans_con = action_gripper_pose[:, 0:3]  # (b, 3)
        # rotation in quaternion xyzw
        action_rot = action_gripper_pose[:, 3:7]  # (b, 4)
        action_grip = action_rot_grip[:, -1]  # (b,)
        lang_goal_embs = replay_sample["lang_goal_embs"][:, -1].float()
        tasks = replay_sample["tasks"]

        proprio = arm_utils.stack_on_channel(replay_sample["low_dim_state"])  # (b, 4)
        return_out = {}

        obs, pcd, dyn_cam_info = cortical_utils._preprocess_inputs(replay_sample, self.cameras)

        # ===== VGGT related inputs =====
        if compute_ap:
            vggt_features = replay_sample["vggt_features"]
            vggt_features_st2 = replay_sample["vggt_features_st2"]
            kp_1, kp_2, kp_3 = replay_sample["kp_1"], replay_sample["kp_2"], replay_sample["kp_3"]
            kp_1_st2, kp_2_st2, kp_3_st2 = replay_sample["kp_1_st2"], replay_sample["kp_2_st2"], replay_sample["kp_3_st2"]

        with torch.no_grad():
            pc, img_feat = rvt_utils.get_pc_img_feat(
                obs,
                pcd,
            )

            if self._transform_augmentation and backprop:
                action_trans_con, action_rot, pc = apply_se3_aug_con(
                    pcd=pc,
                    action_gripper_pose=action_gripper_pose,
                    bounds=torch.tensor(self.scene_bounds),
                    trans_aug_range=torch.tensor(self._transform_augmentation_xyz),
                    rot_aug_range=torch.tensor(self._transform_augmentation_rpy),
                )
                action_trans_con = torch.tensor(action_trans_con).to(pc.device)
                action_rot = torch.tensor(action_rot).to(pc.device)

            # TODO: vectorize
            action_rot = action_rot.cpu().numpy()
            for i, _action_rot in enumerate(action_rot):
                _action_rot = aug_utils.normalize_quaternion(_action_rot)
                if _action_rot[-1] < 0:
                    _action_rot = -_action_rot
                action_rot[i] = _action_rot

            pc, img_feat = rvt_utils.move_pc_in_bound(
                pc, img_feat, self.scene_bounds, no_op=not self.move_pc_in_bound
            )
            wpt = [x[:3] for x in action_trans_con]

            wpt_local = []
            rev_trans = []
            for _pc, _wpt in zip(pc, wpt):
                a, b = mvt_utils.place_pc_in_cube(
                    _pc,
                    _wpt,
                    with_mean_or_bounds=self._place_with_mean,
                    scene_bounds=None if self._place_with_mean else self.scene_bounds,
                )
                wpt_local.append(a.unsqueeze(0))
                rev_trans.append(b)

            wpt_local = torch.cat(wpt_local, axis=0)

            # TODO: Vectorize
            pc = [
                mvt_utils.place_pc_in_cube(
                    _pc,
                    with_mean_or_bounds=self._place_with_mean,
                    scene_bounds=None if self._place_with_mean else self.scene_bounds,
                )[0]
                for _pc in pc
            ]

            bs = len(pc)
            nc = self._net_mod.num_img
            h = w = self._net_mod.img_size

            if backprop and (self.img_aug != 0):
                img_aug = self.img_aug
            else:
                img_aug = 0


        with autocast(enabled=self.amp):
            (
                action_rot_x_one_hot,
                action_rot_y_one_hot,
                action_rot_z_one_hot,
                action_grip_one_hot,  # (bs, 2)
                action_collision_one_hot,  # (bs, 2)
            ) = self._get_one_hot_expert_actions(
                bs, action_rot, action_grip, action_ignore_collisions, device=self._device
            )

            if self.rot_ver == 1:
                rot_x_y = torch.cat(
                    [
                        action_rot_x_one_hot.argmax(dim=-1, keepdim=True),
                        action_rot_y_one_hot.argmax(dim=-1, keepdim=True),
                    ],
                    dim=-1,
                )
                if self.rot_x_y_aug != 0:
                    # add random interger between -rot_x_y_aug and rot_x_y_aug to rot_x_y
                    rot_x_y += torch.randint(
                        -self.rot_x_y_aug, self.rot_x_y_aug, size=rot_x_y.shape
                    ).to(rot_x_y.device)
                    rot_x_y %= self._num_rotation_classes

            # start = time.time()
            out = self._network(
                pc=pc,
                img_feat=img_feat,
                proprio=proprio,
                lang_emb=lang_goal_embs,
                img_aug=img_aug,
                wpt_local=wpt_local if self._network.training else None,
                rot_x_y=rot_x_y if self.rot_ver == 1 else None,
                dyn_cam_info=dyn_cam_info,
            )

            q_trans, rot_q, grip_q, collision_q, y_q, pts = self.get_q(
                out, dims=(bs, nc, h, w)
            )

            action_trans = self.get_action_trans(
                wpt_local, pts, out, dyn_cam_info, dims=(bs, nc, h, w)
            )
            # if DEBUG:
            #     print(f"rvt 推理动作耗时: {time.time() - start:.4f} 秒") 
        # # --------------------------------------------
        rgb_vggt_1 = out["imgs"][0][:,:,3:6].to(self._device).detach()
        kp_3d_pred = {
            "kp_1": replay_sample["kp_1"],
            "kp_2": replay_sample["kp_2"],
            "kp_3": replay_sample["kp_3"]
        }
        # visualize_comparison(kp_3d_pred, rgb_vggt_1, save_dir=f"debug_runs/comparison_results/stage_1/{self.visual_idx}")
        rgb_vggt_2 = out["imgs"][1][:,:,3:6].to(self._device).detach()
        kp_3d_pred_st2 = {
            "kp_1": replay_sample["kp_1_st2"],
            "kp_2": replay_sample["kp_2_st2"],
            "kp_3": replay_sample["kp_3_st2"]
        }
        # visualize_comparison(kp_3d_pred_st2, rgb_vggt_2, save_dir=f"debug_runs/comparison_results/stage_2/{self.visual_idx}")
        self.visual_idx += 1 
        # # VGGT在线提取特征点
        # if compute_ap:
        #     with torch.no_grad():          
        #         rgb_vggt_1 = out["imgs"][0][:,:,3:6].to(self._device).detach()             
        #         # 冻结的vggt前向推理获得预测点云、关键点等数据
        #         _ , vggt_data_1 = get_vggt_feature_map(
        #             rgb_vggt_1, self.vggt_model, device=self._device , visualize = True
        #         )              
        #         #可视化
        #         # if self.visual_idx <= 10:
        #         #     visualize_comparison(vggt_data_1 , rgb_vggt_1[IMG_IDX] , save_dir=f"debug_runs/comparison_results/stage_1/{self.visual_idx}")
        #         #match_kpts_img = {"kp_1":vggt_data["kp_1"],"kp_2":vggt_data["kp_2"],"kp_3":vggt_data["kp_3"]}
        #         match_kpts_img = [vggt_data_1["kp_1"] , vggt_data_1["kp_2"] , vggt_data_1["kp_3"]]
        #         match_input_dict_1 = {
        #             "normalize": True,
        #             "match_kpts_img": match_kpts_img,
        #         }
        #         if self.stage_two:
        #             rgb_vggt_2 = out["imgs"][1][:,:,3:6].to(self._device).detach()         
        #             # 冻结的vggt前向推理获得预测点云、关键点等数据
        #             _ , vggt_data_2 = get_vggt_feature_map(
        #                 rgb_vggt_2, self.vggt_model, device=self._device , visualize = True
        #             )          
        #             #可视化
        #             # if self.visual_idx <= 10:
        #             #     visualize_comparison(vggt_data_2 , rgb_vggt_2[IMG_IDX] , save_dir=f"debug_runs/comparison_results/stage_2/{self.visual_idx}")     
        #             match_kpts_img = [vggt_data_2["kp_1"] , vggt_data_2["kp_2"] , vggt_data_2["kp_3"]]
        #             match_input_dict_2 = {
        #                 "normalize": True,
        #                 "match_kpts_img": match_kpts_img,
        #             }
        #         self.visual_idx += 1    
        loss_log = {}
        if backprop:
            with autocast(enabled=self.amp):
                # cross-entropy loss
                trans_loss = self._cross_entropy_loss(q_trans, action_trans).mean()
                rot_loss_x = rot_loss_y = rot_loss_z = 0.0
                grip_loss = 0.0
                collision_loss = 0.0
                ap_loss = 0.0
                if compute_ap:
                    thres3d_neg = 0.1
                    #关键点作为gt标签，点云图用于将关键点lift到3维空间
                    #与out中存储的特征进行计算matching_ap_loss
                    ap_loss = self.calculate_matching_loss(
                        out,
                        vggt_features['point_map'][:, 0, ...],   
                        vggt_features['point_map'][:, 1, ...],  
                        vggt_features['point_map'][:, 2, ...],  
                        kp_1 * 128 / 518, kp_2 * 128 / 518, kp_3 * 128 / 518,   
                        thres3d_neg = thres3d_neg,
                        align=True,
                    )

                    if self.stage_two:
                        ap_loss += self.calculate_matching_loss(
                            out["mvt2"],
                            vggt_features_st2['point_map'][:, 0, ...], 
                            vggt_features_st2['point_map'][:, 1, ...], 
                            vggt_features_st2['point_map'][:, 2, ...],  
                            kp_1_st2 * 128 / 518, kp_2_st2 * 128 / 518, kp_3_st2 * 128 / 518,   
                            thres3d_neg = thres3d_neg,
                            align=True,
                        )
                    # ap_loss = self.calculate_matching_loss(
                    #     out,
                    #     vggt_data_1["point_map_view_1"],
                    #     vggt_data_1["point_map_view_2"],
                    #     vggt_data_1["point_map_view_3"],
                    #     vggt_data_1["kp_1"],
                    #     vggt_data_1["kp_2"],
                    #     vggt_data_1["kp_3"],
                    #     thres3d_neg = thres3d_neg,
                    #     align=True,
                    #     match_input_dict=match_input_dict_1,
                    # )

                    # if self.stage_two:
                    #     ap_loss += self.calculate_matching_loss(
                    #         out["mvt2"],
                    #         vggt_data_2["point_map_view_1"],
                    #         vggt_data_2["point_map_view_2"],
                    #         vggt_data_2["point_map_view_3"],
                    #         vggt_data_2["kp_1"],
                    #         vggt_data_2["kp_2"],
                    #         vggt_data_2["kp_3"],
                    #         thres3d_neg = thres3d_neg,
                    #         align=True,
                    #         match_input_dict=match_input_dict_2,
                    #     )                    
                if self.add_rgc_loss:
                    rot_loss_x = self._cross_entropy_loss(
                        rot_q[
                            :,
                            0 * self._num_rotation_classes : 1 * self._num_rotation_classes,
                        ],
                        action_rot_x_one_hot.argmax(-1),
                    ).mean()

                    rot_loss_y = self._cross_entropy_loss(
                        rot_q[
                            :,
                            1 * self._num_rotation_classes : 2 * self._num_rotation_classes,
                        ],
                        action_rot_y_one_hot.argmax(-1),
                    ).mean()

                    rot_loss_z = self._cross_entropy_loss(
                        rot_q[
                            :,
                            2 * self._num_rotation_classes : 3 * self._num_rotation_classes,
                        ],
                        action_rot_z_one_hot.argmax(-1),
                    ).mean()

                    grip_loss = self._cross_entropy_loss(
                        grip_q,
                        action_grip_one_hot.argmax(-1),
                    ).mean()

                    collision_loss = self._cross_entropy_loss(
                        collision_q, action_collision_one_hot.argmax(-1)
                    ).mean()

                total_loss = (
                    trans_loss
                    + rot_loss_x
                    + rot_loss_y
                    + rot_loss_z
                    + grip_loss
                    + collision_loss
                    + ap_loss * 1
                )

            self._optimizer.zero_grad(set_to_none=True)
            self.scaler.scale(total_loss).backward()
            self.scaler.step(self._optimizer)
            self.scaler.update()
            self._lr_sched.step()

            loss_log = {
                "total_loss": total_loss.item(),
                "trans_loss": trans_loss.item(),
                "rot_loss_x": rot_loss_x.item(),
                "rot_loss_y": rot_loss_y.item(),
                "rot_loss_z": rot_loss_z.item(),
                "grip_loss": grip_loss.item(),
                "collision_loss": collision_loss.item(),
                "matching_ap_loss": ap_loss.item() if compute_ap else 0.0,
                "lr": self._optimizer.param_groups[0]["lr"],
            }
            manage_loss_log(self, loss_log, reset_log=reset_log)
            return_out.update(loss_log)

        if eval_log:
            with torch.no_grad():
                wpt = torch.cat([x.unsqueeze(0) for x in wpt])
                pred_wpt, pred_rot_quat, _, _ = self.get_pred(
                    out,
                    rot_q,
                    grip_q,
                    collision_q,
                    y_q,
                    rev_trans,
                    dyn_cam_info=dyn_cam_info,
                )

                return_log = manage_eval_log(
                    self=self,
                    tasks=tasks,
                    wpt=wpt,
                    pred_wpt=pred_wpt,
                    action_rot=action_rot,
                    pred_rot_quat=pred_rot_quat,
                    action_grip_one_hot=action_grip_one_hot,
                    grip_q=grip_q,
                    action_collision_one_hot=action_collision_one_hot,
                    collision_q=collision_q,
                    reset_log=reset_log,
                )

                return_out.update(return_log)

        return return_out

    @torch.no_grad()
    def act(
        self, step: int, observation: dict, deterministic=True, pred_distri=False
    ) -> ActResult:
        if self.add_lang:
            lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
            _, lang_goal_embs = _clip_encode_text(self.clip_model, lang_goal_tokens[0])
            lang_goal_embs = lang_goal_embs.float()
        else:
            lang_goal_embs = (
                torch.zeros(observation["lang_goal_embs"].shape)
                .float()
                .to(self._device)
            )

        proprio = arm_utils.stack_on_channel(observation["low_dim_state"])

        obs, pcd, dyn_cam_info = cortical_utils._preprocess_inputs(observation, self.cameras)
        pc, img_feat = rvt_utils.get_pc_img_feat(
            obs,
            pcd,
        )

        pc, img_feat = rvt_utils.move_pc_in_bound(
            pc, img_feat, self.scene_bounds, no_op=not self.move_pc_in_bound
        )

        # TODO: Vectorize
        pc_new = []
        rev_trans = []
        for _pc in pc:
            a, b = mvt_utils.place_pc_in_cube(
                _pc,
                with_mean_or_bounds=self._place_with_mean,
                scene_bounds=None if self._place_with_mean else self.scene_bounds,
            )
            pc_new.append(a)
            rev_trans.append(b)
        pc = pc_new

        bs = len(pc)
        nc = self._net_mod.num_img
        h = w = self._net_mod.img_size

        out = self._network(
            pc=pc,
            img_feat=img_feat,
            proprio=proprio,
            lang_emb=lang_goal_embs,
            img_aug=0,  # no img augmentation while acting
            dyn_cam_info=dyn_cam_info,
        )
        _, rot_q, grip_q, collision_q, y_q, _ = self.get_q(
            out, dims=(bs, nc, h, w), only_pred=True, get_q_trans=False
        )
        pred_wpt, pred_rot_quat, pred_grip, pred_coll = self.get_pred(
            out, rot_q, grip_q, collision_q, y_q, rev_trans, dyn_cam_info
        )

        continuous_action = np.concatenate(
            (
                pred_wpt[0].cpu().numpy(),
                pred_rot_quat[0],
                pred_grip[0].cpu().numpy(),
                pred_coll[0].cpu().numpy(),
            )
        )
        if pred_distri:
            x_distri = rot_grip_q[
                0,
                0 * self._num_rotation_classes : 1 * self._num_rotation_classes,
            ]
            y_distri = rot_grip_q[
                0,
                1 * self._num_rotation_classes : 2 * self._num_rotation_classes,
            ]
            z_distri = rot_grip_q[
                0,
                2 * self._num_rotation_classes : 3 * self._num_rotation_classes,
            ]
            return ActResult(continuous_action), (
                x_distri.cpu().numpy(),
                y_distri.cpu().numpy(),
                z_distri.cpu().numpy(),
            )
        else:
            return ActResult(continuous_action)

    def get_pred(
        self,
        out,
        rot_q,
        grip_q,
        collision_q,
        y_q,
        rev_trans,
        dyn_cam_info,
    ):
        if self.stage_two:
            assert y_q is None
            mvt1_or_mvt2 = False
        else:
            mvt1_or_mvt2 = True

        pred_wpt_local = self._net_mod.get_wpt(
            out, mvt1_or_mvt2, dyn_cam_info, y_q
        )

        pred_wpt = []
        for _pred_wpt_local, _rev_trans in zip(pred_wpt_local, rev_trans):
            pred_wpt.append(_rev_trans(_pred_wpt_local))
        pred_wpt = torch.cat([x.unsqueeze(0) for x in pred_wpt])

        pred_rot = torch.cat(
            (
                rot_q[
                    :,
                    0 * self._num_rotation_classes : 1 * self._num_rotation_classes,
                ].argmax(1, keepdim=True),
                rot_q[
                    :,
                    1 * self._num_rotation_classes : 2 * self._num_rotation_classes,
                ].argmax(1, keepdim=True),
                rot_q[
                    :,
                    2 * self._num_rotation_classes : 3 * self._num_rotation_classes,
                ].argmax(1, keepdim=True),
            ),
            dim=-1,
        )
        pred_rot_quat = aug_utils.discrete_euler_to_quaternion(
            pred_rot.cpu(), self._rotation_resolution
        )
        pred_grip = grip_q.argmax(1, keepdim=True)
        pred_coll = collision_q.argmax(1, keepdim=True)

        return pred_wpt, pred_rot_quat, pred_grip, pred_coll

    @torch.no_grad()
    def get_action_trans(
        self,
        wpt_local,
        pts,
        out,
        dyn_cam_info,
        dims,
    ):
        bs, nc, h, w = dims
        wpt_img = self._net_mod.get_pt_loc_on_img(
            wpt_local.unsqueeze(1),
            mvt1_or_mvt2=True,
            dyn_cam_info=dyn_cam_info,
            out=None
        )
        assert wpt_img.shape[1] == 1
        if self.stage_two:
            wpt_img2 = self._net_mod.get_pt_loc_on_img(
                wpt_local.unsqueeze(1),
                mvt1_or_mvt2=False,
                dyn_cam_info=dyn_cam_info,
                out=out,
            )
            assert wpt_img2.shape[1] == 1

            # (bs, 1, 2 * num_img, 2)
            wpt_img = torch.cat((wpt_img, wpt_img2), dim=-2)
            nc = nc * 2

        # (bs, num_img, 2)
        wpt_img = wpt_img.squeeze(1)

        action_trans = mvt_utils.generate_hm_from_pt(
            wpt_img.reshape(-1, 2),
            (h, w),
            sigma=self.gt_hm_sigma,
            thres_sigma_times=3,
        )
        action_trans = action_trans.view(bs, nc, h * w).transpose(1, 2).clone()

        return action_trans

    def reset(self):
        pass

    def eval(self):
        self._network.eval()

    def train(self):
        self._network.train()
