# 简单场景编辑器 - 支持复制类别到指定位置和场景合成渲染

import os
import json
import torch
import torch.nn as nn
import argparse
from tqdm import tqdm
import numpy as np
from PIL import Image

class SimpleGaussianModel:
    """简化的高斯模型类，兼容真实的高斯模型接口"""
    def __init__(self, sh_degree=3):
        self.max_sh_degree = sh_degree
        self.active_sh_degree = sh_degree
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.source_path = ""
    
    def load_from_gaussian_model(self, gaussian_model):
        """从真实的高斯模型加载数据"""
        if hasattr(gaussian_model, '_xyz'):
            self._xyz = gaussian_model._xyz.clone()
        if hasattr(gaussian_model, '_features_dc'):
            self._features_dc = gaussian_model._features_dc.clone()
        if hasattr(gaussian_model, '_features_rest'):
            self._features_rest = gaussian_model._features_rest.clone()
        if hasattr(gaussian_model, '_scaling'):
            self._scaling = gaussian_model._scaling.clone()
        if hasattr(gaussian_model, '_rotation'):
            self._rotation = gaussian_model._rotation.clone()
        if hasattr(gaussian_model, '_opacity'):
            self._opacity = gaussian_model._opacity.clone()
        
        # 复制其他属性
        if hasattr(gaussian_model, 'max_sh_degree'):
            self.max_sh_degree = gaussian_model.max_sh_degree
        if hasattr(gaussian_model, 'active_sh_degree'):
            self.active_sh_degree = gaussian_model.active_sh_degree
    
    def load_from_model(self, model_data):
        """从模型数据加载"""
        self._xyz = model_data.get('xyz', torch.empty(0))
        self._features_dc = model_data.get('features_dc', torch.empty(0))
        self._features_rest = model_data.get('features_rest', torch.empty(0))
        self._scaling = model_data.get('scaling', torch.empty(0))
        self._rotation = model_data.get('rotation', torch.empty(0))
        self._opacity = model_data.get('opacity', torch.empty(0))
    
    @property
    def get_xyz(self):
        return self._xyz
    
    @property
    def get_features_dc(self):
        return self._features_dc
    
    @property
    def get_features_rest(self):
        return self._features_rest
    
    @property
    def get_features(self):
        return torch.cat((self._features_dc, self._features_rest), dim=1)
    
    @property
    def get_scaling(self):
        return self._scaling
    
    @property
    def get_rotation(self):
        return self._rotation
    
    @property
    def get_opacity(self):
        return self._opacity

def save_image(image_tensor, save_path, background_color=[1, 1, 1]):
    """保存图像到文件"""
    # 确保图像值在[0,1]范围内
    image_tensor = torch.clamp(image_tensor, 0, 1)
    
    # 转换为numpy数组
    image_np = image_tensor.detach().cpu().numpy()
    
    # 如果是4D张量，取第一个
    if len(image_np.shape) == 4:
        image_np = image_np[0]
    
    # 调整维度顺序 (C, H, W) -> (H, W, C)
    if len(image_np.shape) == 3:
        image_np = np.transpose(image_np, (1, 2, 0))
    
    # 处理RGBA图像
    if image_np.shape[2] == 4:  # RGBA
        rgb = image_np[:, :, :3]
        alpha = image_np[:, :, 3:4]
        bg_color_np = np.array(background_color)
        result = rgb * alpha + (1 - alpha) * bg_color_np
        result = np.clip(result, 0, 1)
        image_np = result
    
    # 转换为0-255范围
    image_np = (image_np * 255).astype(np.uint8)
    
    # 保存为PIL图像
    image_pil = Image.fromarray(image_np)
    image_pil.save(save_path)
    print(f"图像已保存: {save_path}")

