# 按类别风格重绘渲染脚本 - 每个类别单独渲染一套视角图片
import numpy as np
import json
import os
from tqdm import tqdm
from gaussian_renderer import flexirender
from scene import Scene
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
from PIL import Image
import torch
import random
import colorsys

def generate_distinct_colors(n_colors, style="vibrant"):
    """生成更加吸引人的鲜艳颜色"""
    
    if style == "vibrant":
        # 鲜艳颜色方案
        color_scheme = [
            (1.0, 0.0, 0.0),    # 纯红色
            (0.0, 1.0, 0.0),    # 纯绿色
            (0.0, 0.0, 1.0),    # 纯蓝色
            (1.0, 1.0, 0.0),    # 纯黄色
            (1.0, 0.0, 1.0),    # 纯洋红
            (0.0, 1.0, 1.0),    # 纯青色
            (1.0, 0.5, 0.0),    # 橙色
            (0.5, 0.0, 1.0),    # 紫色
            (0.0, 0.5, 1.0),    # 天蓝色
            (1.0, 0.0, 0.5),    # 粉红色
        ]
    elif style == "neon":
        # 霓虹灯效果颜色
        color_scheme = [
            (1.0, 0.2, 0.2),    # 霓虹红
            (0.2, 1.0, 0.2),    # 霓虹绿
            (0.2, 0.2, 1.0),    # 霓虹蓝
            (1.0, 1.0, 0.2),    # 霓虹黄
            (1.0, 0.2, 1.0),    # 霓虹粉
            (0.2, 1.0, 1.0),    # 霓虹青
            (1.0, 0.6, 0.2),    # 霓虹橙
            (0.6, 0.2, 1.0),    # 霓虹紫
            (0.2, 0.6, 1.0),    # 霓虹天蓝
            (1.0, 0.2, 0.6),    # 霓虹玫红
        ]
    else:  # pastel
        # 柔和粉彩色
        color_scheme = [
            (1.0, 0.7, 0.7),    # 粉红
            (0.7, 1.0, 0.7),    # 粉绿
            (0.7, 0.7, 1.0),    # 粉蓝
            (1.0, 1.0, 0.7),    # 粉黄
            (1.0, 0.7, 1.0),    # 粉紫
            (0.7, 1.0, 1.0),    # 粉青
            (1.0, 0.8, 0.7),    # 粉橙
            (0.8, 0.7, 1.0),    # 淡紫
            (0.7, 0.8, 1.0),    # 淡蓝
            (1.0, 0.7, 0.8),    # 淡粉
        ]
    
    colors = []
    for i in range(n_colors):
        if i < len(color_scheme):
            # 使用预定义的颜色
            colors.append(color_scheme[i])
        else:
            # 如果类别数超过预定义颜色，使用HSV生成
            hue = i / n_colors
            if style == "vibrant":
                saturation = 0.9 + 0.1 * random.random()
                value = 0.8 + 0.2 * random.random()
            elif style == "neon":
                saturation = 0.8 + 0.2 * random.random()
                value = 0.9 + 0.1 * random.random()
            else:  # pastel
                saturation = 0.3 + 0.4 * random.random()
                value = 0.8 + 0.2 * random.random()
            rgb = colorsys.hsv_to_rgb(hue, saturation, value)
            colors.append(rgb)
    
    return colors

def apply_single_class_color(gaussians, class_label, target_color, background_color=[1, 1, 1]):
    """只对指定类别的点应用颜色，其他点保持原样"""
    print(f"应用类别颜色: RGB{target_color}")
    
    # 获取DC特征（RGB颜色）
    dc_features = gaussians._features_dc.squeeze(1)  # [N, 3]
    num_gaussians = dc_features.shape[0]
    
    print(f"处理 {num_gaussians} 个高斯点...")
    
    # 保存原始颜色用于恢复
    original_colors = dc_features.clone()
    
    # 只对指定类别的点应用颜色
    class_mask = (class_label == 1)  # 该类别的点
    class_indices = torch.where(class_mask)[0]
    
    print(f"类别包含 {len(class_indices)} 个点")
    
    for idx in class_indices:
        # 应用类别颜色
        target_color_tensor = torch.tensor(target_color, dtype=torch.float32, device=dc_features.device)
        
        # 获取原始颜色
        original_color = original_colors[idx, :3]
        original_luminance = 0.299 * original_color[0] + 0.587 * original_color[1] + 0.114 * original_color[2]
        
        # 增强颜色策略：让颜色更加突出
        if original_luminance > 0.1:  # 如果原始点不是太暗
            # 保持一定亮度但增强颜色对比度
            enhanced_color = target_color_tensor * 0.9 + 0.1  # 提高基础亮度
            styled_color = torch.clamp(enhanced_color, 0, 1)
        else:
            # 对于暗点，直接使用鲜艳颜色
            styled_color = target_color_tensor
        
        dc_features[idx, :3] = styled_color
    
    # 更新DC特征
    gaussians._features_dc = dc_features.unsqueeze(1)
    print("类别颜色应用完成")
    
    return original_colors

def restore_original_colors(gaussians, original_colors):
    """恢复原始颜色"""
    gaussians._features_dc = original_colors.unsqueeze(1)

def load_class_labels(model_path, total_categories):
    """加载所有类别的标签"""
    stats_counts_path = os.path.join(model_path, "mid_result")
    all_labels = {}
    
    for class_id in range(total_categories):
        label_path = os.path.join(stats_counts_path, f"class_id_{class_id:03d}_total_categories_{total_categories:03d}_label.pth")
        
        if os.path.exists(label_path):
            print(f"加载类别 {class_id} 的标签")
            label = torch.load(label_path)
            all_labels[class_id] = label
        else:
            print(f"警告: 类别 {class_id} 的标签文件不存在")
    
    return all_labels

