import torch
import sys
from datetime import datetime
import numpy as np
import random
import os
from PIL import Image
from functools import lru_cache
import time

def inverse_sigmoid(x):
    return torch.log(x/(1-x))

def PILtoTorch(pil_image, resolution):
    resized_image_PIL = pil_image.resize(resolution)
    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
    if len(resized_image.shape) == 3:
        return resized_image.permute(2, 0, 1)
    else:
        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)

def get_expon_lr_func(
    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
    """
    Copied from Plenoxels
    """

    def helper(step):
        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
            # Disable this parameter
            return 0.0
        if lr_delay_steps > 0:
            # A kind of reverse cosine decay.
            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
            )
        else:
            delay_rate = 1.0
        t = np.clip(step / max_steps, 0, 1)
        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
        return delay_rate * log_lerp

    return helper

def strip_lowerdiag(L):
    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")

    uncertainty[:, 0] = L[:, 0, 0]
    uncertainty[:, 1] = L[:, 0, 1]
    uncertainty[:, 2] = L[:, 0, 2]
    uncertainty[:, 3] = L[:, 1, 1]
    uncertainty[:, 4] = L[:, 1, 2]
    uncertainty[:, 5] = L[:, 2, 2]
    return uncertainty

def strip_symmetric(sym):
    return strip_lowerdiag(sym)

def build_rotation(r):
    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])

    q = r / norm[:, None]

    R = torch.zeros((q.size(0), 3, 3), device='cuda')

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y*y + z*z)
    R[:, 0, 1] = 2 * (x*y - r*z)
    R[:, 0, 2] = 2 * (x*z + r*y)
    R[:, 1, 0] = 2 * (x*y + r*z)
    R[:, 1, 1] = 1 - 2 * (x*x + z*z)
    R[:, 1, 2] = 2 * (y*z - r*x)
    R[:, 2, 0] = 2 * (x*z - r*y)
    R[:, 2, 1] = 2 * (y*z + r*x)
    R[:, 2, 2] = 1 - 2 * (x*x + y*y)
    return R

def build_scaling_rotation(s, r):
    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
    R = build_rotation(r)

    L[:,0,0] = s[:,0]
    L[:,1,1] = s[:,1]
    L[:,2,2] = s[:,2]

    L = R @ L
    return L

def safe_state(silent):
    old_f = sys.stdout
    class F:
        def __init__(self, silent):
            self.silent = silent

        def write(self, x):
            if not self.silent:
                if x.endswith("\n"):
                    old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
                else:
                    old_f.write(x)

        def flush(self):
            old_f.flush()

    sys.stdout = F(silent)

    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.set_device(torch.device("cuda:0"))


#modify function 从png加载
def load_mask(mask_path, resolution):
    from PIL import Image
    import os
    
    if not os.path.exists(mask_path):
        return None
    
    mask_image = Image.open(mask_path)
    return PILtoTorch(mask_image, resolution)

#modify function  从png加载  
def load_mask_files(image_name, masks_dir, resolution):

    mask_tensors = []
    mask_class_ids = []
    
）
    base_name = os.path.splitext(os.path.basename(image_name))[0]
    

    if not os.path.exists(masks_dir):
        print(f"Error: Masks directory {masks_dir} does not exist.")
        return mask_tensors, mask_class_ids
    

    for file in os.listdir(masks_dir):
        if not file.endswith(f"_{base_name}.png") and not file.endswith(f"_{base_name}.jpg"):
            continue
        
        try:
            # 解析类别ID
            class_id = int(file.split('_')[0])
            
            # 加载掩码
            mask_path = os.path.join(masks_dir, file)
            mask_image = Image.open(mask_path)
            mask_tensor = PILtoTorch(mask_image, resolution)
            
            # 确保掩码是二值的
            mask_tensor = (mask_tensor > 0.5).float()
            
            mask_tensors.append(mask_tensor)
            mask_class_ids.append(class_id)
        except Exception as e:
            print(f"Error loading mask {file}: {e}")
    
    return mask_tensors, mask_class_ids



