import os

import sys
# 获取当前脚本的绝对路径（例如：/path/to/MV-Adapter/mvadapter/data/multiview_8views.py）
current_script_path = os.path.abspath(__file__)

# 计算项目根目录的路径（假设项目根目录是 MV-Adapter/）
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_script_path)))

# 将项目根目录添加到 sys.path
sys.path.insert(0, project_root)
# 标记当前目录为包的一部分
__package__ = "mvadapter.data"


import json
# import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import random
from dataclasses import dataclass, field

import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from ..utils.config import parse_structured
from ..utils.geometry import (
    get_plucker_embeds_from_cameras,
    get_plucker_embeds_from_cameras_ortho,
    get_position_map_from_depth,
    get_position_map_from_depth_ortho,
)
from ..utils.typing import *

from pathlib import Path


def _parse_scene_list_single(scene_list_path: str, root_data_dir: str):
    all_scenes = []
    if scene_list_path.endswith(".json"):
        with open(scene_list_path) as f:
            for p in json.loads(f.read()):
                if "/" in p:
                    all_scenes.append(os.path.join(root_data_dir, p))
                else:
                    all_scenes.append(os.path.join(root_data_dir, p[:2], p))
    elif scene_list_path.endswith(".txt"):
        with open(scene_list_path) as f:
            for p in f.readlines():
                p = p.strip()
                if "/" in p:
                    all_scenes.append(os.path.join(root_data_dir, p))
                else:
                    all_scenes.append(os.path.join(root_data_dir, p[:2], p))
    else:
        raise NotImplementedError

    return all_scenes


def _parse_scene_list(
    scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]]
):
    all_scenes = []
    if isinstance(scene_list_path, str):
        scene_list_path = [scene_list_path]
    if isinstance(root_data_dir, str):
        root_data_dir = [root_data_dir]
    for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir):
        all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_)
    return all_scenes


def _parse_reference_scene_list(reference_scenes: List[str], all_scenes: List[str]):
    all_ids = set(scene.split("/")[-1] for scene in all_scenes)
    ref_ids = set(scene.split("/")[-1] for scene in reference_scenes)
    common_ids = ref_ids.intersection(all_ids)
    all_scenes = [scene for scene in all_scenes if scene.split("/")[-1] in common_ids]
    all_ids = {scene.split("/")[-1]: idx for idx, scene in enumerate(all_scenes)}

    ref_scenes = [
        scene for scene in reference_scenes if scene.split("/")[-1] in all_ids
    ]
    sorted_ref_scenes = sorted(ref_scenes, key=lambda x: all_ids[x.split("/")[-1]])
    scene2ref = {
        scene: ref_scene for scene, ref_scene in zip(all_scenes, sorted_ref_scenes)
    }

    return all_scenes, scene2ref


@dataclass
class MultiviewDataModuleConfig:
    root_dir: Any = ""
    pbr_root_dir: Any = ""
    rgb_root_dir: Any = ""
    scene_list: Any = ""
    image_suffix: str = "png" # webp
    background_color: Union[str, float] = "gray"
    image_names: List[str] = field(default_factory=lambda: [])
    image_modality: str = "render"
    num_views: int = 1
    random_view_list: Optional[List[List[int]]] = None

    prompt_db_path: Optional[str] = None
    return_prompt: bool = False
    use_empty_prompt: bool = False
    prompt_prefix: Optional[Any] = None
    return_one_prompt: bool = True

    projection_type: str = "ORTHO"

    # source conditions
    source_image_modality: Any = "position"
    use_camera_space_normal: bool = False
    position_offset: float = 0.5
    position_scale: float = 1.0
    plucker_offset: float = 1.0
    plucker_scale: float = 2.0

    # reference image
    reference_root_dir: Optional[Any] = None
    reference_scene_list: Optional[Any] = None
    reference_image_modality: str = "render"
    reference_image_names: List[str] = field(default_factory=lambda: [])
    reference_augment_resolutions: Optional[List[int]] = None
    reference_mask_aug: bool = False

    repeat: int = 1  # for debugging purpose

    train_indices: Optional[Tuple[Any, Any]] = None
    val_indices: Optional[Tuple[Any, Any]] = None
    test_indices: Optional[Tuple[Any, Any]] = None

    height: int = 768
    width: int = 768

    batch_size: int = 1
    eval_batch_size: int = 1

    num_workers: int = 16


