import torch
from torch import nn
import numpy as np
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
from utils.general_utils import PILtoTorch, load_mask_files, load_mask_files_from_npy
import cv2
import os
from PIL import Image

class Camera(nn.Module):
    def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, cx, cy, image,
                 image_name, uid,
                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                 ):
        super(Camera, self).__init__()

        self.uid = uid
        self.colmap_id = colmap_id
        self.R = R
        self.T = T
        self.FoVx = FoVx
        self.FoVy = FoVy
        self.cx = cx
        self.cy = cy
        self.image_name = image_name

        try:
            self.data_device = torch.device(data_device)
        except Exception as e:
            print(e)
            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
            self.data_device = torch.device("cuda")

        resized_image_rgb = PILtoTorch(image, resolution)
        gt_image = resized_image_rgb[:3, ...]
        # 如果原始图像有alpha通道，则提取，否则设为None
        if resized_image_rgb.shape[0] > 3:
            self.alpha_mask = resized_image_rgb[3, ...].clamp(0.0, 1.0).to(self.data_device) 
        else:
            self.alpha_mask = None
            
        self.original_image = gt_image.clamp(0.0, 1.0).to(self.data_device)
        self.image_width = self.original_image.shape[2]
        self.image_height = self.original_image.shape[1]

        self.zfar = 100.0
        self.znear = 0.01

        self.trans = trans
        self.scale = scale

        self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
        self.camera_center = self.world_view_transform.inverse()[3, :3]
        
        # 初始化其他可能在 to_gpu/to_cpu 中访问的属性
        self.invdepthmap = None # 初始化景深图为None

        # 掩码相关属性
        self.masks = None # 用于存储所有类别的完整掩码列表（如果一次性加载）
        self.mask_class_ids = None # 对应的类别ID列表
        self.current_mask = None # 用于存储当前处理的单个类别的掩码
        self.current_class_id = None # 当前处理的单个类别ID
        
        self.data_on_gpu = True # 初始时，original_image 和 alpha_mask (如果存在) 在GPU上
    
    # modify function
    def load_masks(self, masks_dir, use_npy_masks=False):
        """
        加载与该相机图像对应的所有掩码
        
        参数:
        masks_dir: 掩码目录
        use_npy_masks: 是否使用NPY文件格式 (True: 使用load_mask_files_from_npy, False: 使用load_mask_files)
        """
        resolution = (self.image_width, self.image_height)
        
        # 初始化掩码属性为空列表而不是None
        self.masks = []
        self.mask_class_ids = []
        
        if use_npy_masks:
            from utils.general_utils import load_mask_files_from_npy
            masks, class_ids = load_mask_files_from_npy(self.image_name, masks_dir, resolution)
        else:
            from utils.general_utils import load_mask_files
            masks, class_ids = load_mask_files(self.image_name, masks_dir, resolution)
            
        # masks 掩码张量列表 [class_id, H, W]
        # class_ids 对应的类别ID列表
        if masks:
            # 保持掩码在CPU上，不立即转移到GPU，减少内存压力
            self.masks = masks
            self.mask_class_ids = class_ids
            # 标记数据不在GPU上
            self.data_on_gpu = False
    
    # 添加新方法，用于加载特定类别的掩码文件
    def load_mask_by_class_id(self, masks_dir, class_id, use_npy_masks=False):
        """
        只加载特定类别ID的掩码，其他类别不加载，以节省内存
        
        参数:
        masks_dir: 掩码目录
        class_id: 要加载的类别ID
        use_npy_masks: 是否使用NPY文件格式 (目前此函数主要针对PNG)
        
        返回:
        成功加载掩码返回True，否则返回False
        """
        # print(f"[Debug CAM {self.image_name}] load_mask_by_class_id called for class_id: {class_id}")
        resolution_wh = (self.image_width, self.image_height) # (width, height)
        
        # 清除之前的单类掩码数据
        self.current_mask = None
        self.current_class_id = None
        # Python的垃圾回收会自动处理旧的current_mask（如果它是torch.Tensor的话）

        
        mask_path = os.path.join(masks_dir, f"{class_id}_{self.image_name}.png")
        print(f"尝试加载掩码: {mask_path}, 文件存在: {os.path.exists(mask_path)}")

        if not os.path.exists(mask_path):
            print(f"  [Info CAM {self.image_name}] Mask file not found: {mask_path}")
            return False
        
        try:
            
            # print(f"    [Debug CAM {self.image_name}] Loading with PIL: {mask_path}")
            mask_img_pil = Image.open(mask_path)
            
            # 转换为numpy数组，并确保是灰度图
            mask_np = np.array(mask_img_pil)
            if mask_np.ndim == 3 and mask_np.shape[2] > 1: # 例如 RGB or RGBA
                # print(f"    [Debug CAM {self.image_name}] Mask is multi-channel ({mask_np.shape}), converting to grayscale.")
                # 转换为灰度，通常取第一个通道或进行加权平均，这里简单取第一个
                mask_np = mask_np[:, :, 0]
            
            # print(f"    [Debug CAM {self.image_name}] Loaded mask shape with PIL: {mask_np.shape}, dtype: {mask_np.dtype}")
            mask_img_pil.close() # 关闭文件句柄

            mask_tensor = torch.from_numpy(mask_np).float() / 255.0 # 归一化到 [0, 1]
            
            # 调整掩码分辨率 (H, W)
            current_shape_hw = (mask_tensor.shape[0], mask_tensor.shape[1])
            target_resolution_hw = (self.image_height, self.image_width)

            if current_shape_hw != target_resolution_hw:
                # print(f"    [Debug CAM {self.image_name}] Resizing mask from {current_shape_hw} to {target_resolution_hw}")
                mask_tensor = torch.nn.functional.interpolate(
                    mask_tensor.unsqueeze(0).unsqueeze(0), # [N, C, H, W]
                    size=target_resolution_hw,          # (H, W)
                    mode='nearest'
                ).squeeze() # 移除N和C维度
                # print(f"    [Debug CAM {self.image_name}] Resized mask shape: {mask_tensor.shape}")
            
            # 确保掩码是二值的 (0.0 或 1.0)
            mask_tensor = (mask_tensor > 0.5).float()
            
            if mask_tensor.sum().item() == 0:
                # print(f"  [Info CAM {self.image_name}] Mask for class {class_id} is all zeros after processing. Path: {mask_path}")
                return False
            
            # 保持在CPU上
            self.current_mask = mask_tensor # 已经是CPU tensor
            self.current_class_id = class_id
            # print(f"  [Success CAM {self.image_name}] Successfully loaded mask for class {class_id}. Shape: {self.current_mask.shape}, Device: {self.current_mask.device}")
            return True
            
        except ImportError:
            print(f"    [Error CAM {self.image_name}] Pillow (PIL) not installed. Cannot load mask {mask_path}.")
            return False
        except FileNotFoundError:
             # 这个理论上不应该发生，因为我们已经检查过 os.path.exists
            print(f"  [Error CAM {self.image_name}] Mask file disappeared before loading: {mask_path}")
            return False
        except Exception as e:
            import traceback
            print(f"  [Error CAM {self.image_name}] Failed to load mask for class {class_id} from {mask_path}: {e}")
            # print(traceback.format_exc()) # 取消注释以获取完整堆栈跟踪
            return False

    # 添加清理掩码的方法
    def clear_masks(self):
        """清除掩码数据以释放内存"""
        if hasattr(self, 'masks') and self.masks is not None:
            del self.masks
            self.masks = None
        
        if hasattr(self, 'mask_class_ids') and self.mask_class_ids is not None:
            self.mask_class_ids = None
        
        if hasattr(self, 'current_mask') and self.current_mask is not None:
            del self.current_mask
            self.current_mask = None
        
        if hasattr(self, 'current_class_id'):
            self.current_class_id = None

    # modify function 
    def to_gpu(self):
        """将数据移到GPU"""
        if not self.data_on_gpu:
            self.original_image = self.original_image.cuda()
            if self.alpha_mask is not None:
                self.alpha_mask = self.alpha_mask.cuda()
            if self.invdepthmap is not None:
                self.invdepthmap = self.invdepthmap.cuda()
            if self.masks is not None:
                self.masks = [mask.cuda() for mask in self.masks]
            self.data_on_gpu = True

    # modify function
    def to_cpu(self):
        """将数据移到CPU以释放GPU内存"""
        if self.data_on_gpu:
            self.original_image = self.original_image.cpu()
            if self.alpha_mask is not None:
                self.alpha_mask = self.alpha_mask.cpu()
            if self.invdepthmap is not None:
                self.invdepthmap = self.invdepthmap.cpu()
            if self.masks is not None:
                self.masks = [mask.cpu() for mask in self.masks]
            self.data_on_gpu = False
        
class MiniCam:
    def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
        self.image_width = width
        self.image_height = height    
        self.FoVy = fovy
        self.FoVx = fovx
        self.znear = znear
        self.zfar = zfar
        self.world_view_transform = world_view_transform
        self.full_proj_transform = full_proj_transform
        view_inv = torch.inverse(self.world_view_transform)
        self.camera_center = view_inv[3][:3]

