import numpy as np
import gc
import json
import os
from tqdm import tqdm
from gaussian_renderer import render,flexirender
from scene import Scene
import os
import cv2 
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
from gaussian_renderer import GaussianModel
from PIL import Image
import torch
import torch.nn.functional as F



def save_mask(mask, save_path):
    mask = (mask > 0.5).float()
    mask_np = mask.cpu().numpy()
    mask_np = np.squeeze(mask_np)
    mask_img = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_img.save(save_path)

def save_image_with_white_background(image_tensor, save_path):
    image_tensor = torch.clamp(image_tensor, 0, 1)
    
    image_np = image_tensor.cpu().numpy()
    image_np = np.transpose(image_np, (1, 2, 0))
    
    if image_np.shape[2] == 4:
        rgb = image_np[:, :, :3]
        alpha = image_np[:, :, 3:4]
        
        result = rgb * alpha + (1 - alpha) * np.array([1, 1, 1])
        result = np.clip(result, 0, 1)
        image_np = result
    
    image_pil = Image.fromarray((image_np * 255).astype(np.uint8))
    image_pil.save(save_path)

def render_set(dataset_path, model_path, views, gaussians, pipeline, background, is_test=False, start_class=0):
    print("[Stage 1] Start flexigaussian matching (simplified version - no post-processing) ...")
    
    all_class_ids = set()
    masks_dir = os.path.join(dataset_path, "masks")
    print(f"Mask directory: {masks_dir}")

    info_json_path = os.path.join(masks_dir, "info.json")
    print(f"Trying to read total categories from {info_json_path}...")

    try:
        if not os.path.exists(info_json_path):
            raise FileNotFoundError(f"Error: {info_json_path} file not found.")
        
        with open(info_json_path, "r") as f:
            mask_info = json.load(f)
        
        if "total_categories" not in mask_info:
            raise KeyError("Error: 'total_categories' key not found in info.json file.")
            
        total_categories = mask_info["total_categories"]
        
        if not isinstance(total_categories, int) or total_categories <= 0:
            raise ValueError(f"Error: 'total_categories' value ({total_categories}) is invalid. It must be a positive integer.")

        all_class_ids = set(range(total_categories))
        print(f"Successfully read total categories from info.json: {total_categories}. Class IDs initialized from 0 to {total_categories - 1}.")
        print(f"all_class_ids: {sorted(list(all_class_ids))[:10]}... (first 10)")

    except FileNotFoundError as e:
        print(e)
        print("Please ensure info.json file exists in the mask directory, or fall back to scanning filenames to get categories.")
        return 
    except json.JSONDecodeError as e:
        print(f"Error: Failed to parse {info_json_path} file: {e}")
        print("Please ensure info.json is a valid JSON file.")
        return
    except (KeyError, ValueError) as e:
        print(e)
        print(f"Please check the content and format of {info_json_path} file.")
        return
    except Exception as e:
        print(f"Unknown error occurred while reading or processing {info_json_path}: {e}")
        return

    if not all_class_ids:
        print("Error: Failed to initialize class IDs. Program will terminate.")
        return

    sorted_views = sorted(views, key=lambda Camera: Camera.image_name)
    print(f"Number of cameras sorted by image name: {len(sorted_views)}")
    
    torch.cuda.empty_cache()
    
    stats_counts_path = os.path.join(model_path, "flexirun_result")
    os.makedirs(stats_counts_path, exist_ok=True)
    for class_id in all_class_ids:
        if class_id < start_class:
            continue
        print(f"处理类别 {class_id}")
        all_counts = None
        n_add=None
        n_sub=None
        cur_label_dir = os.path.join(stats_counts_path, "class_id_{:03d}_total_categories_{:03d}_label.pth".format(class_id, total_categories))

        if is_test: 
            print(f"开始处理类别 {class_id} 的测试集")
            if os.path.exists(cur_label_dir):
                print(f"测试集：加载训练好的标签 {cur_label_dir}")
                unique_label = torch.load(cur_label_dir).cuda()
                
                class_dir = os.path.join(stats_counts_path, f"class_id_{class_id:03d}")
                for obj_idx in range(2):
                    obj_used_mask = (unique_label == obj_idx)
                    merged_render_path = os.path.join(class_dir, f"merged_render{obj_idx}")
                    merged_gt_path = os.path.join(class_dir, f"merged_gt{obj_idx}")
                    merged_mask_path = os.path.join(class_dir, f"merged_mask{obj_idx}")
                    os.makedirs(merged_render_path, exist_ok=True)
                    os.makedirs(merged_gt_path, exist_ok=True)
                    os.makedirs(merged_mask_path, exist_ok=True)
                
                print(f"类别 {class_id} 渲染完成")
                continue
            else:
                print(f"警告：测试集缺少训练好的标签文件 {cur_label_dir}")
                continue
        else:
            print(f"开始处理类别 {class_id} 的训练集")
            for idx, view in enumerate(sorted_views):
                print(f"处理相机 {idx+1} / {len(sorted_views)}")

                image_name = view.image_name
                mask_path = os.path.join(masks_dir,f"{class_id}_{image_name}.png")

                if not os.path.exists(mask_path):
                    print(f"错误: 掩码文件 {mask_path} 不存在。")
                    continue
                    
                try:
                    mask_img = Image.open(mask_path)
                    mask_np = np.array(mask_img)
                    
                    if mask_np.ndim == 3 and mask_np.shape[2] > 1:
                        mask_np = mask_np[..., 0]
                    
                    gt_mask = torch.from_numpy(mask_np).to(torch.float32) / 255.0
                    
                    gt_mask = (gt_mask > 0.5).float()
                    
                    if torch.cuda.is_available():
                        gt_mask = gt_mask.cuda()
                    
                    print(f"成功加载掩码: {mask_path}, 形状: {gt_mask.shape}")
                    
                    if gt_mask.sum() == 0:
                        continue
                    elif gt_mask.sum() == gt_mask.numel():
                        continue
                        
                except Exception as e:
                    print(f"读取掩码文件出错: {mask_path}, 错误: {e}")
                    continue
                
                render_pkg = flexirender(view, gaussians, pipeline, background, gt_mask=gt_mask, obj_num=1) 
                rendering = render_pkg["render"]
                used_count = render_pkg["used_count"]
                if all_counts is None:
                    all_counts = torch.zeros_like(used_count)
                    n_add = torch.zeros(used_count.size(1), device=used_count.device)
                    n_sub = torch.zeros(used_count.size(1), device=used_count.device)
                gt = view.original_image[0:3, :, :]
                
                all_counts += used_count
                n_add += (used_count[1] != 0).float()
                n_sub += (used_count[0] != 0).float() 

        if all_counts is not None:
            print(f"类别 {class_id}匹配完毕")
            
            print("开始中立点检测...")
            
            total_views = n_add + n_sub
            
            epsilon = 1e-8
            total_views_safe = total_views + epsilon
            
            p_foreground = n_add / total_views_safe
            p_background = n_sub / total_views_safe
            
            semantic_entropy = torch.zeros_like(total_views)
            
            foreground_mask = p_foreground > epsilon
            semantic_entropy[foreground_mask] -= p_foreground[foreground_mask] * torch.log2(p_foreground[foreground_mask] + epsilon)
            
            background_mask = p_background > epsilon
            semantic_entropy[background_mask] -= p_background[background_mask] * torch.log2(p_background[background_mask] + epsilon)
            
            print(f"语义熵统计信息:")
            print(f"  总点数: {semantic_entropy.numel()}")
            print(f"  最小值: {semantic_entropy.min().item():.6f}")
            print(f"  最大值: {semantic_entropy.max().item():.6f}")
            print(f"  平均值: {semantic_entropy.mean().item():.6f}")
            print(f"  中位数: {semantic_entropy.median().item():.6f}")
            
            sorted_entropy, _ = torch.sort(semantic_entropy, descending=True)
            top_05_percent_idx = int(0.005 * len(sorted_entropy))
            top_05_percent_value = sorted_entropy[top_05_percent_idx].item()
            print(f"  前0.5%的语义熵值: {top_05_percent_value:.6f}")
            
            top_1_percent_idx = int(0.01 * len(sorted_entropy))
            top_1_percent_value = sorted_entropy[top_1_percent_idx].item()
            print(f"  前1%的语义熵值: {top_1_percent_value:.6f}")
            
            top_5_percent_idx = int(0.05 * len(sorted_entropy))
            top_5_percent_value = sorted_entropy[top_5_percent_idx].item()
            print(f"  前5%的语义熵值: {top_5_percent_value:.6f}")
            
            thresholds = [0.5, 0.7, 0.8, 0.9, 0.95, 0.98, 0.99]
            print(f"不同阈值下的候选点数量:")
            for thresh in thresholds:
                count = (semantic_entropy > thresh).sum().item()
                percentage = count / semantic_entropy.numel() * 100
                print(f"  τ_h = {thresh}: {count} 点 ({percentage:.3f}%)")
            
            tau_h = 0.99
            candidate_mask = semantic_entropy > tau_h
            candidate_count = candidate_mask.sum().item()
            print(f"语义冲突候选点数量: {candidate_count}")
            
            if hasattr(gaussians, '_opacity'):
                alpha_values = gaussians._opacity.squeeze()
                tau_alpha = 0.01
                
                solid_points_mask = (alpha_values > tau_alpha) & candidate_mask
                solid_points_count = solid_points_mask.sum().item()
                print(f"保持原分类的实体点数量: {solid_points_count}")
                
                neutral_points_mask = (alpha_values <= tau_alpha) & candidate_mask
                neutral_points_count = neutral_points_mask.sum().item()
                print(f"确认的中立点数量: {neutral_points_count}")
                
                print(f"实体点将保持原来的分类，不进行重新分配")
                
                if neutral_points_count > 0:
                    all_counts[0, neutral_points_mask] = -1.0
                    all_counts[1, neutral_points_mask] = -1.0
                    print(f"已将 {neutral_points_count} 个中立点从语义监督中屏蔽")
            else:
                print("警告：无法获取高斯点的不透明度信息，跳过几何过滤")
                neutral_points_mask = candidate_mask
                neutral_points_count = neutral_points_mask.sum().item()
                print(f"仅基于语义熵的中立点数量: {neutral_points_count}")
            
            neutral_info = {
                'neutral_points_mask': neutral_points_mask.cpu(),
                'neutral_points_count': neutral_points_count,
                'semantic_entropy': semantic_entropy.cpu(),
                'candidate_count': candidate_count,
                'tau_h': tau_h,
                'tau_alpha': tau_alpha if hasattr(gaussians, '_opacity') else None
            }
            neutral_save_path = os.path.join(stats_counts_path, f"class_id_{class_id:03d}_neutral_points.pth")
            torch.save(neutral_info, neutral_save_path)
            print(f"中立点信息已保存到: {neutral_save_path}")
            
            both_zero_mask = (all_counts[0, :] == 0) & (all_counts[1, :] == 0)
            both_zero_count = both_zero_mask.sum().item()
            print(f"前景和背景计数都为0的点数量: {both_zero_count}")
            
            all_counts[0, both_zero_mask] = 1000.0
            all_counts[1, both_zero_count] = 0.0
            
            non_zero_mask = ~both_zero_mask & ~neutral_points_mask
            if non_zero_mask.sum() > 0:
                all_counts[:, non_zero_mask] = torch.nn.functional.normalize(all_counts[:, non_zero_mask], dim=0)
            
            unique_label = all_counts.max(dim=0)[1]
            
            if neutral_points_count > 0:
                unique_label[neutral_points_mask] = 2
            
            foreground_count = (unique_label == 1).sum().item()
            background_count = (unique_label == 0).sum().item()
            neutral_count = (unique_label == 2).sum().item()
            total_count = unique_label.numel()
            foreground_ratio = foreground_count / total_count * 100.0
            background_ratio = background_count / total_count * 100.0
            neutral_ratio = neutral_count / total_count * 100.0
            print(f"类别 {class_id} 分类结果:")
            print(f"  前景点: {foreground_count} ({foreground_ratio:.2f}%)")
            print(f"  背景点: {background_count} ({background_ratio:.2f}%)")
            print(f"  中立点: {neutral_count} ({neutral_ratio:.2f}%)")
            print(f"  总点数: {total_count}")
            
            if neutral_points_count > 0:
                actual_neutral_count = (unique_label == 2).sum().item()
                print(f"中立点验证: 期望 {neutral_points_count} 个，实际标记为中立的有 {actual_neutral_count} 个")
                
                if solid_points_count > 0:
                    solid_points_foreground = ((unique_label == 1) & solid_points_mask).sum().item()
                    solid_points_background = ((unique_label == 0) & solid_points_mask).sum().item()
                    print(f"实体点分类验证: 前景 {solid_points_foreground} 个，背景 {solid_points_background} 个，总计 {solid_points_count} 个")
            
            zero_points_as_background = (unique_label == 0) & both_zero_mask
            zero_points_as_foreground = (unique_label == 1) & both_zero_mask
            print(f"原本计数为0的点中，被分类为背景的: {zero_points_as_background.sum().item()}, 被分类为前景的: {zero_points_as_foreground.sum().item()}")
            
            early_stop = False
            for obj_idx in range(2):
                obj_used_mask = (unique_label == obj_idx)
                class_dir = os.path.join(stats_counts_path, f"class_id_{class_id:03d}")
                os.makedirs(class_dir, exist_ok=True)

                render_obj_path = os.path.join(class_dir, f"render{obj_idx}")

                os.makedirs(render_obj_path, exist_ok=True)

                for idx, view in enumerate(tqdm(views, desc="Rendering object {:03d}".format(obj_idx))):
                    render_pkg = flexirender(view, gaussians, pipeline, background, used_mask=obj_used_mask)
                    rendering = render_pkg["render"]
                    save_image_with_white_background(rendering, os.path.join(render_obj_path, f"{idx:05d}.png"))
                if early_stop: 
                    break

            print(f"类别 {class_id}保存完毕")
            if not is_test:
                print(f"保存训练好的标签到 {cur_label_dir}")
                torch.save(unique_label, cur_label_dir)

def render_sets(source_path, dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, 
                start_class=0):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        bg_color = [1, 1, 1]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
            print("处理训练集...")
            render_set(source_path, dataset.model_path, scene.getTrainCameras(), gaussians, pipeline, background, is_test=False, start_class=start_class)

        if not skip_test:
            print("处理测试集...")
            render_set(source_path, dataset.model_path, scene.getTestCameras(), gaussians, pipeline, background, is_test=True, start_class=start_class)


if __name__ == "__main__":
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--start_class", default=0, type=int, help="从第几个类别开始匹配")

    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    safe_state(args.quiet)

    render_sets(args.source_path, model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, 
                args.start_class)



            


        