class MultiviewDataset(Dataset):
    def __init__(self, cfg: Any, split: str = "train") -> None:
        super().__init__()
        assert split in ["train", "val", "test"]
        self.cfg: MultiviewDataModuleConfig = cfg
        self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.root_dir)

        # 新增调试标志和计数器
        self.debug_save = True  # 开启调试保存
        self.debug_save_path = "./debug_images"  # 调试保存路径
        self.debug_save_counter = 0  # 限制保存次数
        os.makedirs(self.debug_save_path, exist_ok=True)

        # 新增：缓存机制 ===================================================
        # 生成缓存文件名（基于场景列表和根目录的哈希值）
        import hashlib
        cache_key = f"{self.cfg.scene_list}_{self.cfg.root_dir}"
        cache_hash = hashlib.md5(cache_key.encode()).hexdigest()
        self.cache_file = os.path.join(project_root, f"valid_scenes_{cache_hash}.txt")
        
        # 如果缓存文件存在，直接加载有效场景
        if os.path.exists(self.cache_file):
            print(f"从缓存加载有效场景: {self.cache_file}")
            with open(self.cache_file, "r") as f:
                self.valid_scenes = [line.strip() for line in f.readlines()]
            print(f"加载 {len(self.valid_scenes)} 个有效场景")
        else:
            # 否则执行完整验证
            print("执行完整场景验证...")
            from tqdm import tqdm
            
            self.valid_scenes = []

            kkk = 0
            for scene_dir in tqdm(self.all_scenes, desc="验证数据完整性"):
                kkk += 1
                # if kkk > 5000:
                #     break
                if self._validate_scene(scene_dir):
                    self.valid_scenes.append(scene_dir)

            # for scene_dir in tqdm(self.all_scenes, desc="验证数据完整性"):
            #     if self._validate_scene(scene_dir):
            #         self.valid_scenes.append(scene_dir)
            
            print(f"验证完成，有效场景数: {len(self.valid_scenes)}/{len(self.all_scenes)}")
            
            # 保存结果到缓存文件
            with open(self.cache_file, "w") as f:
                for scene in self.valid_scenes:
                    f.write(scene + "\n")
            print(f"结果已保存到缓存: {self.cache_file}")
        # ================================================================

        # 原始引用场景处理逻辑（需要根据过滤后的场景调整）=======================
        if (
            self.cfg.reference_root_dir is not None
            and self.cfg.reference_scene_list is not None
        ):
            reference_scenes = _parse_scene_list(
                self.cfg.reference_scene_list, self.cfg.reference_root_dir
            )
            # 使用过滤后的场景重新建立引用关系
            self.valid_scenes, self.reference_scenes = _parse_reference_scene_list(
                reference_scenes, self.valid_scenes  # 使用valid_scenes
            )
        else:
            self.reference_scenes = None

        self.split = split
        # 修改原始的分片逻辑使用valid_scenes ================================
        if self.split == "train" and self.cfg.train_indices is not None:
            self.valid_scenes = self.valid_scenes[
                self.cfg.train_indices[0] : self.cfg.train_indices[1]
            ]
            self.valid_scenes = self.valid_scenes * self.cfg.repeat
        elif self.split == "val" and self.cfg.val_indices is not None:
            self.valid_scenes = self.valid_scenes[
                self.cfg.val_indices[0] : self.cfg.val_indices[1]
            ]
        elif self.split == "test" and self.cfg.test_indices is not None:
            self.valid_scenes = self.valid_scenes[
                self.cfg.test_indices[0] : self.cfg.test_indices[1]
            ]
        print(f"Final dataset size after filtering: {len(self.valid_scenes)} scenes")

        if self.cfg.prompt_db_path is not None:
            self.prompt_db = json.load(open(self.cfg.prompt_db_path))
        else:
            self.prompt_db = None

# class MultiviewDataset(Dataset):
#     def __init__(self, cfg: Any, split: str = "train") -> None:
#         super().__init__()
#         assert split in ["train", "val", "test"]
#         self.cfg: MultiviewDataModuleConfig = cfg
#         self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.pbr_root_dir)

#         # 新增调试标志和计数器
#         self.debug_save = True  # 开启调试保存
#         self.debug_save_path = "./debug_images"  # 调试保存路径
#         self.debug_save_counter = 0  # 限制保存次数
#         os.makedirs(self.debug_save_path, exist_ok=True)


#         # 新增：预过滤有效场景 ==============================================
#         # self.valid_scenes = []
#         # kkk = 0
#         # for scene_dir in self.all_scenes:
#         #     # kkk += 1
#         #     # if kkk > 3000:
#         #     #     break
#         #     valid = True
#         #     # 检查每个视图的PBR文件是否存在
#         #     for f in self.cfg.image_names:
#         #         required_files = [
#         #             os.path.join(scene_dir, f"base-color_{f}.{self.cfg.image_suffix}"),
#         #             os.path.join(scene_dir, f"metallic_{f}.{self.cfg.image_suffix}"),
#         #             os.path.join(scene_dir, f"roughness_{f}.{self.cfg.image_suffix}"),
#         #         ]
#         #         # 检查所有必需文件是否存在
#         #         if not all(os.path.exists(p) for p in required_files):
#         #             valid = False
#         #             break
#         #     if valid:
#         #         self.valid_scenes.append(scene_dir)