def generate_distinct_colors(n_colors, style="vibrant"):
    """生成不同的颜色"""
    if style == "vibrant":
        # 预定义的鲜艳颜色
        colors = [
            [255, 0, 0],      # 红色
            [0, 255, 0],      # 绿色
            [0, 0, 255],      # 蓝色
            [255, 255, 0],    # 黄色
            [255, 0, 255],    # 洋红
            [0, 255, 255],    # 青色
            [255, 128, 0],    # 橙色
            [128, 0, 255],    # 紫色
        ]
    else:
        # HSV 生成
        colors = []
        for i in range(n_colors):
            hue = i / n_colors
            saturation = 0.8
            value = 0.9
            # HSV to RGB
            h = int(hue * 6)
            f = hue * 6 - h
            p = value * (1 - saturation)
            q = value * (1 - f * saturation)
            t = value * (1 - (1 - f) * saturation)
            
            if h == 0:
                r, g, b = value, t, p
            elif h == 1:
                r, g, b = q, value, p
            elif h == 2:
                r, g, b = p, value, t
            elif h == 3:
                r, g, b = p, q, value
            elif h == 4:
                r, g, b = t, p, value
            else:
                r, g, b = value, p, q
            
            colors.append([int(r * 255), int(g * 255), int(b * 255)])
    
    return colors[:n_colors]

def apply_transformation(xyz, translation, rotation=None, scale=None):
    """应用3D变换到点云"""
    # 复制坐标
    transformed_xyz = xyz.clone()
    
    # 计算物体质心
    centroid = torch.mean(transformed_xyz, dim=0)
    
    # 应用平移（相对于质心）
    if translation is not None:
        transformed_xyz = transformed_xyz + torch.tensor(translation, device=xyz.device, dtype=xyz.dtype)
    
    # 应用旋转（目前只支持Z轴旋转）
    if rotation is not None:
        if isinstance(rotation, (int, float)):
            # 假设是Z轴旋转角度（弧度）
            cos_theta = torch.cos(torch.tensor(rotation, device=xyz.device, dtype=xyz.dtype))
            sin_theta = torch.sin(torch.tensor(rotation, device=xyz.device, dtype=xyz.dtype))
            
            # Z轴旋转矩阵
            x = transformed_xyz[:, 0]
            y = transformed_xyz[:, 1]
            z = transformed_xyz[:, 2]
            
            transformed_xyz[:, 0] = x * cos_theta - y * sin_theta
            transformed_xyz[:, 1] = x * sin_theta + y * cos_theta
            transformed_xyz[:, 2] = z
    
    # 应用缩放（以质心为中心）
    if scale is not None:
        if isinstance(scale, (int, float)):
            scale = [scale, scale, scale]
        scale_tensor = torch.tensor(scale, device=xyz.device, dtype=xyz.dtype)
        
        # 以质心为中心进行缩放
        transformed_xyz = (transformed_xyz - centroid) * scale_tensor + centroid
    
    return transformed_xyz

def _load_class_indices_and_centroid(gaussians, class_label, verbose=True):
    """加载指定类别的点索引与质心。返回 (indices, centroid) 或 (None, None)。"""
    import glob
    # 构造查找路径（支持多个根目录：model_path 与 source_path）
    roots = []
    if hasattr(gaussians, 'label_search_roots'):
        roots.extend(list(dict.fromkeys(gaussians.label_search_roots)))
    else:
        roots.append(getattr(gaussians, 'source_path', '.'))
    pattern = f"class_id_{class_label:03d}_total_categories_*_label.pth"
    label_files = []
    for root in roots:
        mid_result_path = os.path.join(root, "mid_result")
        stats_counts_path = os.path.join(root, "stats_counts")
        label_files = glob.glob(os.path.join(mid_result_path, pattern))
        if not label_files:
            label_files = glob.glob(os.path.join(stats_counts_path, pattern))
        if label_files:
            break
    if not label_files:
        if verbose:
            print(f"警告: 找不到类别 {class_label} 的标签文件，查找根目录: {roots}")
        return None, None
    label_file = label_files[0]
    try:
        unique_label = torch.load(label_file, map_location=gaussians._xyz.device)
    except Exception as e:
        if verbose:
            print(f"加载类别 {class_label} 标签失败: {e}")
        return None, None
    # 面向多种编码的前景判定
    class_mask = unique_label == 1
    indices = torch.where(class_mask)[0]
    if len(indices) == 0:
        unique_values = torch.unique(unique_label)
        if 2 in unique_values:
            indices = torch.where(unique_label == 2)[0]
        if len(indices) == 0:
            indices = torch.where(unique_label > 0)[0]
    if len(indices) == 0:
        if verbose:
            print(f"警告: 类别 {class_label} 没有找到任何点")
        return None, None
    xyz_sel = gaussians._xyz[indices]
    centroid = torch.mean(xyz_sel, dim=0)
    return indices, centroid

