import numpy as np
from vggt.models.vggt import VGGT
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map

from dataclasses import dataclass
import torch
from PIL import Image
from torchvision import transforms as TF
from scipy.spatial import cKDTree


@dataclass
class VGGTPipelineConfig:
    """VGGT 点云生成参数"""
    model_path: str = "facebook/VGGT-1B"
    device: str = "cuda"
    dtype: str = "float16"  # Changed to float16 to match model weights
    conf_thres: float = 50.0  # 过滤低置信度点
    mask_black_bg: bool = False  # 是否去除黑色背景
    mask_white_bg: bool = False  # 是否去除白色背景
    mask_gray_bg: bool = False  # 是否去除灰色背景
    prediction_mode: str = (
        "Predicted Pointmap"  # 模式选择 ("Predicted Pointmap" / "Depthmap")
    )
    multi_reference: bool = True  # 是否使用多参考帧



class VGGTPipeline:
    def __init__(self, cfg: VGGTPipelineConfig):
        self.cfg = cfg
        # Convert string dtype to torch.dtype
        self.dtype = getattr(torch, cfg.dtype) if hasattr(torch, cfg.dtype) else torch.float16
        self.model = VGGT.from_pretrained(cfg.model_path).to(cfg.device)
        
    def __del__(self):
        del self.model
        torch.cuda.empty_cache()

    def preprocess_images(self, images: np.ndarray = None, mode="crop"):
        cfg = self.cfg
        # Check for empty list
        if len(images) == 0:
            raise ValueError("At least 1 image is required")
        # Validate mode
        if mode not in ["crop", "pad"]:
            raise ValueError("Mode must be either 'crop' or 'pad'")

        result_images = []
        shapes = set()
        to_tensor = TF.ToTensor()
        target_size = 518

        # First process all images and collect their shapes
        for img in images:
            width, height = img.size
            if mode == "pad":
                # Make the largest dimension 518px while maintaining aspect ratio
                if width >= height:
                    new_width = target_size
                    new_height = (
                        round(height * (new_width / width) / 14) * 14
                    )  # Make divisible by 14
                else:
                    new_height = target_size
                    new_width = (
                        round(width * (new_height / height) / 14) * 14
                    )  # Make divisible by 14
            else:  # mode == "crop"
                # Original behavior: set width to 518px
                new_width = target_size
                # Calculate height maintaining aspect ratio, divisible by 14
                new_height = round(height * (new_width / width) / 14) * 14

            # Resize with new dimensions (width, height)
            img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
            img = to_tensor(img)  # Convert to tensor (0, 1)

            # Center crop height if it's larger than 518 (only in crop mode)
            if mode == "crop" and new_height > target_size:
                start_y = (new_height - target_size) // 2
                img = img[:, start_y : start_y + target_size, :]

            # For pad mode, pad to make a square of target_size x target_size
            if mode == "pad":
                h_padding = target_size - img.shape[1]
                w_padding = target_size - img.shape[2]

                if h_padding > 0 or w_padding > 0:
                    pad_top = h_padding // 2
                    pad_bottom = h_padding - pad_top
                    pad_left = w_padding // 2
                    pad_right = w_padding - pad_left

                    # Pad with white (value=1.0)
                    img = torch.nn.functional.pad(
                        img,
                        (pad_left, pad_right, pad_top, pad_bottom),
                        mode="constant",
                        value=1.0,
                    )

            shapes.add((img.shape[1], img.shape[2]))
            result_images.append(img)

        # Check if we have different shapes
        # In theory our model can also work well with different shapes
        if len(shapes) > 1:
            print(f"Warning: Found images with different shapes: {shapes}")
            # Find maximum dimensions
            max_height = max(shape[0] for shape in shapes)
            max_width = max(shape[1] for shape in shapes)

            # Pad images if necessary
            padded_images = []
            for img in result_images:
                h_padding = max_height - img.shape[1]
                w_padding = max_width - img.shape[2]

                if h_padding > 0 or w_padding > 0:
                    pad_top = h_padding // 2
                    pad_bottom = h_padding - pad_top
                    pad_left = w_padding // 2
                    pad_right = w_padding - pad_left

                    img = torch.nn.functional.pad(
                        img,
                        (pad_left, pad_right, pad_top, pad_bottom),
                        mode="constant",
                        value=1.0,
                    )
                padded_images.append(img)
            result_images = padded_images
        result_images = torch.stack(result_images)  # concatenate images

        # Ensure correct shape when single image
        if len(images) == 1:
            # Verify shape is (1, C, H, W)
            if result_images.dim() == 3:
                result_images = result_images.unsqueeze(0)
        return result_images

    def run(self, images: np.ndarray = None):
        """
        一键运行 VGGT 推理并生成点云 (coords, rgb)。

        支持两种模式：
        - 单次推理模式（cfg.multi_reference=False）：一次性输入所有视角
        - 多参考视角模式（cfg.multi_reference=True）：每张图片依次作为参考视角，独立推理
        """
        cfg = self.cfg
        self.model.eval()
        print(f"🚀 开始 VGGT 推理，输入图片数量: {len(images)}，设备: {cfg.device}")

        # ---- Step 1: 图片预处理 ----
        images_tensor = self.preprocess_images(images).to(cfg.device)
        print(f"预处理后图片形状: {images_tensor.shape} (batch, channels, height, width)")

        # 判断推理模式
        if not getattr(cfg, "multi_reference", False):
            print("🧩 当前为【单次推理模式】")
            all_coords, all_rgb = self._run_single_inference(images_tensor)
        else:
            print("🔁 当前为【多参考视角模式】")
            all_coords, all_rgb = self._run_multi_reference_inference_fused(images_tensor)

        print(f"✅ VGGT 点云生成完成，总点数: {all_coords.shape[0]}")
        torch.cuda.empty_cache()
        return all_coords, all_rgb
    

    def _predictions_to_pointcloud(self, predictions):
        """提取预测点云的通用函数"""

        cfg = self.cfg
        if not isinstance(predictions, dict):
            raise ValueError("predictions 必须是一个字典")

        if "Pointmap" in cfg.prediction_mode and "world_points" in predictions:
            world_points = predictions["world_points"]
            conf = predictions.get("world_points_conf", np.ones_like(world_points[..., 0]))
        else:
            world_points = predictions["world_points_from_depth"]
            conf = predictions.get("depth_conf", np.ones_like(world_points[..., 0]))

        images = predictions["images"]

        if images.ndim == 4 and images.shape[1] == 3:
            images = np.transpose(images, (0, 2, 3, 1))

        coords = world_points.reshape(-1, 3)
        coords[..., 1:] *= -1
        rgb = (images.reshape(-1, 3) * 255).astype(np.uint8)
        conf = conf.reshape(-1)

        conf_threshold = np.percentile(conf, cfg.conf_thres) if cfg.conf_thres > 0 else 0.0
        mask = (conf >= conf_threshold) & (conf > 1e-5)

        if cfg.mask_black_bg:
            black_mask = rgb.sum(axis=1) >= 16
            mask &= black_mask
        if cfg.mask_white_bg:
            white_mask = ~((rgb[:, 0] > 240) & (rgb[:, 1] > 240) & (rgb[:, 2] > 240))
            mask &= white_mask
        if cfg.mask_gray_bg:
            # 判断是否为接近 #828282 的灰色 (RGB: 130, 130, 130)
            # 允许容差范围：RGB 值在 110-150 之间，且三个通道值接近（标准差小于20）
            target_gray = 130  # #828282 的 RGB 值
            tolerance = 20  # 容差范围
            rgb_mean = rgb.mean(axis=1)  # 每个像素的平均RGB值
            rgb_std = rgb.std(axis=1)  # 每个像素RGB值的标准差
            # 灰色判断条件：
            # 1. RGB均值接近目标灰色（130 ± 20，即 110-150）
            # 2. RGB三个通道值接近（标准差小于20）
            gray_mask = ~((rgb_mean >= target_gray - tolerance) & (rgb_mean <= target_gray + tolerance) & (rgb_std < 20))
            mask &= gray_mask

        coords = coords[mask]
        rgb = rgb[mask]

        if coords.size == 0:
            coords = np.array([[0.0, 0.0, 0.0]])
            rgb = np.array([[255, 255, 255]], dtype=np.uint8)
        return coords, rgb
        
    def _conf_to_pointcloud_with_conf(self, predictions):
        cfg = self.cfg
        
        # ---- Step 1: 提取原始数据 ----
        if "Pointmap" in cfg.prediction_mode and "world_points" in predictions:
            world_points = predictions["world_points"]
            conf = predictions.get("world_points_conf", np.ones_like(world_points[..., 0]))
        else:
            world_points = predictions["world_points_from_depth"]
            conf = predictions.get("depth_conf", np.ones_like(world_points[..., 0]))

        # 获取原始图像颜色用于背景过滤 (形状变为 N, 3)
        images = predictions["images"]
        if images.ndim == 4 and images.shape[1] == 3:
            images = np.transpose(images, (0, 2, 3, 1))
        raw_rgb = (images.reshape(-1, 3) * 255).astype(np.uint8)

        # 展平坐标
        coords = world_points.reshape(-1, 3)
        coords[..., 1:] *= -1 # 适配坐标系
        conf = conf.reshape(-1)

        # ---- Step 2: 掩码过滤逻辑 (针对原始颜色和置信度) ----
        conf_threshold = np.percentile(conf, cfg.conf_thres) if cfg.conf_thres > 0 else 0.0
        mask = (conf >= conf_threshold) & (conf > 1e-5)

        if cfg.mask_black_bg:
            mask &= (raw_rgb.sum(axis=1) >= 16)
        if cfg.mask_white_bg:
            mask &= ~((raw_rgb[:, 0] > 240) & (raw_rgb[:, 1] > 240) & (raw_rgb[:, 2] > 240))
        if cfg.mask_gray_bg:
            # 判断是否为接近 #828282 的灰色 (RGB: 130, 130, 130)
            target_gray = 130  # #828282 的 RGB 值
            tolerance = 20  # 容差范围
            rgb_mean = raw_rgb.mean(axis=1)
            rgb_std = raw_rgb.std(axis=1)
            mask &= ~((rgb_mean >= target_gray - tolerance) & (rgb_mean <= target_gray + tolerance) & (rgb_std < 20))

        # ✅ 统一应用掩码，确保三者长度一致
        coords = coords[mask]
        conf = conf[mask]
        raw_rgb = raw_rgb[mask]

        # ---- Step 3: 着色逻辑 ----
        # 默认使用置信度着色，如果配置了使用原始颜色，则切换
        # if getattr(cfg, "use_confidence_color", True):
        #     import matplotlib.cm as cm
        #     # 归一化当前切片的置信度
        #     conf_norm = (conf - conf.min()) / (conf.max() - conf.min() + 1e-8)
        #     colormap = cm.get_cmap("turbo")
        #     rgb_output = (colormap(conf_norm)[:, :3] * 255).astype(np.uint8)
        # else:
        rgb_output = raw_rgb

        # ---- Step 4: 空结果处理 ----
        if coords.size == 0:
            return np.zeros((1, 3)), np.ones((1, 3), dtype=np.uint8)*255, np.zeros(1)

        return coords, rgb_output, conf
    
    def _conf_to_pointcloud(self, predictions):
        """提取预测点云的通用函数，可选择使用置信度代替 RGB 可视化"""

        cfg = self.cfg
        if not isinstance(predictions, dict):
            raise ValueError("predictions 必须是一个字典")

        # ---- Step 1: 选择点与置信度来源 ----
        if "Pointmap" in cfg.prediction_mode and "world_points" in predictions:
            world_points = predictions["world_points"]
            conf = predictions.get("world_points_conf", np.ones_like(world_points[..., 0]))
            print(f"world_points shape: {world_points.shape}")
            print(f"conf shape: {conf.shape}")

        else:
            world_points = predictions["world_points_from_depth"]
            conf = predictions.get("depth_conf", np.ones_like(world_points[..., 0]))

        images = predictions["images"]

        # ---- Step 2: 图片维度标准化 ----
        if images.ndim == 4 and images.shape[1] == 3:
            images = np.transpose(images, (0, 2, 3, 1))

        coords = world_points.reshape(-1, 3)
        coords[..., 1:] *= -1

        # ---- Step 3: 置信度映射为颜色 (替代 RGB) ----
        conf = conf.reshape(-1)
        conf_norm = (conf - conf.min()) / (conf.max() - conf.min() + 1e-8)

        # ✅ 使用 matplotlib colormap 映射为伪彩色（如“turbo”或“viridis”）
        import matplotlib.cm as cm
        colormap = cm.get_cmap("turbo")
        rgb = (colormap(conf_norm)[:, :3] * 255).astype(np.uint8)

        # 如果仍希望保留原始图片颜色，可添加一个配置开关
        # if not cfg.use_confidence_color:
        #     rgb = (images.reshape(-1, 3) * 255).astype(np.uint8)

        # ---- Step 4: 掩码过滤 ----
        conf_threshold = np.percentile(conf, cfg.conf_thres) if cfg.conf_thres > 0 else 0.0

        mask = (conf >= conf_threshold) & (conf > 1e-5)

        if cfg.mask_black_bg:
            black_mask = rgb.sum(axis=1) >= 16
            mask &= black_mask
        if cfg.mask_white_bg:
            white_mask = ~((rgb[:, 0] > 240) & (rgb[:, 1] > 240) & (rgb[:, 2] > 240))
            mask &= white_mask
        if cfg.mask_gray_bg:
            # 判断是否为接近 #828282 的灰色 (RGB: 130, 130, 130)
            # 允许容差范围：RGB 值在 110-150 之间，且三个通道值接近（标准差小于20）
            target_gray = 130  # #828282 的 RGB 值
            tolerance = 20  # 容差范围
            rgb_mean = rgb.mean(axis=1)  # 每个像素的平均RGB值
            rgb_std = rgb.std(axis=1)  # 每个像素RGB值的标准差
            # 灰色判断条件：
            # 1. RGB均值接近目标灰色（130 ± 20，即 110-150）
            # 2. RGB三个通道值接近（标准差小于20）
            gray_mask = ~((rgb_mean >= target_gray - tolerance) & (rgb_mean <= target_gray + tolerance) & (rgb_std < 20))
            mask &= gray_mask

        coords = coords[mask]
        rgb = rgb[mask]

        if coords.size == 0:
            coords = np.array([[0.0, 0.0, 0.0]])
            rgb = np.array([[255, 255, 255]], dtype=np.uint8)

        return coords, rgb
    
    def _run_single_inference(self, images_tensor: torch.Tensor):
        cfg = self.cfg
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=False):
                predictions = self.model(images_tensor)

        extrinsic, intrinsic = pose_encoding_to_extri_intri(
            predictions["pose_enc"], images_tensor.shape[-2:]
        )
        predictions["extrinsic"] = extrinsic
        predictions["intrinsic"] = intrinsic

        # 转 numpy
        for key in predictions.keys():
            if isinstance(predictions[key], torch.Tensor):
                predictions[key] = predictions[key].cpu().numpy().squeeze(0)

        # 深度图 → 点云
        depth_map = predictions["depth"]
        world_points = unproject_depth_map_to_point_map(
            depth_map, predictions["extrinsic"], predictions["intrinsic"]
        )
        predictions["world_points_from_depth"] = world_points

        # 转换为点云
        coords, rgb = self._predictions_to_pointcloud(predictions)
        return coords, rgb
    
    # def _run_multi_reference_inference_fused(self, images_tensor: torch.Tensor, voxel_size: float = 0.01):
    #     """
    #     多参考图推理 + 置信度加权融合点云
    #     Args:
    #         images_tensor (torch.Tensor): 多视角图片 (N, 3, H, W)
    #         voxel_size (float): 融合体素大小
    #     Returns:
    #         coords_fused (np.ndarray): 融合后的点云坐标 (M, 3)
    #         rgb_fused (np.ndarray): 融合后的点云颜色 (M, 3)
    #     """
    #     cfg = self.cfg
    #     num_views = images_tensor.shape[0]

    #     all_coords, all_rgb, all_conf = [], [], []

    #     for ref_idx in range(num_views):
    #         print(f"\n📸 以第 {ref_idx} 张图片作为参考图进行推理...")

    #         # 将参考图放在第一位
    #         reordered = torch.cat([
    #             images_tensor[ref_idx:ref_idx+1],
    #             torch.cat([images_tensor[:ref_idx], images_tensor[ref_idx+1:]], dim=0)
    #         ], dim=0).unsqueeze(0)  # 添加 batch 维度

    #         # 模型推理
    #         with torch.no_grad():
    #             with torch.cuda.amp.autocast(enabled=False):
    #                 predictions = self.model(reordered)

    #         # 相机参数
    #         extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], reordered.shape[-2:])
    #         predictions["extrinsic"] = extrinsic
    #         predictions["intrinsic"] = intrinsic

    #         # 转 numpy
    #         for key in predictions.keys():
    #             if isinstance(predictions[key], torch.Tensor):
    #                 predictions[key] = predictions[key].cpu().numpy().squeeze(0)

    #         # 深度图 → 点云
    #         depth_map = predictions["depth"]
    #         world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
    #         predictions["world_points_from_depth"] = world_points

    #         # 点云 + 颜色 + 置信度
    #         coords, rgb, conf = self._conf_to_pointcloud_with_conf(predictions)
    #         all_coords.append(coords)
    #         all_rgb.append(rgb)
    #         all_conf.append(conf)

    #         print(f"✅ 第 {ref_idx} 张参考图完成，共 {coords.shape[0]} 个点")
    #         torch.cuda.empty_cache()
    #         del predictions

    #     # ------------------ 融合点云 ------------------
    #     all_coords = np.concatenate(all_coords, axis=0)
    #     all_rgb = np.concatenate(all_rgb, axis=0)
    #     all_conf = np.concatenate(all_conf, axis=0)

    #     # 体素 + 置信度加权融合
    #     voxel_idx = np.floor(all_coords / voxel_size).astype(np.int32)
    #     voxel_keys = [tuple(v) for v in voxel_idx]

    #     voxel_dict = {}
    #     for i, key in enumerate(voxel_keys):
    #         if key not in voxel_dict:
    #             voxel_dict[key] = {"coords": [], "rgb": [], "conf": []}
    #         voxel_dict[key]["coords"].append(all_coords[i])
    #         voxel_dict[key]["rgb"].append(all_rgb[i])
    #         voxel_dict[key]["conf"].append(all_conf[i])

    #     coords_fused, rgb_fused = [], []
    #     for voxel in voxel_dict.values():
    #         pts = np.array(voxel["coords"])
    #         colors = np.array(voxel["rgb"])
    #         confs = np.array(voxel["conf"])
    #         weight = confs / (np.sum(confs) + 1e-8)
    #         coords_fused.append(np.sum(pts * weight[:, None], axis=0))
    #         rgb_fused.append(np.sum(colors * weight[:, None], axis=0))

    #     coords_fused = np.array(coords_fused)
    #     rgb_fused = np.clip(np.array(rgb_fused), 0, 255).astype(np.uint8)

    #     print(f"🎯 多参考图融合完成，总点数: {coords_fused.shape[0]}")
    #     return coords_fused, rgb_fused
    def _run_multi_reference_inference_fused(self, images_tensor: torch.Tensor, voxel_size: float = 0.015):
        """
        修正版：多参考图推理 + 跨序列坐标精准对齐
        """
        num_views = images_tensor.shape[0]
        all_pts, all_colors, all_confs = [], [], []
        
        # 统一基准：以第 0 张图作为全局坐标系的原点
        base_view_idx = 0
        base_extrinsic_in_base_run = None

        # ---- 第一轮：获取基准位姿 ----
        # 先单独跑一次 ref_idx=0，记录下 0 号相机在它自己世界系下的位姿 (通常是 I 和 0)
        # 这一步是为了后续所有轮次都能精确回归到这个原点
        
        for ref_idx in range(num_views):
            print(f"🔄 正在处理参考视角 {ref_idx}/{num_views-1}...")
            
            # 构建推理序列：参考图放第一位
            indices = [ref_idx] + [i for i in range(num_views) if i != ref_idx]
            reordered = images_tensor[indices].unsqueeze(0)

            with torch.no_grad():
                predictions = self.model(reordered)
            
            # 解析外参 (1, 6, 3, 4) -> (6, 3, 4)
            extri, _ = pose_encoding_to_extri_intri(predictions["pose_enc"], reordered.shape[-2:])
            curr_extris = extri.cpu().numpy().squeeze(0)
            
            # 转 Numpy 用于提取点云
            preds_np = {k: (v.cpu().numpy().squeeze(0) if isinstance(v, torch.Tensor) else v) 
                        for k, v in predictions.items()}
            
            # 提取当前轮次的点云（此时在 ref_idx 的局部世界系下）
            coords, rgb, conf = self._conf_to_pointcloud_with_conf(preds_np)

            # ---- 坐标系对齐逻辑 (精准修正) ----
            # 1. 找到“0号相机”在当前这组外参里的索引
            idx_of_base_cam = indices.index(base_view_idx)
            
            # 2. 获取 0 号相机在当前局部世界系下的位姿 M_curr
            # M_curr 是从 0号相机空间 -> 当前局部世界空间 的变换
            R_curr = curr_extris[idx_of_base_cam, :3, :3]
            t_curr = curr_extris[idx_of_base_cam, :3, 3]
            M_curr = np.eye(4)
            M_curr[:3, :3] = R_curr
            M_curr[:3, 3] = t_curr

            if ref_idx == base_view_idx:
                # 记录基准位姿 (通常就是 Identity)
                base_extrinsic_in_base_run = M_curr
                coords_unified = coords
            else:
                # 3. 计算变换矩阵 T：将点从“当前局部世界系”转到“基准世界系”
                # 公式：P_base = M_base * Inv(M_curr) * P_curr
                M_curr_inv = np.eye(4)
                M_curr_inv[:3, :3] = R_curr.T
                M_curr_inv[:3, 3] = -R_curr.T @ t_curr
                
                T = base_extrinsic_in_base_run @ M_curr_inv
                
                # 应用变换
                hom_coords = np.concatenate([coords, np.ones((coords.shape[0], 1))], axis=-1)
                coords_unified = (hom_coords @ T.T)[:, :3]
            
            # 在加入之前对点云坐标进行归一化
            # 先中心化到原点
            coords_mean = np.mean(coords_unified, axis=0)
            coords_unified = coords_unified - coords_mean
            
            # 计算每个轴的范围（最大值-最小值）
            x_length = np.max(coords_unified[:, 0]) - np.min(coords_unified[:, 0])
            y_length = np.max(coords_unified[:, 1]) - np.min(coords_unified[:, 1])
            z_length = np.max(coords_unified[:, 2]) - np.min(coords_unified[:, 2])
            
            # 找到最长的轴，等比例缩放使得最长轴的范围是[-1, 1]
            max_length = max(x_length, y_length, z_length) / 2
            if max_length > 0:
                coords_unified = coords_unified / max_length
            
            all_pts.append(coords_unified)
            all_colors.append(rgb)
            all_confs.append(conf)
            
            torch.cuda.empty_cache()

        # ------------------ 融合与 1/6 数量控制 ------------------
        # 此部分逻辑保持不变：跨视角提取每个体素内置信度最高的点
        print("💡 正在跨视角合并精华点云...")
        all_pts = np.concatenate(all_pts, axis=0)
        all_colors = np.concatenate(all_colors, axis=0)
        all_confs = np.concatenate(all_confs, axis=0)

        voxel_idx = np.floor(all_pts / voxel_size).astype(np.int32)
        best_point_idx_map = {}
        for i, v_key in enumerate(map(tuple, voxel_idx)):
            if v_key not in best_point_idx_map or all_confs[i] > best_point_idx_map[v_key][0]:
                best_point_idx_map[v_key] = [all_confs[i], i]

        final_indices = [v[1] for v in best_point_idx_map.values()]
        
        # 严格限制点数在单次运行水平（约26万点）
        max_points = 268324 
        if len(final_indices) > max_points:
            # 按置信度保留 Top-K
            subset_confs = all_confs[final_indices]
            top_k_sub_idx = np.argsort(subset_confs)[-max_points:]
            final_indices = [final_indices[i] for i in top_k_sub_idx]

        return all_pts[final_indices], all_colors[final_indices]