#消融实验版本 - 完全删除中立点处理逻辑，只进行前景/背景二分类
import numpy as np
import gc
import json
import os
from tqdm import tqdm
from gaussian_renderer import render,flexirender
from scene import Scene
import os
import cv2 
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
from gaussian_renderer import GaussianModel
from PIL import Image
import torch
import torch.nn.functional as F



def save_mask(mask, save_path):
    """
    保存二值掩码图像
    
    Args:
        mask: 二值掩码张量
        save_path: 保存路径
    """
    # 确保掩码是二值的（0或1）
    mask = (mask > 0.5).float()
    # 转换为PIL图像并保存
    mask_np = mask.cpu().numpy()
    mask_np = np.squeeze(mask_np)  # 新增，去掉多余的维度
    mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_img.save(save_path)

def save_image_with_white_background(image_tensor, save_path):
    """
    保存图像，确保背景为白色
    
    Args:
        image_tensor: 图像张量 [C, H, W]
        save_path: 保存路径
    """
    # 确保图像值在[0,1]范围内
    image_tensor = torch.clamp(image_tensor, 0, 1)
    
    # 转换为numpy数组
    image_np = image_tensor.cpu().numpy()
    image_np = np.transpose(image_np, (1, 2, 0))  # [H, W, C]
    
    # 确保是RGB格式（3通道）
    if image_np.shape[2] == 4:  # RGBA
        # 分离RGB和Alpha通道
        rgb = image_np[:, :, :3]
        alpha = image_np[:, :, 3:4]
        
        # 合成到白色背景上
        result = rgb * alpha + (1 - alpha) * np.array([1, 1, 1])
        result = np.clip(result, 0, 1)
        image_np = result
    
    # 转换为PIL图像并保存
    image_pil = Image.fromarray((image_np * 255).astype(np.uint8))
    image_pil.save(save_path)