def delete_class_points_inplace(gaussian_model, class_label, verbose=True):
    """就地删除指定类别的点。返回删除的点数量。"""
    indices, _ = _load_class_indices_and_centroid(gaussian_model, class_label, verbose=verbose)
    if indices is None or len(indices) == 0:
        if verbose:
            print(f"删除跳过: 类别 {class_label} 未找到要删除的点")
        return 0
    num_total = gaussian_model._xyz.shape[0]
    mask = torch.ones(num_total, dtype=torch.bool, device=gaussian_model._xyz.device)
    mask[indices] = False
    num_delete = int((~mask).sum().item())
    if verbose:
        print(f"准备删除类别 {class_label} 的 {num_delete} 个点（总数 {num_total}）")
    with torch.no_grad():
        for attr_name in dir(gaussian_model):
            if not attr_name.startswith('_'):
                continue
            attr_value = getattr(gaussian_model, attr_name, None)
            if isinstance(attr_value, torch.Tensor) and attr_value.shape[0] == num_total:
                kept = attr_value[mask]
                setattr(gaussian_model, attr_name, nn.Parameter(kept))
        # 处理可选张量
        if hasattr(gaussian_model, 'max_radii2D') and isinstance(gaussian_model.max_radii2D, torch.Tensor):
            if gaussian_model.max_radii2D.shape[0] == num_total:
                gaussian_model.max_radii2D = gaussian_model.max_radii2D[mask]
    if verbose:
        print(f"已删除类别 {class_label} 的点，当前总点数: {gaussian_model._xyz.shape[0]}")
    return num_delete