def save_image(image_tensor, save_path, background_color=[1, 1, 1]):
    """保存图像"""
    image_tensor = torch.clamp(image_tensor, 0, 1)
    image_np = image_tensor.cpu().numpy()
    image_np = np.transpose(image_np, (1, 2, 0))
    
    if image_np.shape[2] == 4:  # RGBA
        rgb = image_np[:, :, :3]
        alpha = image_np[:, :, 3:4]
        bg_color_np = np.array(background_color)
        result = rgb * alpha + (1 - alpha) * bg_color_np
        result = np.clip(result, 0, 1)
        image_np = result
    
    image_pil = Image.fromarray((image_np * 255).astype(np.uint8))
    image_pil.save(save_path)

def render_single_class(views, gaussians, pipeline, background, class_id, class_label, class_color, output_dir):
    """渲染单个类别的风格化场景"""
    print(f"\n开始渲染类别 {class_id}...")
    
    # 保存原始颜色
    original_colors = gaussians._features_dc.squeeze(1).clone()
    
    try:
        # 应用该类别的颜色
        apply_single_class_color(gaussians, class_label, class_color, background)
        
        # 创建该类别的输出目录
        class_dir = os.path.join(output_dir, f"class_{class_id:03d}")
        os.makedirs(class_dir, exist_ok=True)
        
        # 渲染每个视图
        for idx, view in enumerate(tqdm(views, desc=f"渲染类别 {class_id}")):
            # 渲染风格化场景
            render_pkg = flexirender(view, gaussians, pipeline, background)
            styled_rendering = render_pkg["render"]
            
            # 保存风格化渲染结果
            render_path = os.path.join(class_dir, f"{idx:05d}_styled.png")
            save_image(styled_rendering, render_path, background)
            
            # 同时保存原始图像作为对比
            original_image = view.original_image[0:3, :, :]
            original_path = os.path.join(class_dir, f"{idx:05d}_original.png")
            save_image(original_image, original_path, background)
        
        print(f"类别 {class_id} 渲染完成，结果保存在: {class_dir}")
        
    finally:
        # 恢复原始颜色
        print(f"恢复类别 {class_id} 的原始颜色...")
        restore_original_colors(gaussians, original_colors)

def main():
    # 设置命令行参数
    parser = ArgumentParser(description="按类别风格重绘渲染脚本")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--output_dir", default="styled_output", type=str, help="输出目录")
    parser.add_argument("--background_color", default="white", type=str, 
                       choices=["white", "black", "gray"], help="背景颜色")
    parser.add_argument("--color_style", default="vibrant", type=str,
                       choices=["vibrant", "neon", "pastel"], help="颜色风格")
    parser.add_argument("--random_seed", default=42, type=int, help="随机种子")
    
    args = get_combined_args(parser)
    print("按类别风格重绘渲染 " + args.model_path)
    
    # 设置随机种子
    random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    
    # 初始化系统状态
    safe_state(args.quiet)
    
    # 设置背景颜色
    if args.background_color == "white":
        background_color = [1, 1, 1]
    elif args.background_color == "black":
        background_color = [0, 0, 0]
    else:  # gray
        background_color = [0.5, 0.5, 0.5]
    
    background = torch.tensor(background_color, dtype=torch.float32, device="cuda")
    
    with torch.no_grad():
        # 加载高斯模型和场景
        gaussians = GaussianModel(model.extract(args).sh_degree)
        scene = Scene(model.extract(args), gaussians, load_iteration=args.iteration, shuffle=False)
        
        # 读取info.json获取类别总数
        masks_dir = os.path.join(args.source_path, "masks")
        info_json_path = os.path.join(masks_dir, "info.json")
        
        if not os.path.exists(info_json_path):
            print(f"错误: {info_json_path} 文件未找到")
            return
        
        with open(info_json_path, "r") as f:
            mask_info = json.load(f)
        
        total_categories = mask_info["total_categories"]
        print(f"总类别数: {total_categories}")
        
        # 生成类别颜色
        print(f"生成{args.color_style}风格的颜色...")
        class_colors = generate_distinct_colors(total_categories, args.color_style)
        for i, color in enumerate(class_colors):
            print(f"类别 {i}: RGB{color}")
        
        # 加载类别标签
        print("加载类别标签...")
        all_labels = load_class_labels(args.model_path, total_categories)
        
        if not all_labels:
            print("错误: 没有找到任何类别标签文件")
            return
        
        # 创建输出目录
        output_dir = os.path.join(args.model_path, args.output_dir)
        os.makedirs(output_dir, exist_ok=True)
        
        # 获取所有相机视图
        views = scene.getTrainCameras()
        print(f"渲染 {len(views)} 个视图...")
        
        # 为每个类别单独渲染
        for class_id in range(total_categories):
            if class_id in all_labels:
                class_label = all_labels[class_id]
                class_color = class_colors[class_id]
                
                # 统计该类别的点数
                class_count = class_label.sum().item()
                print(f"\n类别 {class_id}: {class_count} 个点")
                
                if class_count > 0:
                    # 渲染该类别
                    render_single_class(views, gaussians, pipeline.extract(args), background, 
                                      class_id, class_label, class_color, output_dir)
                else:
                    print(f"类别 {class_id} 没有点，跳过")
            else:
                print(f"类别 {class_id} 标签不存在，跳过")
        
        print(f"\n所有类别渲染完成，结果保存在: {output_dir}")

if __name__ == "__main__":
    main() 