# mvadapter/utils/load_mesh.py
import datetime
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

import json
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from typing import Tuple, List, Dict, Optional

# 配置参数
class Config:
    def __init__(self):
        self.height = 512
        self.width = 512
        self.use_camera_space_normal = False
        self.background_color = "gray"
        self.position_offset = 0.5
        self.position_scale = 1.0
        self.image_names = ["0006", "0000", "0002", "0004", "0008", "0009"]
        self.image_suffix = "exr"
        
    def get_bg_color(self, bg_color: str) -> np.ndarray:
        if bg_color == "gray":
            return np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "white":
            return np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            return np.array([0.0, 0.0, 0.0], dtype=np.float32)
        else:
            raise ValueError(f"Unknown background color: {bg_color}")

cfg = Config()

def load_meta(scene_dir: str) -> Dict:
    """加载场景的meta.json文件"""
    meta_path = os.path.join(scene_dir, "meta.json")
    with open(meta_path, 'r') as f:
        meta = json.load(f)
    return meta

def get_c2w_matrices(meta: Dict) -> torch.Tensor:
    """获取所有视角的c2w变换矩阵"""
    c2w_list = []
    for loc in meta["locations"]:
        c2w = torch.tensor(loc["transform_matrix"], dtype=torch.float32)
        c2w_list.append(c2w)
    return torch.stack(c2w_list, dim=0)

def get_ortho_scale(meta: Dict) -> float:
    """获取正交投影的缩放比例"""
    return meta.get("ortho_scale", meta["locations"][0]["ortho_scale"])

def load_normal_image(
    path: str,
    height: int,
    width: int,
    background_color: torch.Tensor,
    camera_space: bool = False,
    c2w: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """加载和处理法线图像"""
    image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if image is None:
        raise FileNotFoundError(f"Normal image not found: {path}")
    
    image = cv2.resize(image, (width, height), interpolation=cv2.INTER_NEAREST)
    alpha = image[:, :, 3:4]
    image = image[:, :, :3]
    image = torch.from_numpy(np.array(image[...,::-1])).float()
    alpha = torch.from_numpy(np.array(alpha)).float()
    
    if not camera_space and c2w is not None:
        c2w = c2w[:3, :3]
        image = (
            F.normalize(((image * 2 - 1)[:, :, None, :] * c2w).sum(-1), dim=-1)
            * 0.5
            + 0.5
        )
    image = image * alpha + background_color * (1 - alpha)
    return image

def load_depth(
    path: str,
    height: int,
    width: int
) -> Tuple[torch.Tensor, torch.Tensor]:
    """加载深度图并创建掩码"""
    depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if depth is None:
        raise FileNotFoundError(f"Depth image not found: {path}")
    
    depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST)
    depth = torch.from_numpy(depth[..., 0:1]).float()
    
    mask = torch.ones_like(depth)
    mask[depth > 1000.0] = 0.0
    depth[~(mask > 0.5)] = 0.0
    
    return depth, mask

def get_position_map_from_depth_ortho(
    depth: torch.Tensor,
    mask: torch.Tensor,
    extrinsics: torch.Tensor,
    ortho_scale: float,
    image_wh: Tuple[int, int] = (512, 512)
) -> torch.Tensor:
    """从正交深度图计算位置图"""
    B, H, W, _ = depth.shape
    depth = depth.squeeze(-1)
    
    u_coord = torch.linspace(0, image_wh[0]-1, image_wh[0])
    v_coord = torch.linspace(0, image_wh[1]-1, image_wh[1])
    u_coord, v_coord = torch.meshgrid(u_coord, v_coord, indexing='xy')
    
    u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
    v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1)
    
    x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0]
    y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1]
    z = depth
    
    camera_coords = torch.stack([x, y, z], dim=-1)
    coords_homogeneous = F.pad(camera_coords, (0, 1), "constant", 1.0)
    world_coords = torch.einsum('bij,bhwj->bhwi', extrinsics, coords_homogeneous)
    position_map = world_coords[..., :3] * mask
    
    return position_map

def load_objaverse_mesh(scene_dir: str) -> torch.Tensor:
    """
    加载Objaverse网格数据并返回控制图像
    
    参数:
        scene_dir: 场景目录路径 (e.g., /path/to/029389da502d41a0aeada600137ae98b)
    
    返回:
        control_images: 形状为 [6, 6, 512, 512] 的张量
                        每个视角包含6通道数据 (前3通道位置图，后3通道法线图)
    """
    # 加载元数据
    try:
        meta = load_meta(scene_dir)
    except Exception as e:
        raise RuntimeError(f"Failed to load meta.json: {str(e)}")
    
    # 获取变换矩阵和正交缩放
    c2w = get_c2w_matrices(meta)
    ortho_scale = get_ortho_scale(meta)
    c2w_ = c2w.clone()
    c2w_[:, :, 1:3] *= -1
    
    # 获取背景颜色
    bg_color = cfg.get_bg_color(cfg.background_color)
    background_color = torch.from_numpy(bg_color)
    
    normal_images = []
    position_maps = []
    
    for view_name in cfg.image_names:
        view_idx = int(view_name)
        view_c2w = c2w[view_idx]
        
        normal_path = os.path.join(scene_dir, f"normal_{view_name}.exr")
        depth_path = os.path.join(scene_dir, f"depth_{view_name}.exr")
        
        try:
            # 加载法线图
            normal_img = load_normal_image(
                normal_path,
                cfg.height,
                cfg.width,
                background_color,
                cfg.use_camera_space_normal,
                view_c2w
            )
            normal_images.append(normal_img)
            
            # 加载深度图
            depth, mask = load_depth(depth_path, cfg.height, cfg.width)
            position_map = get_position_map_from_depth_ortho(
                depth.unsqueeze(0),
                mask.unsqueeze(0),
                c2w_[view_idx].unsqueeze(0),
                ortho_scale,
                (cfg.width, cfg.height)
            )
            position_maps.append(position_map.squeeze(0))
            
        except Exception as e:
            raise RuntimeError(f"Error processing view {view_name}: {str(e)}")
    
    # 转换为Tensor并处理
    normal_tensor = torch.stack(normal_images, dim=0)
    position_tensor = torch.stack(position_maps, dim=0)
    
    # 归一化位置图
    position_tensor = (position_tensor + cfg.position_offset) / cfg.position_scale
    position_tensor = position_tensor.clamp(0.0, 1.0)
    
    # 拼接法线和位置图 [6, H, W, 3+3] -> [6, H, W, 6]
    source_images = torch.cat([position_tensor, normal_tensor], dim=-1)
    
    # 调整维度顺序 [B, H, W, C] -> [B, C, H, W]
    source_images = source_images.permute(0, 3, 1, 2)
    
    # 确保数据在[0,1]范围内
    source_images = torch.clamp(source_images, 0, 1)
    
    # 异常值处理
    if torch.isnan(source_images).any() or torch.isinf(source_images).any():
        source_images = torch.nan_to_num(source_images, nan=0.0, posinf=0.0, neginf=0.0)
    
    return source_images