def copy_class_to_position(gaussians, class_label, translation, rotation=None, scale=None, show_info_only=False):
    """复制指定类别的点到新位置"""
    # 加载类别标签 - 修复路径构建逻辑
    # 标签文件直接在 model_path 目录下
    
    # 首先尝试 mid_result 目录
    mid_result_path = os.path.join(gaussians.source_path, "mid_result")
    label_file = os.path.join(mid_result_path, f"class_id_{class_label:03d}_total_categories_*_label.pth")
    
    # 查找匹配的标签文件
    import glob
    label_files = glob.glob(label_file)
    
    print(f"尝试查找标签文件: {label_file}")
    print(f"找到的文件数量: {len(label_files)}")
    
    # 如果没找到，尝试 stats_counts 目录
    if not label_files:
        stats_counts_path = os.path.join(gaussians.source_path, "stats_counts")
        label_file = os.path.join(stats_counts_path, f"class_id_{class_label:03d}_total_categories_*_label.pth")
        label_files = glob.glob(label_file)
    
    if not label_files:
        print(f"警告: 找不到类别 {class_label} 的标签文件")
        print(f"尝试的路径: {mid_result_path} 和 {stats_counts_path}")
        print(f"尝试的文件模式: class_id_{class_label:03d}_total_categories_*_label.pth")
        print(f"当前 source_path: {gaussians.source_path}")
        
        # 检查有哪些类别文件存在
        print("检查可用的类别文件...")
        all_label_files = glob.glob(os.path.join(mid_result_path, "class_id_*_total_categories_*_label.pth"))
        if all_label_files:
            print("找到的类别文件:")
            for f in all_label_files:
                print(f"  {os.path.basename(f)}")
        else:
            print("在 mid_result 目录中没有找到任何类别标签文件")
            
        return None
    
    label_file = label_files[0]
    print(f"加载类别 {class_label} 的标签: {label_file}")
    
    try:
        unique_label = torch.load(label_file, map_location=gaussians._xyz.device)
        print(f"加载类别 {class_label} 的标签，形状: {unique_label.shape}")
        print(f"标签数据类型: {unique_label.dtype}")
        print(f"标签值范围: {unique_label.min().item()} - {unique_label.max().item()}")
        print(f"标签中1的数量: {(unique_label == 1).sum().item()}")
        print(f"标签中0的数量: {(unique_label == 0).sum().item()}")
    except Exception as e:
        print(f"加载类别 {class_label} 标签失败: {e}")
        return None
    
    # 获取属于该类别的点 - 支持多种标签值
    # 首先尝试值为1的点
    class_mask = unique_label == 1
    class_indices = torch.where(class_mask)[0]
    
    # 如果没有值为1的点，尝试其他可能的值
    if len(class_indices) == 0:
        print(f"没有找到值为1的点，尝试其他值...")
        unique_values = torch.unique(unique_label)
        print(f"标签中的唯一值: {unique_values}")
        
        # 尝试值为2的点（有些数据集使用2表示前景）
        if 2 in unique_values:
            class_mask = unique_label == 2
            class_indices = torch.where(class_mask)[0]
            print(f"尝试值为2的点，找到 {len(class_indices)} 个")
        
        # 如果还是没有，尝试非零值
        if len(class_indices) == 0:
            class_mask = unique_label > 0
            class_indices = torch.where(class_mask)[0]
            print(f"尝试所有非零值，找到 {len(class_indices)} 个")
    
    if len(class_indices) == 0:
        print(f"警告: 类别 {class_label} 没有找到任何点")
        print(f"尝试检查其他可能的值...")
        unique_values = torch.unique(unique_label)
        print(f"标签中的唯一值: {unique_values}")
        return None
    
    print(f"复制 {len(class_indices)} 个点")
    print(f"原始点云总数: {gaussians._xyz.shape[0]}")
    print(f"选中的点索引范围: {class_indices.min().item()} - {class_indices.max().item()}")
    
    # 显示物体的位置信息
    original_xyz = gaussians._xyz[class_indices]
    centroid = torch.mean(original_xyz, dim=0)
    min_coords = torch.min(original_xyz, dim=0)[0]
    max_coords = torch.max(original_xyz, dim=0)[0]
    
    print(f"\n=== 类别 {class_label} 的位置信息 ===")
    print(f"质心位置: X={centroid[0].item():.3f}, Y={centroid[1].item():.3f}, Z={centroid[2].item():.3f}")
    print(f"边界框: X[{min_coords[0].item():.3f}, {max_coords[0].item():.3f}], "
          f"Y[{min_coords[1].item():.3f}, {max_coords[1].item():.3f}], "
          f"Z[{min_coords[2].item():.3f}, {max_coords[2].item():.3f}]")
    print(f"物体尺寸: X={max_coords[0].item()-min_coords[0].item():.3f}, "
          f"Y={max_coords[1].item()-min_coords[1].item():.3f}, "
          f"Z={max_coords[2].item()-min_coords[2].item():.3f}")
    
    if show_info_only:
        print("仅显示信息，不进行复制")
        return None
    
    # 复制所有必要的属性
    copied_points = {}
    valid_attributes = 0
    
    for attr_name in dir(gaussians):
        if attr_name.startswith('_') and hasattr(gaussians, attr_name):
            attr_value = getattr(gaussians, attr_name)
            if isinstance(attr_value, torch.Tensor):
                # 检查张量是否为空
                if attr_value.numel() == 0:
                    print(f"警告: 属性 {attr_name} 是空张量，跳过")
                    continue
                
                # 检查索引是否超出范围
                if len(class_indices) > 0 and max(class_indices) >= attr_value.shape[0]:
                    print(f"警告: 类别索引超出属性 {attr_name} 的范围，跳过")
                    continue
                
                # 获取该类别的点
                if attr_name == '_xyz':
                    # 应用变换
                    original_xyz = attr_value[class_indices]
                    print(f"原始点云位置范围: X[{original_xyz[:, 0].min().item():.3f}, {original_xyz[:, 0].max().item():.3f}], "
                          f"Y[{original_xyz[:, 1].min().item():.3f}, {original_xyz[:, 1].max().item():.3f}], "
                          f"Z[{original_xyz[:, 2].min().item():.3f}, {original_xyz[:, 2].max().item():.3f}]")
                    transformed_xyz = apply_transformation(original_xyz, translation, rotation, scale)
                    print(f"变换后点云位置范围: X[{transformed_xyz[:, 0].min().item():.3f}, {transformed_xyz[:, 0].max().item():.3f}], "
                          f"Y[{transformed_xyz[:, 1].min().item():.3f}, {transformed_xyz[:, 1].max().item():.3f}], "
                          f"Z[{transformed_xyz[:, 2].min().item():.3f}, {transformed_xyz[:, 2].max().item():.3f}]")
                    copied_points[attr_name] = transformed_xyz
                    valid_attributes += 1
                else:
                    # 直接复制其他属性
                    copied_points[attr_name] = attr_value[class_indices].clone()
                    valid_attributes += 1
    
    # 检查是否有有效的属性被复制
    if valid_attributes == 0:
        print(f"错误: 没有成功复制任何属性")
        return None
    
    print(f"成功复制类别 {class_label} 到位置 {translation}，复制了 {valid_attributes} 个属性")
    return copied_points