def render_set(dataset_path, model_path, views, gaussians, pipeline, background, is_test=False, start_class=0):
    print("[Stage 1] Start flexigaussian matching (ablation version - no neutral points) ...")
    
    #1.1 读取info.json，获取类别总数
    all_class_ids = set()
    masks_dir = os.path.join(dataset_path, "masks")
    print(f"掩码目录: {masks_dir}")

    # 构建info.json的路径
    info_json_path = os.path.join(masks_dir, "info.json")
    print(f"尝试从 {info_json_path} 读取类别总数...")

    try:
        if not os.path.exists(info_json_path):
            raise FileNotFoundError(f"错误: {info_json_path} 文件未找到。")
        
        with open(info_json_path, "r") as f:
            mask_info = json.load(f)
        
        if "total_categories" not in mask_info:
            raise KeyError("错误: 'total_categories' 键在 info.json 文件中未找到。")
            
        total_categories = mask_info["total_categories"]
        
        if not isinstance(total_categories, int) or total_categories <= 0:
            raise ValueError(f"错误: 'total_categories' 的值 ({total_categories}) 无效。它必须是一个正整数。")

        # 根据 total_categories 初始化 all_class_ids (0 到 total_categories-1)
        all_class_ids = set(range(total_categories))
        print(f"成功从 info.json 读取类别总数: {total_categories}。类别ID已初始化为 0 到 {total_categories - 1}。")
        print(f"all_class_ids: {sorted(list(all_class_ids))[:10]}... (前10个)") # 打印一部分验证

    except FileNotFoundError as e:
        print(e)
        print("请确保 info.json 文件存在于掩码目录中，或者回退到扫描文件名以获取类别。")
        return 
    except json.JSONDecodeError as e:
        print(f"错误: 解析 {info_json_path} 文件失败: {e}")
        print("请确保 info.json 是一个有效的JSON文件。")
        return
    except (KeyError, ValueError) as e:
        print(e)
        print(f"请检查 {info_json_path} 文件的内容和格式。")
        return
    except Exception as e:
        print(f"读取或处理 {info_json_path} 时发生未知错误: {e}")
        return

    if not all_class_ids:
        print("错误: 未能成功初始化类别ID。程序将终止。")
        return

    #1.2 匹配

    # 准备相机视图
    sorted_views = sorted(views, key=lambda Camera: Camera.image_name)
    print(f"已按图像名称排序的相机数量: {len(sorted_views)}")
    
    # 确保内存清理
    torch.cuda.empty_cache()
    
    # 创建临时目录用于存储中间结果
    stats_counts_path = os.path.join(model_path, "flexirun_ablation_no_neutral_result")
    os.makedirs(stats_counts_path, exist_ok=True)

    #开始匹配
    for class_id in all_class_ids:
        if class_id < start_class:
            continue
        print(f"处理类别 {class_id}")
        all_counts = None
        n_add=None
        n_sub=None
        cur_label_dir = os.path.join(stats_counts_path, "class_id_{:03d}_total_categories_{:03d}_label.pth".format(class_id, total_categories))

        if is_test: 
            # 测试集：直接使用训练好的标签
            print(f"开始处理类别 {class_id} 的测试集")
            if os.path.exists(cur_label_dir):
                print(f"测试集：加载训练好的标签 {cur_label_dir}")
                unique_label = torch.load(cur_label_dir).cuda()
                
                # 渲染结果
                class_dir = os.path.join(stats_counts_path, f"class_id_{class_id:03d}")
                for obj_idx in range(2):
                    obj_used_mask = (unique_label == obj_idx)
                    merged_render_path = os.path.join(class_dir, f"merged_render{obj_idx}")
                    merged_gt_path = os.path.join(class_dir, f"merged_gt{obj_idx}")
                    merged_mask_path = os.path.join(class_dir, f"merged_mask{obj_idx}")
                    os.makedirs(merged_render_path, exist_ok=True)
                    os.makedirs(merged_gt_path, exist_ok=True)
                    os.makedirs(merged_mask_path, exist_ok=True)

                    # for idx, view in enumerate(tqdm(views, desc="Rendering merged object {:03d}".format(obj_idx))):
                    #     render_pkg = flexirender(view, gaussians, pipeline, background, used_mask=obj_used_mask)
                    #     rendering = render_pkg["render"]
                    #     render_alpha = render_pkg["alpha"]
                    #     gt = view.original_image[0:3, :, :]
                        # 保存渲染结果
                        # render_path = os.path.join(merged_render_path, f"{idx:05d}.png")
                        # save_image_with_white_background(rendering, render_path)
                        # save_image_with_white_background(gt, os.path.join(merged_gt_path, f"{idx:05d}.png"))
                        
                        # 提取并保存掩码
                        # render_mask = (render_alpha > 0.5)
                        # save_mask(render_mask, os.path.join(merged_mask_path, f"{idx:05d}.png"))
                
                print(f"类别 {class_id} 渲染完成")
                continue  # 直接处理下一个类别
            else:
                print(f"警告：测试集缺少训练好的标签文件 {cur_label_dir}")
                continue
        else:
            # 训练集：计算 all_counts 并进行点群合并
            print(f"开始处理类别 {class_id} 的训练集")
            #针对其中一个class的每一个view
            for idx, view in enumerate(sorted_views):
                print(f"处理相机 {idx+1} / {len(sorted_views)}")

                # get mask
                image_name = view.image_name
                mask_path = os.path.join(masks_dir,f"{class_id}_{image_name}.png")

                if not os.path.exists(mask_path):
                    print(f"错误: 掩码文件 {mask_path} 不存在。")
                    continue
                    
                # 从PNG文件读取掩码
                try:
                    # 使用PIL读取掩码图像
                    mask_img = Image.open(mask_path)
                    mask_np = np.array(mask_img)
                    
                    # 如果掩码是RGB或RGBA格式，转为灰度图
                    if mask_np.ndim == 3 and mask_np.shape[2] > 1:
                        mask_np = mask_np[..., 0]  # 只取第一个通道
                    
                    # 转为PyTorch张量，并归一化到[0-1]
                    gt_mask = torch.from_numpy(mask_np).to(torch.float32) / 255.0
                    
                    # 确保掩码是二值的（可选）
                    gt_mask = (gt_mask > 0.5).float()
                    
                    # 将掩码移到GPU（如果需要）
                    if torch.cuda.is_available():
                        gt_mask = gt_mask.cuda()
                    
                    print(f"成功加载掩码: {mask_path}, 形状: {gt_mask.shape}")
                    
                    # 检查掩码是否全为0或全为1
                    if gt_mask.sum() == 0:
                        continue
                    elif gt_mask.sum() == gt_mask.numel():
                        continue
                        
                except Exception as e:
                    print(f"读取掩码文件出错: {mask_path}, 错误: {e}")
                    continue
                
                render_pkg = flexirender(view, gaussians, pipeline, background, gt_mask=gt_mask, obj_num=1) 
                rendering = render_pkg["render"]
                used_count = render_pkg["used_count"]
                if all_counts is None:
                    all_counts = torch.zeros_like(used_count)
                    # 指定设备为CUDA
                    n_add = torch.zeros(used_count.size(1), device=used_count.device)
                    n_sub = torch.zeros(used_count.size(1), device=used_count.device)
                gt = view.original_image[0:3, :, :]

                # 不再保存训练集的中间渲染结果和掩码，只保留最终分类结果
                
                all_counts += used_count
                n_add += (used_count[1] != 0).float()  #属于前景
                n_sub += (used_count[0] != 0).float()  #属于背景 

        if all_counts is not None:
            print(f"类别 {class_id}匹配完毕")
            
            # 消融实验：完全删除中立点处理逻辑，直接进行二分类
            print("消融实验：跳过中立点检测，直接进行前景/背景二分类...")
            
            # 找出前景和背景计数都为0的点（这些点没有在mask视野内）
            both_zero_mask = (all_counts[0, :] == 0) & (all_counts[1, :] == 0)
            both_zero_count = both_zero_mask.sum().item()
            print(f"前景和背景计数都为0的点数量: {both_zero_count}")
            
            # 对这些点直接设置为背景点，给一个较大的负值
            all_counts[0, both_zero_mask] = 1000.0  # 给一个很大的背景值
            all_counts[1, both_zero_mask] = 0.0     # 前景值保持为0
            
            # 对非零计数点进行归一化
            non_zero_mask = ~both_zero_mask
            if non_zero_mask.sum() > 0:
                all_counts[:, non_zero_mask] = torch.nn.functional.normalize(all_counts[:, non_zero_mask], dim=0)
            
            unique_label = all_counts.max(dim=0)[1]
            
            # 调试：统计分类结果
            foreground_count = (unique_label == 1).sum().item()
            background_count = (unique_label == 0).sum().item()
            total_count = unique_label.numel()
            foreground_ratio = foreground_count / total_count * 100.0
            background_ratio = background_count / total_count * 100.0
            print(f"类别 {class_id} 分类结果（消融实验版本 - 无中立点）:")
            print(f"  前景点: {foreground_count} ({foreground_ratio:.2f}%)")
            print(f"  背景点: {background_count} ({background_ratio:.2f}%)")
            print(f"  总点数: {total_count}")
            
            # 验证处理结果
            zero_points_as_background = (unique_label == 0) & both_zero_mask
            zero_points_as_foreground = (unique_label == 1) & both_zero_mask
            print(f"原本计数为0的点中，被分类为背景的: {zero_points_as_background.sum().item()}, 被分类为前景的: {zero_points_as_foreground.sum().item()}")
            
            # 保存消融实验信息
            ablation_info = {
                'foreground_count': foreground_count,
                'background_count': background_count,
                'total_count': total_count,
                'foreground_ratio': foreground_ratio,
                'background_ratio': background_ratio,
                'both_zero_count': both_zero_count,
                'ablation_type': 'no_neutral_points'
            }
            ablation_save_path = os.path.join(stats_counts_path, f"class_id_{class_id:03d}_ablation_info.pth")
            torch.save(ablation_info, ablation_save_path)
            print(f"消融实验信息已保存到: {ablation_save_path}")
            
            # 渲染结果（只渲染前景点和背景点）
            early_stop = False
            for obj_idx in range(2):  # 只渲染2个类别：0(背景), 1(前景)
                    
                obj_used_mask = (unique_label == obj_idx)
                # 为每个class_id创建单独的目录
                class_dir = os.path.join(stats_counts_path, f"class_id_{class_id:03d}")
                os.makedirs(class_dir, exist_ok=True)

                # 创建render子目录（不再保存gt真值图像）
                render_obj_path = os.path.join(class_dir, f"render{obj_idx}")

                # 确保目录存在（仅渲染目录）
                os.makedirs(render_obj_path, exist_ok=True)

                for idx, view in enumerate(tqdm(views, desc="Rendering object {:03d}".format(obj_idx))):
                    render_pkg = flexirender(view, gaussians, pipeline, background, used_mask=obj_used_mask)
                    rendering = render_pkg["render"]
                    # 保存渲染图像（不保存gt真值图像）
                    save_image_with_white_background(rendering, os.path.join(render_obj_path, f"{idx:05d}.png"))

                if early_stop: 
                    break

            print(f"类别 {class_id}保存完毕")
            if not is_test:
                print(f"保存训练好的标签到 {cur_label_dir}")
                torch.save(unique_label, cur_label_dir)

def render_sets(source_path, dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, 
                start_class=0):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        # 强制使用白色背景用于可视化
        bg_color = [1, 1, 1]  # 白色背景
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
            print("处理训练集...")
            render_set(source_path, dataset.model_path, scene.getTrainCameras(), gaussians, pipeline, background, is_test=False, start_class=start_class)

        if not skip_test:
            print("处理测试集...")
            render_set(source_path, dataset.model_path, scene.getTestCameras(), gaussians, pipeline, background, is_test=True, start_class=start_class)


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--start_class", default=0, type=int, help="从第几个类别开始匹配")

    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(args.source_path, model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, 
                args.start_class)