@lru_cache(maxsize=128)
def cached_load_npy(file_path):
    """缓存加载NPY文件以避免重复IO操作"""
    try:
        return np.load(file_path)
    except Exception as e:
        print(f"无法加载NPY文件 {file_path}: {e}")
        return None

def load_mask_files_from_npy(image_name, masks_dir, resolution):
    start_time = time.time()
    mask_tensors = []
    mask_class_ids = []
    
    # 获取图像文件名（不含扩展名）
    base_name = os.path.splitext(os.path.basename(image_name))[0]
    
    # 检查掩码目录是否存在
    if not os.path.exists(masks_dir):
        print(f"错误: 掩码目录 {masks_dir} 不存在。")
        return mask_tensors, mask_class_ids
    
    try:
        # 尝试从文件名中提取帧索引
        frame_idx = int(base_name.split('_')[-1])
        
        # 尝试不同的文件命名格式
        npy_paths = [
            os.path.join(masks_dir, f"mask_{frame_idx}.npy"),
            os.path.join(masks_dir, f"mask_{frame_idx:04d}.npy")
        ]
        
        masks_data = None
        for npy_path in npy_paths:
            if os.path.exists(npy_path):
                # 使用缓存加载NPY文件
                masks_data = cached_load_npy(npy_path)
                break
        
        if masks_data is None:
            return mask_tensors, mask_class_ids
        
        # 分析掩码数据的维度情况
        if len(masks_data.shape) == 4 and masks_data.shape[1] == 1:
            # 掩码数据是4维的，格式为(class, 1, height, width)
            # 去掉多余的维度，转换为(class, height, width)格式
            masks_data = masks_data.squeeze(1)
        elif len(masks_data.shape) != 3:
            print(f"错误 - 数据维度不是3维或4维，当前维度为 {len(masks_data.shape)}")
            return mask_tensors, mask_class_ids
        
        # 获取类别数量
        n_classes = masks_data.shape[0]
        valid_classes = 0
        
        # 处理每个类别的掩码
        for class_id in range(n_classes):
            try:
                # 获取掩码
                mask = masks_data[class_id]
                
                # 快速检查掩码是否有意义
                has_content = False
                if mask.dtype == bool:
                    if not np.any(mask):
                        continue  # 跳过全False掩码
                    has_content = True
                    mask_array = mask.astype(np.uint8) * 255
                else:
                    if np.max(mask) == 0:
                        continue  # 跳过全0掩码
                    has_content = True
                    mask_array = (mask * 255).astype(np.uint8)
                
                # 只有有内容的掩码才进行尺寸调整
                if has_content:
                    # 调整掩码大小
                    if mask.shape[0] != resolution[1] or mask.shape[1] != resolution[0]:
                        try:
                            # 使用PIL进行高效的尺寸调整
                            mask_image = Image.fromarray(mask_array)
                            mask_tensor = PILtoTorch(mask_image, resolution)
                        except Exception as e:
                            continue
                    else:
                        # 已经是正确的尺寸，直接转换为张量
                        mask_tensor = torch.from_numpy(mask).float()
                        if len(mask_tensor.shape) == 2:
                            mask_tensor = mask_tensor.unsqueeze(0)
                    
                    # 确保掩码是二值的
                    mask_tensor = (mask_tensor > 0.5).float()
                    
                    # 检查掩码是否有意义（至少有一个非零值）
                    if mask_tensor.sum().item() > 0:
                        # 保持张量在CPU上，不立即转移到GPU
                        mask_tensors.append(mask_tensor.cpu())
                        mask_class_ids.append(class_id)
                        valid_classes += 1
            except Exception as e:
                continue  # 忽略单个掩码的错误，继续处理其他掩码
        
        elapsed = time.time() - start_time
        # 简短汇报处理时间
        if valid_classes > 0:
            print(f"图像 {base_name}: 处理了 {n_classes} 个类别，找到 {valid_classes} 个有效掩码，用时 {elapsed:.2f}s")
    
    except Exception as e:
        print(f"加载NPY掩码时出错 {base_name}: {e}")
    
    return mask_tensors, mask_class_ids

    