def merge_points_inplace_real_model(gaussian_model, copied_points_list):
    """将复制的点原位合并到真实 GaussianModel 中，用于真实渲染。"""
    if not copied_points_list:
        print("没有需要合并的点")
        return
    with torch.no_grad():
        # 收集要拼接的张量
        concat_map = {}
        for copied in copied_points_list:
            for k, v in copied.items():
                if not isinstance(v, torch.Tensor):
                    continue
                # 确保 dtype 和 device 与目标一致
                base_tensor = getattr(gaussian_model, k, None)
                device = gaussian_model._xyz.device
                if isinstance(base_tensor, torch.Tensor):
                    device = base_tensor.device
                    v = v.to(device=device, dtype=base_tensor.dtype)
                else:
                    v = v.to(device=device)
                concat_map.setdefault(k, []).append(v)

        def cat_or_keep(name, base):
            tensors = concat_map.get(name, [])
            if not tensors:
                return base
            return torch.cat([base, *tensors], dim=0)

        gaussian_model._xyz = nn.Parameter(cat_or_keep('_xyz', gaussian_model._xyz))
        if hasattr(gaussian_model, '_features_dc'):
            gaussian_model._features_dc = nn.Parameter(cat_or_keep('_features_dc', gaussian_model._features_dc))
        if hasattr(gaussian_model, '_features_rest'):
            gaussian_model._features_rest = nn.Parameter(cat_or_keep('_features_rest', gaussian_model._features_rest))
        if hasattr(gaussian_model, '_scaling'):
            gaussian_model._scaling = nn.Parameter(cat_or_keep('_scaling', gaussian_model._scaling))
        if hasattr(gaussian_model, '_rotation'):
            gaussian_model._rotation = nn.Parameter(cat_or_keep('_rotation', gaussian_model._rotation))
        if hasattr(gaussian_model, '_opacity'):
            gaussian_model._opacity = nn.Parameter(cat_or_keep('_opacity', gaussian_model._opacity))
        # 可选: 扩展占位张量
        if hasattr(gaussian_model, 'max_radii2D') and isinstance(gaussian_model.max_radii2D, torch.Tensor):
            if gaussian_model.max_radii2D.numel() == gaussian_model._xyz.shape[0] - sum(v.shape[0] for v in concat_map.get('_xyz', [])):
                gaussian_model.max_radii2D = torch.cat([
                    gaussian_model.max_radii2D,
                    torch.zeros((sum(v.shape[0] for v in concat_map.get('_xyz', []))), device=gaussian_model._xyz.device)
                ], dim=0)
        print("已将复制点合并到真实模型中：",
              gaussian_model._xyz.shape[0])
        print(f"合并后点云位置范围: X[{gaussian_model._xyz[:, 0].min().item():.3f}, {gaussian_model._xyz[:, 0].max().item():.3f}], "
              f"Y[{gaussian_model._xyz[:, 1].min().item():.3f}, {gaussian_model._xyz[:, 1].max().item():.3f}], "
              f"Z[{gaussian_model._xyz[:, 2].min().item():.3f}, {gaussian_model._xyz[:, 2].max().item():.3f}]")