#         from tqdm import tqdm

#         # 增强型预过滤（包含文件完整性检查）
#         # 在数据集初始化时严格验证所有文件的完整性,
#         # 确保valid_scenes只包含所有文件均正常的场景 
#         self.valid_scenes = []
#         kkk = 0
#         for scene_dir in tqdm(self.all_scenes, desc="验证数据完整性"):
#             kkk += 1
#             if kkk > 500:
#                 break
#             if self._validate_scene(scene_dir):
#                 self.valid_scenes.append(scene_dir)

#         print(f"最终有效场景数: {len(self.valid_scenes)}/{len(self.all_scenes)}")



#         # 原始引用场景处理逻辑（需要根据过滤后的场景调整）=======================
#         if (
#             self.cfg.reference_root_dir is not None
#             and self.cfg.reference_scene_list is not None
#         ):
#             reference_scenes = _parse_scene_list(
#                 self.cfg.reference_scene_list, self.cfg.reference_root_dir
#             )
#             # 使用过滤后的场景重新建立引用关系
#             self.valid_scenes, self.reference_scenes = _parse_reference_scene_list(
#                 reference_scenes, self.valid_scenes  # 使用valid_scenes
#             )
#         else:
#             self.reference_scenes = None

#         self.split = split
#         # 修改原始的分片逻辑使用valid_scenes ================================
#         if self.split == "train" and self.cfg.train_indices is not None:
#             self.valid_scenes = self.valid_scenes[
#                 self.cfg.train_indices[0] : self.cfg.train_indices[1]
#             ]
#             self.valid_scenes = self.valid_scenes * self.cfg.repeat
#         elif self.split == "val" and self.cfg.val_indices is not None:
#             self.valid_scenes = self.valid_scenes[
#                 self.cfg.val_indices[0] : self.cfg.val_indices[1]
#             ]
#         elif self.split == "test" and self.cfg.test_indices is not None:
#             self.valid_scenes = self.valid_scenes[
#                 self.cfg.test_indices[0] : self.cfg.test_indices[1]
#             ]
#         print(f"Final dataset size after filtering: {len(self.valid_scenes)} scenes")