def create_simple_camera_view(width=800, height=600, fov=60):
    """创建简单的相机视图"""
    class SimpleCamera:
        def __init__(self, width, height, fov):
            self.image_width = width
            self.image_height = height
            self.FoVx = fov
            self.FoVy = fov
            self.world_view_transform = torch.eye(4)
            self.full_proj_transform = torch.eye(4)
            self.camera_center = torch.tensor([0.0, 0.0, 0.0])
    
    return SimpleCamera(width, height, fov)

def render_scene_views(scene, pipeline, background, output_dir, split='train'):
    """使用真实相机与渲染器渲染当前场景的视图。"""
    from gaussian_renderer import render
    render_dir = os.path.join(output_dir, f"edited_scene_{split}")
    os.makedirs(render_dir, exist_ok=True)
    views = scene.getTrainCameras() if split == 'train' else scene.getTestCameras()
    print(f"使用 {split} 视图渲染，共 {len(views)} 个…")
    for idx, view in enumerate(tqdm(views, desc=f"Rendering {split}")):
        try:
            render_pkg = render(view, scene.gaussians, pipeline, background)
            rendering = render_pkg["render"]
            save_image(rendering, os.path.join(render_dir, f"{idx:05d}.png"))
        except Exception as e:
            print(f"视图 {idx} 渲染失败: {e}")
            continue

def render_simple_fallback(gaussians, render_dir, num_views):
    """简单的渲染回退方案"""
    print("使用简单渲染回退方案...")
    
    for idx in tqdm(range(num_views), desc="简单渲染"):
        try:
            # 检查是否有有效的点云数据
            if not hasattr(gaussians, '_xyz') or gaussians._xyz.numel() == 0:
                print(f"错误: 视图 {idx} 没有有效的点云数据")
                continue
                
            xyz = gaussians._xyz.cpu().numpy()
            print(f"视图 {idx}: 处理 {len(xyz)} 个点")
            
            # 创建简单的2D投影图像
            img_size = 800
            img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255
            
            # 简单的正交投影
            x_proj = ((xyz[:, 0] + 5) / 10 * img_size).astype(int)
            y_proj = ((xyz[:, 2] + 5) / 10 * img_size).astype(int)
            
            # 过滤有效坐标
            valid_mask = (x_proj >= 0) & (x_proj < img_size) & (y_proj >= 0) & (y_proj < img_size)
            x_proj = x_proj[valid_mask]
            y_proj = y_proj[valid_mask]
            
            print(f"视图 {idx}: 有效投影点 {len(x_proj)} 个")
            
            # 绘制点
            for i in range(len(x_proj)):
                img[y_proj[i], x_proj[i]] = [255, 0, 0]  # 红色点
            
            # 保存图像
            render_path = os.path.join(render_dir, f"{idx:05d}.png")
            Image.fromarray(img).save(render_path)
            print(f"视图 {idx} 图像已保存: {render_path}")
            
        except Exception as e:
            print(f"简单渲染视图 {idx} 失败: {e}")
            continue

def create_default_edits(total_categories):
    """创建默认的编辑配置 - 只处理1个类别"""
    edits = []
    
    # 只处理第一个类别（类别0）
    if total_categories > 0:
        edits.append({
            "class_id": 0,
            "action": "copy",
            "translation": [2.0, 0.0, 1.5],  # 向右移动2.0，向上移动1.5
            "rotation": 0.0,                   # 不旋转
            "scale": [2.0, 2.0, 2.0],         # 放大2倍
            "description": "复制类别0到右侧，向上移动，放大2倍"
        })
    
    return edits

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="场景编辑器 - 复制类别到指定位置")
    
    # 基本参数
    parser.add_argument("--source_path", type=str, required=True, help="数据源路径")
    parser.add_argument("--model_path", type=str, required=True, help="模型路径")
    parser.add_argument("--iteration", type=int, required=True, help="模型迭代次数")
    parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
    
    # 编辑配置
    parser.add_argument("--edit_config", type=str, default=None, help="编辑配置文件路径")
    parser.add_argument("--class_id", type=int, default=0, help="要编辑的类别ID（默认0）")
    parser.add_argument("--translation", type=float, nargs=3, default=[2.0, 0.0, 1.5], 
                       help="移动位置 [x, y, z]（默认: 2.0 0.0 1.5）")
    parser.add_argument("--rotation", type=float, default=0.0, help="旋转角度（弧度，默认0.0）")
    parser.add_argument("--scale", type=float, nargs=3, default=[2.0, 2.0, 2.0], 
                       help="缩放比例 [sx, sy, sz]（默认: 2.0 2.0 2.0）")
    parser.add_argument("--show_info", action="store_true",
                       help="仅显示类别位置信息，不进行复制和渲染")
    parser.add_argument("--target_class_id", type=int, default=None,
                       help="将 --class_id 的对象复制到目标类别质心位置（自动计算平移向量）")
    parser.add_argument("--move", action="store_true",
                       help="启用移动模式：在目标类位置放置源类物体前，先删除目标类别点；默认保留源类原位")
    parser.add_argument("--remove_source", action="store_true",
                       help="配合 --move 使用：除在目标位置放置外，还删除源类别原位置的点（效果为真正的移动而非复制）")
    
    # 获取参数
    args = parser.parse_args()
    
    print("=" * 60)
    print("场景编辑器 - 复制类别到指定位置")
    print("=" * 60)
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 加载真实的高斯模型
    print(f"加载真实的高斯模型: {args.model_path}")
    try:
        # 导入必要的模块
        import sys
        sys.path.append('.')
        
        from scene import Scene, GaussianModel
        from arguments import ModelParams, PipelineParams
        
        # 创建模型参数
        parser = argparse.ArgumentParser()
        model_params = ModelParams(parser)
        model_params.model_path = args.model_path
        model_params.source_path = args.source_path
        model_params.images = "images"  # 默认图像目录
        model_params.eval = False       # 非评估模式
        model_params.white_background = False  # 非白色背景
        model_params.depths = ""        # 深度目录
        model_params.resolution = -1    # 分辨率
        model_params.data_device = "cuda"  # 数据设备
        
        # 创建高斯模型
        print("创建高斯模型...")
        # 实例化真实的 GaussianModel 并传入 Scene
        gaussian_model = GaussianModel(model_params.sh_degree)
        print("加载场景...")
        scene = Scene(model_params, gaussian_model, load_iteration=args.iteration, shuffle=False)
        
        # 记录模型路径以便读取标签
        gaussian_model.source_path = args.model_path
        # 同时记录可搜索的标签根目录（模型目录与数据源目录）
        gaussian_model.label_search_roots = [args.model_path, args.source_path]
        
        print(f"成功加载真实的高斯模型")
        print(f"高斯点数量: {gaussian_model._xyz.shape[0] if hasattr(gaussian_model, '_xyz') else 0}")
        
        # 验证数据完整性
        if gaussian_model._xyz.numel() == 0:
            print("错误: 高斯模型中没有有效的点云数据")
            return
            
    except ImportError as e:
        print(f"导入模块失败: {e}")
        print("请确保以下条件满足：")
        print("1. 在正确的目录中运行脚本")
        print("2. 已安装所有必要的依赖")
        print("3. 模型文件完整且未损坏")
        return
    except Exception as e:
        print(f"加载高斯模型失败: {e}")
        print("请确保以下条件满足：")
        print("1. 模型路径正确且包含检查点文件")
        print("2. 检查点文件存在且可读")
        print("3. 模型文件完整且未损坏")
        import traceback
        traceback.print_exc()
        return
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 当指定了目标类别时，自动计算从源类到目标类的平移向量
    auto_translation = None
    pending_delete_target = False
    pending_delete_source = False
    if args.target_class_id is not None:
        print(f"自动计算平移向量: 从类别 {args.class_id} 到类别 {args.target_class_id}")
        # 需要已加载的 gaussian_model
        src_idx, src_centroid = _load_class_indices_and_centroid(gaussian_model, args.class_id, verbose=True)
        tgt_idx, tgt_centroid = _load_class_indices_and_centroid(gaussian_model, args.target_class_id, verbose=True)
        if src_centroid is None or tgt_centroid is None:
            print("错误: 无法计算质心，检查标签文件是否存在且有效。")
            return
        delta = (tgt_centroid - src_centroid)
        auto_translation = delta.detach().cpu().tolist()
        print(f"源类质心: {src_centroid.detach().cpu().numpy()}")
        print(f"目标质心: {tgt_centroid.detach().cpu().numpy()}")
        print(f"自动平移向量: {auto_translation}")
        # 若用户未显式设置缩放（仍为默认），自动改为保持原尺寸
        if isinstance(args.scale, list) and args.scale == [2.0, 2.0, 2.0]:
            args.scale = [1.0, 1.0, 1.0]
            print("检测到 target_class_id，自动将缩放设置为 [1,1,1] 以保持大小不变。")
        # 移动模式下，延后删除到复制成功之后再执行，避免只删不添
        if args.move:
            pending_delete_target = True
            pending_delete_source = bool(args.remove_source)
            print("移动模式开启：将于复制成功后删除目标/源类别点…")
        # 若仅查看信息
        if args.show_info:
            print("仅显示信息模式：已输出自动平移向量，不执行复制与渲染。")
            return
    
    # 确定编辑配置
    edit_config = getattr(args, 'edit_config', None)
    if edit_config and os.path.exists(edit_config):
        print(f"从配置文件加载编辑配置: {edit_config}")
        with open(edit_config, 'r', encoding='utf-8') as f:
            edits = json.load(f)
    else:
        print("使用命令行参数创建编辑配置")
        edits = [{
            "class_id": args.class_id,
            "action": "copy",
            "translation": auto_translation if auto_translation is not None else args.translation,
            "rotation": args.rotation,
            "scale": args.scale,
            "description": f"复制类别{args.class_id}到位置{auto_translation if auto_translation is not None else args.translation}"
        }]
    
    print(f"\n编辑配置:")
    for i, edit in enumerate(edits):
        print(f"  {i+1}. {edit['description']}")
    
    # 应用编辑：基于标签复制点并合并至真实模型
    copied_points_list = []
    for edit in edits:
        print(f"\n处理编辑: {edit['action']} 类别 {edit['class_id']}")
        
        if edit['action'] == 'copy':
            copied_points = copy_class_to_position(
                gaussian_model, 
                edit['class_id'], 
                edit['translation'], 
                edit.get('rotation'), 
                edit.get('scale'),
                show_info_only=args.show_info
            )
            if copied_points:
                # 显示复制点数与范围，便于确认
                if '_xyz' in copied_points:
                    xyz_new = copied_points['_xyz']
                    print(f"已准备复制 {xyz_new.shape[0]} 个点，位置范围: "
                          f"X[{xyz_new[:,0].min().item():.3f}, {xyz_new[:,0].max().item():.3f}], "
                          f"Y[{xyz_new[:,1].min().item():.3f}, {xyz_new[:,1].max().item():.3f}], "
                          f"Z[{xyz_new[:,2].min().item():.3f}, {xyz_new[:,2].max().item():.3f}]")
            if copied_points and not args.show_info:
                copied_points_list.append(copied_points)
    
    # 合并到真实模型并渲染
    if args.show_info:
        print("\n仅显示信息模式，跳过复制和渲染")
    elif copied_points_list:
        # 在合并之前执行必要的删除，确保不会出现“只删不添”
        if args.target_class_id is not None and pending_delete_target:
            print("删除目标类别点（移动模式）…")
            delete_class_points_inplace(gaussian_model, args.target_class_id, verbose=True)
        if args.target_class_id is not None and pending_delete_source:
            print("删除源类别原位置点（真正移动）…")
            delete_class_points_inplace(gaussian_model, args.class_id, verbose=True)
        print(f"\n合并高斯点到真实模型…")
        merge_points_inplace_real_model(gaussian_model, copied_points_list)
        # 构建渲染管线与背景
        pp = PipelineParams(argparse.ArgumentParser())
        pp.convert_SHs_python = False
        pp.compute_cov3D_python = False
        pp.debug = False
        bg_color = [1, 1, 1] if not model_params.white_background else [1, 1, 1]
        background = torch.tensor(bg_color, dtype=torch.float32, device=gaussian_model._xyz.device)
        print(f"\n开始渲染…")
        render_scene_views(scene, pp, background, args.output_dir, split='train')
    else:
        print("没有成功复制的点，跳过删除与渲染，以避免误删。")
    
    print("\n场景编辑完成！")

if __name__ == "__main__":
    main() 