#         if self.cfg.prompt_db_path is not None:
#             self.prompt_db = json.load(open(self.cfg.prompt_db_path))
#         else:
#             self.prompt_db = None


    #  ===================
    def _validate_scene(self, scene_dir):
        """全面验证场景数据完整性"""
        try:
            # 检查meta.json
            if not os.path.exists(os.path.join(scene_dir, "meta.json")):
                return False
            
            # 验证所有图像文件
            for f in self.cfg.image_names:
                # 只用检查 render(rgb),normal, depth 图片
                required_files = [
                    os.path.join(scene_dir, f"render_{f}.png"),
                    os.path.join(scene_dir, f"normal_{f}.exr"),
                    os.path.join(scene_dir, f"depth_{f}.exr"),
                    # os.path.join(scene_dir, f"base-color_{f}.{self.cfg.image_suffix}"),
                    # os.path.join(scene_dir, f"metallic_{f}.{self.cfg.image_suffix}"),
                    # os.path.join(scene_dir, f"roughness_{f}.{self.cfg.image_suffix}"),
                ]
                
                for path in required_files:
                    # 存在性检查
                    if not os.path.exists(path):
                        # print(f"缺失文件: {path}")
                        return False
                    
                    # # 完整性检查
                    # try:
                    #     with Image.open(path) as img:
                    #         img.verify()  # 验证文件完整性
                    #         if img.format == "WEBP":
                    #             img.load()  # 对webp格式需要额外加载验证
                    # except (IOError, OSError, Image.DecompressionBombError) as e:
                    #     print(f"损坏文件: {path} ({str(e)})")
                    #     return False

            # # 验证参考场景（如果有）
            # if self.reference_scenes is not None:
            #     ref_scene = self.reference_scenes.get(scene_dir, None)
            #     if ref_scene and not self._validate_reference_scene(ref_scene):
            #         return False

            return True
        except Exception as e:
            print(f"场景验证异常: {scene_dir} - {str(e)}")
            return False

    def _validate_reference_scene(self, scene_dir):
        """验证参考场景完整性"""
        try:
            for f in self.cfg.reference_image_names:
                path = os.path.join(
                    scene_dir, 
                    f"{self.cfg.reference_image_modality}_{f}.webp"
                )
                with Image.open(path) as img:
                    img.verify()
            return True
        except Exception as e:
            print(f"参考场景损坏: {scene_dir} - {str(e)}")
            return False

    def __len__(self):
        # return len(self.all_scenes)
        return len(self.valid_scenes)  # 使用过滤后的场景数量

    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif bg_color == "random_gray":
            bg_color = random.uniform(0.3, 0.7)
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        elif isinstance(bg_color, list) or isinstance(bg_color, tuple):
            bg_color = np.array(bg_color, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def load_image(
        self,
        image: Union[str, Image.Image],
        height: int,
        width: int,
        background_color: torch.Tensor,
        rescale: bool = False,
        mask_aug: bool = False,
    ):
        if isinstance(image, str):
            image = Image.open(image)

        image = image.resize((width, height))
        image = torch.from_numpy(np.array(image)).float() / 255.0

        if mask_aug:
            alpha = image[:, :, 3]  # Extract alpha channel
            h, w = alpha.shape
            y_indices, x_indices = torch.where(alpha > 0.5)
            if len(y_indices) > 0 and len(x_indices) > 0:
                idx = torch.randint(len(y_indices), (1,)).item()
                y_center = y_indices[idx].item()
                x_center = x_indices[idx].item()
                mask_h = random.randint(h // 8, h // 4)
                mask_w = random.randint(w // 8, w // 4)

                y1 = max(0, y_center - mask_h // 2)
                y2 = min(h, y_center + mask_h // 2)
                x1 = max(0, x_center - mask_w // 2)
                x2 = min(w, x_center + mask_w // 2)

                alpha[y1:y2, x1:x2] = 0.0
                image[:, :, 3] = alpha

        image = image[:, :, :3] * image[:, :, 3:4] + background_color * (
            1 - image[:, :, 3:4]
        )
        if rescale:
            image = image * 2.0 - 1.0

        # Calculate min and max values
        # min_val = torch.min(image)
        # max_val = torch.max(image)
        # print('min_val', min_val)
        # print('max_val', max_val)

        return image
    
    # copy from zz
    def load_normal_image(
        self,
        path,
        height,
        width,
        background_color,
        camera_space: bool = False,
        c2w: Optional[torch.FloatTensor] = None,
        resize_rate: float = 1.0,
        max_bbox: Tuple[int, int, int, int] = None,
    ):
        # 需要确认是否需要通道倒置和alpha剔除
        image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        image = cv2.resize(image, (width, height), interpolation=cv2.INTER_NEAREST)
        alpha = image[:, :, 3:4]
        image = image[:, :, :3]
        image = torch.from_numpy(np.array(image[...,::-1])).float()
        alpha = torch.from_numpy(np.array(alpha)).float()
        if not camera_space:
            c2w = c2w[:3, :3]
            image = (
                F.normalize(((image * 2 - 1)[:, :, None, :] * c2w).sum(-1), dim=-1)
                * 0.5
                + 0.5
            )
        image = image * alpha + background_color * (1 - alpha)

        # min_val = torch.min(image)
        # max_val = torch.max(image)
        # print('normal min_val', min_val)
        # print('normal max_val', max_val)

        return image
    

    def load_depth(self, path, height, width):
        depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST)
        depth = torch.from_numpy(depth[..., 0:1]).float()
        mask = torch.ones_like(depth)
        mask[depth > 1000.0] = 0.0  # depth = 65535 is the invalid value
        depth[~(mask > 0.5)] = 0.0

        # min_val = torch.min(depth)
        # max_val = torch.max(depth)
        # print('depth min_val', min_val)
        # print('depth max_val', max_val)

        return depth, mask

    def retrieve_prompt(self, scene_dir):
        assert self.prompt_db is not None
        source_id = os.path.basename(scene_dir)
        # print('source_id', source_id)
        return self.prompt_db.get(source_id, "")

    def __getitem__(self, index):
        # 直接使用预过滤后的有效场景
        scene_dir = self.valid_scenes[index]  # 修改这里
        # scene_dir /data/public/material_dataset/Objaverse-Ortho/texture_ortho10view_pbr_gt/2a/2a77a1c27b604f038138e16116091ae7
        # print('scene_dir', scene_dir)

        background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color))

        # background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color))
        # scene_dir = self.all_scenes[index]


        '''# 从Objaverse-Ortho/texture_rand_easylight_objaverse拿到rgb，先不考虑这种方案，还需要需要拿到对应的norm等，这里的norm cam有修改不知道怎么做  
        path_parts = Path(scene_dir).parts
        rgb_scene_dir = Path("/data/public/material_dataset/Objaverse-Ortho/texture_rand_easylight_objaverse")
        rgb_scene_dir = rgb_scene_dir / path_parts[-2] / path_parts[-1]  # 取最后两级目录


        try:
            with open(os.path.join(scene_dir, "meta.json")) as f:
                meta = json.load(f)
            name2loc = {loc["index"]: loc for loc in meta["locations"]}

            # target multi-view images, GT 多视图
            image_paths = [
                os.path.join(
                    rgb_scene_dir, f"color_{f}.webp" # rgb是webp，写死
                )
                for f in self.cfg.image_names
            ]
            images = [
                self.load_image(
                    p,
                    height=self.cfg.height,
                    width=self.cfg.width,
                    background_color=background_color,
                )
                for p in image_paths
            ]
            images = torch.stack(images, dim=0).permute(0, 3, 1, 2)
        '''


        try:
            ''' 从 pbr数据拿到 render.png作为rgb, 不知道这种做法合理不，后期需要从texture_rand_easylight_objaverse拿到rgb和norm '''
            with open(os.path.join(scene_dir, "meta.json")) as f:
                meta = json.load(f)
            name2loc = {loc["index"]: loc for loc in meta["locations"]}

            # target multi-view images, GT 多视图
            image_paths = [
                os.path.join(
                    scene_dir, f"render_{f}.png" # pbr目录下的rgb是render_000x.png，写死
                )
                for f in self.cfg.image_names
            ]
            images = [
                self.load_image(
                    p,
                    height=self.cfg.height,
                    width=self.cfg.width,
                    background_color=background_color,
                )
                for p in image_paths
            ]
            images = torch.stack(images, dim=0).permute(0, 3, 1, 2)

            ''' target pbr, 这里style transfer应该不需要 
            # target multi-view PBR images
            a_images = []  # 存储albedo (3通道)
            b_images = []  # 存储 [metallic, roughness, 0] (3通道)

            for f in self.cfg.image_names:
                # ----------------------------
                # 1. 加载albedo (base-color)
                # ----------------------------
                albedo_path = os.path.join(
                    scene_dir, 
                    f"base-color_{f}.{self.cfg.image_suffix}"  # 例如 base-color_0000.png
                )
                albedo = self.load_image(
                    albedo_path,
                    height=self.cfg.height,
                    width=self.cfg.width,
                    background_color=background_color,
                )  # 形状 (H, W, 3)

                # ----------------------------
                # 2. 加载metallic和roughness
                # ----------------------------
                # 加载metallic（假设为单通道图像）
                metallic_path = os.path.join(
                    scene_dir,
                    f"metallic_{f}.{self.cfg.image_suffix}"  # 例如 metallic_0000.png
                )
                metallic = self.load_image(
                    metallic_path,
                    height=self.cfg.height,
                    width=self.cfg.width,
                    background_color=background_color,
                )
                # 提取单通道,metallic和roughness RGB相同（取R通道）
                metallic = metallic[..., 0:1]  # 形状 (H, W, 1)

                # 加载roughness（同理）
                roughness_path = os.path.join(
                    scene_dir,
                    f"roughness_{f}.{self.cfg.image_suffix}"  # 例如 roughness_0000.png
                )
                roughness = self.load_image(
                    roughness_path,
                    height=self.cfg.height,
                    width=self.cfg.width,
                    background_color=background_color,
                )
                roughness = roughness[..., 0:1]  # 形状 (H, W, 1)

                # ----------------------------
                # 3. 组合成B通道 [M, R, 0]
                # ----------------------------
                # 创建空白通道（与metallic同形状）
                zero_channel = torch.zeros_like(metallic)
                # 拼接通道
                b = torch.cat([metallic, roughness, zero_channel], dim=-1)  # 形状 (H, W, 3)

                # ----------------------------
                # 4. 存入列表
                # ----------------------------
                a_images.append(albedo)
                b_images.append(b)

            # ----------------------------
            # 5. 堆叠多视角张量
            # ----------------------------
            # Albedo: (num_views, 3, H, W)
            a_images = torch.stack(a_images, dim=0).permute(0, 3, 1, 2)
            # Metallic+Roughness: (num_views, 3, H, W)
            b_images = torch.stack(b_images, dim=0).permute(0, 3, 1, 2)
            
            '''
            
            # camera
            # index: 视图唯一标识符（如 "0000"）。transform_matrix: 4x4 的相机到世界坐标系变换矩阵（c2w）。
            # projection_type: 投影类型（ORTHO 或 PERSP）。ortho_scale: 正交投影的缩放参数。
            c2w = [
                torch.as_tensor(name2loc[name]["transform_matrix"])
                for name in self.cfg.image_names
            ]
            c2w = torch.stack(c2w, dim=0)

            if self.cfg.projection_type == "PERSP":
                camera_angle_x = (
                    meta.get("camera_angle_x", None)
                    or meta["locations"][0]["camera_angle_x"]
                )
                focal_length = 0.5 * self.cfg.width / np.tan(0.5 * camera_angle_x)
                intrinsics = (
                    torch.as_tensor(
                        [
                            [focal_length, 0.0, 0.5 * self.cfg.width],
                            [0.0, focal_length, 0.5 * self.cfg.height],
                            [0.0, 0.0, 1.0],
                        ]
                    )
                    .unsqueeze(0)
                    .float()
                    .repeat(len(self.cfg.image_names), 1, 1)
                )
            elif self.cfg.projection_type == "ORTHO":
                ortho_scale = (
                    meta.get("ortho_scale", None) or meta["locations"][0]["ortho_scale"]
                )

            # source conditions
            # 这里 涉及到世界到相机坐标之间的转换（w2c），还不是很理解
            source_image_modality = self.cfg.source_image_modality
            if isinstance(source_image_modality, str):
                source_image_modality = [source_image_modality]
            source_images = []
            for modality in source_image_modality:
                if modality == "position":
                    depth_masks = [
                        self.load_depth(
                            os.path.join(scene_dir, f"depth_{f}.exr"),
                            self.cfg.height,
                            self.cfg.width,
                        )
                        for f in self.cfg.image_names
                    ]
                    depths = torch.stack([d for d, _ in depth_masks])
                    masks = torch.stack([m for _, m in depth_masks])
                    c2w_ = c2w.clone()
                    c2w_[:, :, 1:3] *= -1

                    if self.cfg.projection_type == "PERSP":
                        position_maps = get_position_map_from_depth(
                            depths,
                            masks,
                            intrinsics,
                            c2w_,
                            image_wh=(self.cfg.width, self.cfg.height),
                        )
                    elif self.cfg.projection_type == "ORTHO":
                        position_maps = get_position_map_from_depth_ortho(
                            depths,
                            masks,
                            c2w_,
                            ortho_scale,
                            image_wh=(self.cfg.width, self.cfg.height),
                        )
                    position_maps = (
                        (position_maps + self.cfg.position_offset) / self.cfg.position_scale
                    ).clamp(0.0, 1.0)
                    source_images.append(position_maps)
                elif modality == "normal":
                    normal_maps = [
                        self.load_normal_image(
                            os.path.join(
                                # scene_dir, f"{modality}_{f}.{self.cfg.image_suffix}"
                                scene_dir, f"{modality}_{f}.exr"
                            ),
                            height=self.cfg.height,
                            width=self.cfg.width,
                            background_color=background_color,
                            camera_space=self.cfg.use_camera_space_normal,
                            c2w=c,
                        )
                        for c, f in zip(c2w, self.cfg.image_names)
                    ]
                    source_images.append(torch.stack(normal_maps, dim=0))
                elif modality == "plucker":
                    if self.cfg.projection_type == "ORTHO":
                        plucker_embed = get_plucker_embeds_from_cameras_ortho(
                            c2w, [ortho_scale] * len(c2w), self.cfg.width
                        )
                    elif self.cfg.projection_type == "PERSP":
                        plucker_embed = get_plucker_embeds_from_cameras(
                            c2w, [camera_angle_x] * len(c2w), self.cfg.width
                        )
                    else:
                        raise NotImplementedError
                    plucker_embed = plucker_embed.permute(0, 2, 3, 1)
                    plucker_embed = (
                        (plucker_embed + self.cfg.plucker_offset) / self.cfg.plucker_scale
                    ).clamp(0.0, 1.0)
                    source_images.append(plucker_embed)
                else:
                    raise NotImplementedError
            source_images = torch.cat(source_images, dim=-1).permute(0, 3, 1, 2)
            # 这里可能需要修改， rgb(images)是6-view gt img; source_rgb (source_images)是conditions
            # rv = {"rgb": images, "c2w": c2w, "source_rgb": source_images} 

            # clip(0, 1) 确保数值在合理范围内
            source_images = source_images.clip(0, 1)

            # 异常值检测与处理
            # 检查 source_images 中的 NaN/Inf 值
            # 若存在异常值：
            # 设置 flag 为 True（后续可用于日志或跳过该样本）
            # 将异常值替换为 0
            flag = False
            if torch.isnan(source_images).any() or torch.isinf(source_images).any():
                flag = True
                source_images = torch.nan_to_num(source_images, 0)

            rv = {
                "rgb": images, 
                "c2w": c2w, 
                "source_rgb": source_images, # normal,position condition
                # "albedo": a_images,  # 用于diffuse颜色
                # "mr_channel": b_images  # 金属度+粗糙度+占位
                } 
            
            # 打印所有图片的 shape
            # print("Shape of 'rgb':", rv["rgb"].shape if hasattr(rv["rgb"], 'shape') else "Not a tensor/array")
            # print("Shape of 'source_rgb':", rv["source_rgb"].shape if hasattr(rv["source_rgb"], 'shape') else "Not a tensor/array")
            # print("Shape of 'albedo':", rv["albedo"].shape if hasattr(rv["albedo"], 'shape') else "Not a tensor/array")
            # print("Shape of 'mr_channel':", rv["mr_channel"].shape if hasattr(rv["mr_channel"], 'shape') else "Not a tensor/array")


            num_images = len(self.cfg.image_names)
            # prompt
            if self.cfg.return_prompt:
                if self.cfg.use_empty_prompt:
                    prompt = ""
                else:
                    prompt = self.retrieve_prompt(scene_dir)
                    # print('prompt', prompt)
                prompts = [prompt] * num_images

                if self.cfg.prompt_prefix is not None:
                    prompt_prefix = self.cfg.prompt_prefix
                    if isinstance(prompt_prefix, str):
                        prompt_prefix = [prompt_prefix] * num_images

                    for i, prompt in enumerate(prompts):
                        prompts[i] = f"{prompt_prefix[i]} {prompt}"

                if self.cfg.return_one_prompt:
                    rv.update({"prompts": prompts[0]})
                else:
                    rv.update({"prompts": prompts})

                # print('rv', rv)

            # reference image
            if self.reference_scenes is not None:
                reference_scene_dir = self.reference_scenes[scene_dir]
                reference_image_paths = [
                    os.path.join(
                        reference_scene_dir,
                        # f"{self.cfg.reference_image_modality}_{f}.{self.cfg.image_suffix}", 
                        f"{self.cfg.reference_image_modality}_{f}.webp", # 写死
                    )
                    for f in self.cfg.reference_image_names
                ]
                reference_image_path = random.choice(reference_image_paths)

                if self.cfg.reference_augment_resolutions is None:
                    reference_image = self.load_image(
                        reference_image_path,
                        height=self.cfg.height,
                        width=self.cfg.width,
                        background_color=background_color,
                        mask_aug=self.cfg.reference_mask_aug,
                    ).permute(2, 0, 1)
                    rv.update({"reference_rgb": reference_image})
                else:
                    random_resolution = random.choice(
                        self.cfg.reference_augment_resolutions
                    )
                    reference_image_ = Image.open(reference_image_path).resize(
                        (random_resolution, random_resolution)
                    )
                    reference_image = self.load_image(
                        reference_image_,
                        height=self.cfg.height,
                        width=self.cfg.width,
                        background_color=background_color,
                        mask_aug=self.cfg.reference_mask_aug,
                    ).permute(2, 0, 1)
                    rv.update({"reference_rgb": reference_image})

            # 新增调试保存逻辑 ================================================
            # if self.debug_save and self.debug_save_counter < 1:  # 只保存第一个样本
            #     scene_id = os.path.basename(scene_dir)
            #     save_dir = os.path.join(self.debug_save_path, f"sample_{scene_id}")
            #     os.makedirs(save_dir, exist_ok=True)
                
            #     # 保存albedo各视图
            #     for view_idx in range(a_images.shape[0]):
            #         albedo_view = a_images[view_idx].permute(1, 2, 0).numpy()  # CHW->HWC
            #         albedo_view = (albedo_view * 255).astype(np.uint8)
            #         Image.fromarray(albedo_view).save(
            #             os.path.join(save_dir, f"albedo_view{view_idx:02d}.png")
            #         )
                    
            #     # 保存metallic+roughness各视图
            #     for view_idx in range(b_images.shape[0]):
            #         mr_view = b_images[view_idx].permute(1, 2, 0).numpy()
            #         metallic = (mr_view[..., 0] * 255).astype(np.uint8)  # Metallic通道
            #         roughness = (mr_view[..., 1] * 255).astype(np.uint8)  # Roughness通道
            #         Image.fromarray(metallic).save(
            #             os.path.join(save_dir, f"metallic_view{view_idx:02d}.png")
            #         )
            #         Image.fromarray(roughness).save(
            #             os.path.join(save_dir, f"roughness_view{view_idx:02d}.png")
            #         )
                
            #     # 保存source条件图像（position/normal）
            #     for view_idx in range(source_images.shape[0]):
            #         source_view = source_images[view_idx].permute(1, 2, 0).numpy()
                    
            #         # 分通道保存
            #         for ch_idx in range(source_view.shape[-1]):
            #             channel = (source_view[..., ch_idx] * 255).astype(np.uint8)
            #             Image.fromarray(channel).save(
            #                 os.path.join(save_dir, f"source_view{view_idx:02d}_ch{ch_idx}.png")
            #             )
                
            #     # 保存参考图像
            #     if "reference_rgb" in rv:
            #         ref_img = rv["reference_rgb"].permute(1, 2, 0).numpy()
            #         ref_img = (ref_img * 255).astype(np.uint8)
            #         Image.fromarray(ref_img).save(
            #             os.path.join(save_dir, "reference.png")
            #         )
                
            #     self.debug_save_counter += 1
            #     print(f"Debug images saved to: {save_dir}")

            return rv
        
        except Exception as e:
            print(f"严重错误：预验证通过的场景加载失败！{scene_dir} - {str(e)}")
            # 返回空数据并在collate中过滤
            return None 

    # 这里可能需要修改
    def collate(self, batch):
        # 过滤无效样本（理论上不应出现）
        batch = [b for b in batch if b is not None]
        if not batch:
            # 返回空batch时跳过该批次
            return {}


        batch = torch.utils.data.default_collate(batch)
        pack = lambda t: t.view(-1, *t.shape[2:])

        if self.cfg.random_view_list is not None:
            indices = random.choice(self.cfg.random_view_list)
        else:
            indices = list(range(self.cfg.num_views))
        num_views = len(indices)

        for k in batch.keys():
            # if k in ["rgb", "source_rgb", "c2w"]:
            # if k in ["rgb", "albedo", "mr_channel", "source_rgb", "c2w", "reference_rgb"]:
            if k in ["rgb", "source_rgb", "c2w"]:
                batch[k] = batch[k][:, indices]
                batch[k] = pack(batch[k])
        for k in ["prompts"]:
            if not self.cfg.return_one_prompt:
                batch[k] = [item for pair in zip(*batch[k]) for item in pair]

        batch.update(
            {
                "num_views": num_views,
                # For SDXL
                "original_size": (self.cfg.height, self.cfg.width),
                "target_size": (self.cfg.height, self.cfg.width),
                "crops_coords_top_left": (0, 0),
            }
        )
        return batch


class MultiviewDataModule(pl.LightningDataModule):
    cfg: MultiviewDataModuleConfig

    def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
        super().__init__()
        self.cfg = parse_structured(MultiviewDataModuleConfig, cfg)

    def setup(self, stage=None) -> None:
        if stage in [None, "fit"]:
            self.train_dataset = MultiviewDataset(self.cfg, "train")
        if stage in [None, "fit", "validate"]:
            self.val_dataset = MultiviewDataset(self.cfg, "val")
        if stage in [None, "test", "predict"]:
            self.test_dataset = MultiviewDataset(self.cfg, "test")

    def prepare_data(self):
        pass

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.cfg.batch_size,
            num_workers=self.cfg.num_workers,
            shuffle=True,
            collate_fn=self.train_dataset.collate,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.cfg.eval_batch_size,
            num_workers=self.cfg.num_workers,
            shuffle=False,
            collate_fn=self.val_dataset.collate,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_dataset,
            batch_size=self.cfg.eval_batch_size,
            num_workers=self.cfg.num_workers,
            shuffle=False,
            collate_fn=self.test_dataset.collate,
        )

    def predict_dataloader(self) -> DataLoader:
        return self.test_dataloader()


# if __name__ == "__main__":
#     import torchvision
#     from omegaconf import OmegaConf

#     # config_file = "configs/view-guidance/mvadapter_ig2mv_sdxl_debug.yaml"
#     config_file = "/data/home/yeyuteng/single_image_to_pbr/configs/geometry-guidance/mvadapter_ig2mv_sdxl_decouple.yaml"
#     data_cfg = OmegaConf.load(config_file)["data"]
#     cfg: MultiviewDataModuleConfig = MultiviewDataModuleConfig(**data_cfg)
#     data_module = MultiviewDataModule(cfg)
#     data_module.setup()

#     for batch in data_module.train_dataloader():
#         # ref_rgb = batch["reference_rgb"]  # bchw
#         rgb = batch["albedo"]
#         source_rgb = batch["source_rgb"]
#         ref_rgb = batch["reference_rgb"]

#         print(batch["prompts"])
#         print(rgb.shape, source_rgb.shape)

#         torchvision.utils.save_image(rgb[:, :3], "debug_albedo.png")
#         torchvision.utils.save_image(ref_rgb[:, :3], "debug_ref_rgb.png")
#         torchvision.utils.save_image(source_rgb[:, :3], "debug_xyz_rgb.png")
#         torchvision.utils.save_image(source_rgb[:, 3:], "debug_normal_rgb.png")
#         exit(0)
