#!/usr/bin/env python3
"""
场景类别位置交换器
支持选择两个类别并交换它们的位置
"""

import os
import sys
import json
import argparse
import torch
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import numpy as np

def save_image(image_tensor, save_path, background_color=[1, 1, 1]):
    """保存图像"""
    try:
        # 确保图像在0-1范围内
        image_tensor = torch.clamp(image_tensor, 0, 1)
        
        # 转换为numpy数组
        image_np = image_tensor.detach().cpu().numpy()
        
        # 检查图像形状
        if len(image_np.shape) == 3:
            if image_np.shape[0] == 1:  # 如果第一维是1，需要去掉
                image_np = image_np.squeeze(0)
            elif image_np.shape[2] == 1:  # 如果最后一维是1，需要去掉
                image_np = image_np.squeeze(-1)
        
        # 确保是3通道RGB图像
        if len(image_np.shape) == 2:  # 灰度图像
            image_np = np.stack([image_np, image_np, image_np], axis=-1)
        elif image_np.shape[-1] == 4:  # RGBA图像
            rgb = image_np[:, :, :3]
            alpha = image_np[:, :, 3:4]
            bg_color_np = np.array(background_color)
            result = rgb * alpha + (1 - alpha) * bg_color_np
            result = np.clip(result, 0, 1)
            image_np = result
        elif image_np.shape[-1] == 1:  # 单通道图像
            image_np = np.concatenate([image_np, image_np, image_np], axis=-1)
        
        # 确保数据类型正确
        if image_np.dtype != np.float32 and image_np.dtype != np.float64:
            image_np = image_np.astype(np.float32)
        
        # 转换为0-255范围
        image_np = (image_np * 255).astype(np.uint8)
        
        # 保存为PIL图像
        image_pil = Image.fromarray(image_np)
        image_pil.save(save_path)
        print(f"图像已保存: {save_path}")
        
    except Exception as e:
        print(f"保存图像失败: {e}")
        print(f"图像张量形状: {image_tensor.shape}")
        print(f"图像张量数据类型: {image_tensor.dtype}")
        print(f"图像张量值范围: [{image_tensor.min().item():.3f}, {image_tensor.max().item():.3f}]")
        # 尝试使用torchvision保存
        try:
            import torchvision
            torchvision.utils.save_image(image_tensor, save_path)
            print(f"使用torchvision保存图像: {save_path}")
        except Exception as e2:
            print(f"torchvision保存也失败: {e2}")

def apply_transformation(xyz, translation, rotation=None, scale=None):
    """应用3D变换到点云"""
    # 复制坐标
    transformed_xyz = xyz.clone()
    
    # 计算物体质心
    centroid = torch.mean(transformed_xyz, dim=0)
    
    # 应用平移（相对于质心）
    if translation is not None:
        transformed_xyz = transformed_xyz + torch.tensor(translation, device=xyz.device, dtype=xyz.dtype)
    
    # 应用旋转（目前只支持Z轴旋转）
    if rotation is not None:
        if isinstance(rotation, (int, float)):
            # 假设是Z轴旋转角度（弧度）
            cos_theta = torch.cos(torch.tensor(rotation, device=xyz.device, dtype=xyz.dtype))
            sin_theta = torch.sin(torch.tensor(rotation, device=xyz.device, dtype=xyz.dtype))
            
            # Z轴旋转矩阵
            x = transformed_xyz[:, 0]
            y = transformed_xyz[:, 1]
            z = transformed_xyz[:, 2]
            
            transformed_xyz[:, 0] = x * cos_theta - y * sin_theta
            transformed_xyz[:, 1] = x * sin_theta + y * cos_theta
            transformed_xyz[:, 2] = z
    
    # 应用缩放（以质心为中心）
    if scale is not None:
        if isinstance(scale, (int, float)):
            scale = [scale, scale, scale]
        scale_tensor = torch.tensor(scale, device=xyz.device, dtype=xyz.dtype)
        
        # 以质心为中心进行缩放
        transformed_xyz = (transformed_xyz - centroid) * scale_tensor + centroid
    
    return transformed_xyz

def load_class_points(gaussians, class_label):
    """加载指定类别的点"""
    # 加载类别标签
    mid_result_path = os.path.join(gaussians.source_path, "mid_result")
    label_file = os.path.join(mid_result_path, f"class_id_{class_label:03d}_total_categories_*_label.pth")
    
    # 查找匹配的标签文件
    import glob
    label_files = glob.glob(label_file)
    
    print(f"尝试查找标签文件: {label_file}")
    print(f"找到的文件数量: {len(label_files)}")
    
    # 如果没找到，尝试 stats_counts 目录
    if not label_files:
        stats_counts_path = os.path.join(gaussians.source_path, "stats_counts")
        label_file = os.path.join(stats_counts_path, f"class_id_{class_label:03d}_total_categories_*_label.pth")
        label_files = glob.glob(label_file)
    
    if not label_files:
        print(f"警告: 找不到类别 {class_label} 的标签文件")
        print(f"尝试的路径: {mid_result_path} 和 {stats_counts_path}")
        print(f"尝试的文件模式: class_id_{class_label:03d}_total_categories_*_label.pth")
        print(f"当前 source_path: {gaussians.source_path}")
        
        # 检查有哪些类别文件存在
        print("检查可用的类别文件...")
        all_label_files = glob.glob(os.path.join(mid_result_path, "class_id_*_total_categories_*_label.pth"))
        if all_label_files:
            print("找到的类别文件:")
            for f in all_label_files:
                print(f"  {os.path.basename(f)}")
        else:
            print("在 mid_result 目录中没有找到任何类别标签文件")
            
        return None
    
    label_file = label_files[0]
    print(f"加载类别 {class_label} 的标签: {label_file}")
    
    try:
        unique_label = torch.load(label_file, map_location=gaussians._xyz.device)
        print(f"加载类别 {class_label} 的标签，形状: {unique_label.shape}")
        print(f"标签数据类型: {unique_label.dtype}")
        print(f"标签值范围: {unique_label.min().item()} - {unique_label.max().item()}")
        print(f"标签中1的数量: {(unique_label == 1).sum().item()}")
        print(f"标签中0的数量: {(unique_label == 0).sum().item()}")
    except Exception as e:
        print(f"加载类别 {class_label} 标签失败: {e}")
        return None
    
    # 获取属于该类别的点 - 支持多种标签值
    # 首先尝试值为1的点
    class_mask = unique_label == 1
    class_indices = torch.where(class_mask)[0]
    
    # 如果没有值为1的点，尝试其他可能的值
    if len(class_indices) == 0:
        print(f"没有找到值为1的点，尝试其他值...")
        unique_values = torch.unique(unique_label)
        print(f"标签中的唯一值: {unique_values}")
        
        # 尝试值为2的点（有些数据集使用2表示前景）
        if 2 in unique_values:
            class_mask = unique_label == 2
            class_indices = torch.where(class_mask)[0]
            print(f"尝试值为2的点，找到 {len(class_indices)} 个")
        
        # 如果还是没有，尝试非零值
        if len(class_indices) == 0:
            class_mask = unique_label > 0
            class_indices = torch.where(class_mask)[0]
            print(f"尝试所有非零值，找到 {len(class_indices)} 个")
    
    if len(class_indices) == 0:
        print(f"警告: 类别 {class_label} 没有找到任何点")
        print(f"尝试检查其他可能的值...")
        unique_values = torch.unique(unique_label)
        print(f"标签中的唯一值: {unique_values}")
        return None
    
    print(f"找到类别 {class_label} 的 {len(class_indices)} 个点")
    print(f"原始点云总数: {gaussians._xyz.shape[0]}")
    print(f"选中的点索引范围: {class_indices.min().item()} - {class_indices.max().item()}")
    
    # 显示物体的位置信息
    original_xyz = gaussians._xyz[class_indices]
    centroid = torch.mean(original_xyz, dim=0)
    min_coords = torch.min(original_xyz, dim=0)[0]
    max_coords = torch.max(original_xyz, dim=0)[0]
    
    print(f"\n=== 类别 {class_label} 的位置信息 ===")
    print(f"质心位置: X={centroid[0].item():.3f}, Y={centroid[1].item():.3f}, Z={centroid[2].item():.3f}")
    print(f"边界框: X[{min_coords[0].item():.3f}, {max_coords[0].item():.3f}], "
          f"Y[{min_coords[1].item():.3f}, {max_coords[1].item():.3f}], "
          f"Z[{min_coords[2].item():.3f}, {max_coords[2].item():.3f}]")
    print(f"物体尺寸: X={max_coords[0].item()-min_coords[0].item():.3f}, "
          f"Y={max_coords[1].item()-min_coords[1].item():.3f}, "
          f"Z={max_coords[2].item()-min_coords[2].item():.3f}")
    
    # 复制所有必要的属性
    copied_points = {}
    valid_attributes = 0
    
    for attr_name in dir(gaussians):
        if attr_name.startswith('_') and hasattr(gaussians, attr_name):
            attr_value = getattr(gaussians, attr_name)
            if isinstance(attr_value, torch.Tensor):
                # 检查张量是否为空
                if attr_value.numel() == 0:
                    print(f"警告: 属性 {attr_name} 是空张量，跳过")
                    continue
                
                # 检查索引是否超出范围
                if len(class_indices) > 0 and max(class_indices) >= attr_value.shape[0]:
                    print(f"警告: 类别索引超出属性 {attr_name} 的范围，跳过")
                    continue
                
                # 获取该类别的点
                copied_points[attr_name] = attr_value[class_indices].clone()
                valid_attributes += 1
    
    # 检查是否有有效的属性被复制
    if valid_attributes == 0:
        print(f"错误: 没有成功复制任何属性")
        return None
    
    print(f"成功复制类别 {class_label}，复制了 {valid_attributes} 个属性")
    
    # 添加位置信息到返回结果
    copied_points['centroid'] = centroid
    copied_points['indices'] = class_indices
    
    return copied_points

def swap_class_positions(gaussians, class1_id, class2_id):
    """交换两个类别的位置 - 先删除原始位置，再在新位置放置"""
    print(f"\n开始交换类别 {class1_id} 和类别 {class2_id} 的位置...")
    
    # 加载两个类别的点
    class1_points = load_class_points(gaussians, class1_id)
    class2_points = load_class_points(gaussians, class2_id)
    
    if class1_points is None or class2_points is None:
        print("无法加载类别点，交换失败")
        return None
    
    # 获取两个类别的质心和索引
    centroid1 = class1_points['centroid']
    centroid2 = class2_points['centroid']
    indices1 = class1_points['indices']
    indices2 = class2_points['indices']
    
    print(f"\n类别 {class1_id} 质心: {centroid1}, 点数: {len(indices1)}")
    print(f"类别 {class2_id} 质心: {centroid2}, 点数: {len(indices2)}")
    
    # 计算移动向量（交换位置）
    translation1_to_2 = centroid2 - centroid1  # 类别1移动到类别2的位置
    translation2_to_1 = centroid1 - centroid2  # 类别2移动到类别1的位置
    
    print(f"类别 {class1_id} 移动到类别 {class2_id} 位置的向量: {translation1_to_2}")
    print(f"类别 {class2_id} 移动到类别 {class1_id} 位置的向量: {translation2_to_1}")
    
    # 创建交换后的点集
    swapped_points = []
    
    # 变换类别1到类别2的位置（保持原始大小）
    if '_xyz' in class1_points:
        original_xyz = class1_points['_xyz']
        print(f"类别 {class1_id} 原始位置范围: X[{original_xyz[:, 0].min().item():.3f}, {original_xyz[:, 0].max().item():.3f}], "
              f"Y[{original_xyz[:, 1].min().item():.3f}, {original_xyz[:, 1].max().item():.3f}], "
              f"Z[{original_xyz[:, 2].min().item():.3f}, {original_xyz[:, 2].max().item():.3f}]")
        
        # 只应用平移，不改变大小
        transformed_xyz = apply_transformation(original_xyz, translation1_to_2, scale=[1.0, 1.0, 1.0])
        print(f"类别 {class1_id} 变换后位置范围: X[{transformed_xyz[:, 0].min().item():.3f}, {transformed_xyz[:, 0].max().item():.3f}], "
              f"Y[{transformed_xyz[:, 1].min().item():.3f}, {transformed_xyz[:, 1].max().item():.3f}], "
              f"Z[{transformed_xyz[:, 2].min().item():.3f}, {transformed_xyz[:, 2].max().item():.3f}]")
        
        # 创建变换后的类别1点
        class1_swapped = {}
        for k, v in class1_points.items():
            if k == '_xyz':
                class1_swapped[k] = transformed_xyz
            elif k not in ['centroid', 'indices']:
                class1_swapped[k] = v.clone()
        
        swapped_points.append(class1_swapped)
    
    # 变换类别2到类别1的位置（保持原始大小）
    if '_xyz' in class2_points:
        original_xyz = class2_points['_xyz']
        print(f"类别 {class2_id} 原始位置范围: X[{original_xyz[:, 0].min().item():.3f}, {original_xyz[:, 0].max().item():.3f}], "
              f"Y[{original_xyz[:, 1].min().item():.3f}, {original_xyz[:, 1].max().item():.3f}], "
              f"Z[{original_xyz[:, 2].min().item():.3f}, {original_xyz[:, 2].max().item():.3f}]")
        
        # 只应用平移，不改变大小
        transformed_xyz = apply_transformation(original_xyz, translation2_to_1, scale=[1.0, 1.0, 1.0])
        print(f"类别 {class2_id} 变换后位置范围: X[{transformed_xyz[:, 0].min().item():.3f}, {transformed_xyz[:, 0].max().item():.3f}], "
              f"Y[{transformed_xyz[:, 1].min().item():.3f}, {transformed_xyz[:, 1].max().item():.3f}], "
              f"Z[{transformed_xyz[:, 2].min().item():.3f}, {transformed_xyz[:, 2].max().item():.3f}]")
        
        # 创建变换后的类别2点
        class2_swapped = {}
        for k, v in class2_points.items():
            if k == '_xyz':
                class2_swapped[k] = transformed_xyz
            elif k not in ['centroid', 'indices']:
                class2_swapped[k] = v.clone()
        
        swapped_points.append(class2_swapped)
    
    print(f"成功创建 {len(swapped_points)} 个交换后的点集")
    return swapped_points, indices1, indices2

def merge_points_inplace_real_model(gaussian_model, swapped_points_list, indices_to_remove):
    """将交换后的点原位合并到真实 GaussianModel 中 - 先删除原始位置，再添加新位置"""
    if not swapped_points_list:
        print("没有需要合并的点")
        return
    with torch.no_grad():
        print(f"原始模型点数: {gaussian_model._xyz.shape[0]}")
        
        # 第一步：删除原始位置的物体
        if indices_to_remove:
            # 创建保留索引的掩码
            all_indices = torch.arange(gaussian_model._xyz.shape[0], device=gaussian_model._xyz.device)
            keep_mask = torch.ones_like(all_indices, dtype=torch.bool)
            
            for indices in indices_to_remove:
                if len(indices) > 0:
                    keep_mask[indices] = False
            
            keep_indices = torch.where(keep_mask)[0]
            print(f"删除 {gaussian_model._xyz.shape[0] - len(keep_indices)} 个原始点")
            
            # 保留未被删除的点
            gaussian_model._xyz = nn.Parameter(gaussian_model._xyz[keep_indices])
            gaussian_model._features_dc = nn.Parameter(gaussian_model._features_dc[keep_indices])
            gaussian_model._features_rest = nn.Parameter(gaussian_model._features_rest[keep_indices])
            gaussian_model._scaling = nn.Parameter(gaussian_model._scaling[keep_indices])
            gaussian_model._rotation = nn.Parameter(gaussian_model._rotation[keep_indices])
            gaussian_model._opacity = nn.Parameter(gaussian_model._opacity[keep_indices])
            
            print(f"删除后模型点数: {gaussian_model._xyz.shape[0]}")
        
        # 第二步：添加交换后的物体
        if swapped_points_list:
            # 收集要拼接的张量
            concat_map = {}
            for swapped in swapped_points_list:
                for k, v in swapped.items():
                    if not isinstance(v, torch.Tensor):
                        continue
                    concat_map.setdefault(k, []).append(v.to(gaussian_model._xyz.device))

            def cat_or_keep(name, base):
                tensors = concat_map.get(name, [])
                if not tensors:
                    return base
                return torch.cat([base, *tensors], dim=0)

            gaussian_model._xyz = nn.Parameter(cat_or_keep('_xyz', gaussian_model._xyz))
            gaussian_model._features_dc = nn.Parameter(cat_or_keep('_features_dc', gaussian_model._features_dc))
            gaussian_model._features_rest = nn.Parameter(cat_or_keep('_features_rest', gaussian_model._features_rest))
            gaussian_model._scaling = nn.Parameter(cat_or_keep('_scaling', gaussian_model._scaling))
            gaussian_model._rotation = nn.Parameter(cat_or_keep('_rotation', gaussian_model._rotation))
            gaussian_model._opacity = nn.Parameter(cat_or_keep('_opacity', gaussian_model._opacity))
            
            print(f"添加交换后的点，最终模型点数: {gaussian_model._xyz.shape[0]}")
        
        print(f"最终点云位置范围: X[{gaussian_model._xyz[:, 0].min().item():.3f}, {gaussian_model._xyz[:, 0].max().item():.3f}], "
              f"Y[{gaussian_model._xyz[:, 1].min().item():.3f}, {gaussian_model._xyz[:, 1].max().item():.3f}], "
              f"Z[{gaussian_model._xyz[:, 2].min().item():.3f}, {gaussian_model._xyz[:, 2].max().item():.3f}]")

def render_scene_views(scene, pipeline, background, output_dir, split='train'):
    """使用真实相机与渲染器渲染当前场景的视图"""
    from gaussian_renderer import render
    render_dir = os.path.join(output_dir, f"swapped_scene_{split}")
    os.makedirs(render_dir, exist_ok=True)
    views = scene.getTrainCameras() if split == 'train' else scene.getTestCameras()
    print(f"使用 {split} 视图渲染，共 {len(views)} 个…")
    
    successful_renders = 0
    failed_renders = 0
    
    for idx, view in enumerate(tqdm(views, desc=f"Rendering {split}")):
        try:
            render_pkg = render(view, scene.gaussians, pipeline, background)
            rendering = render_pkg["render"]
            
            # 检查渲染结果
            if rendering is None or rendering.numel() == 0:
                print(f"视图 {idx} 渲染结果为空")
                failed_renders += 1
                continue
                
            # 确保渲染结果是正确的形状
            if len(rendering.shape) == 4:  # (1, C, H, W)
                rendering = rendering.squeeze(0)  # 去掉batch维度
            elif len(rendering.shape) == 3 and rendering.shape[0] == 1:  # (1, H, W)
                rendering = rendering.squeeze(0)  # 去掉batch维度
                rendering = rendering.unsqueeze(0)  # 添加通道维度 (1, H, W)
                rendering = rendering.repeat(3, 1, 1)  # 转换为RGB (3, H, W)
            
            # 确保是3通道图像
            if len(rendering.shape) == 2:  # (H, W)
                rendering = rendering.unsqueeze(0).repeat(3, 1, 1)  # (3, H, W)
            elif rendering.shape[0] == 1:  # (1, H, W)
                rendering = rendering.repeat(3, 1, 1)  # (3, H, W)
            elif rendering.shape[0] == 4:  # (4, H, W) RGBA
                rendering = rendering[:3]  # 只取RGB通道
            
            # 转置为(H, W, C)格式
            if len(rendering.shape) == 3 and rendering.shape[0] in [1, 3, 4]:
                rendering = rendering.permute(1, 2, 0)
            
            save_image(rendering, os.path.join(render_dir, f"{idx:05d}.png"))
            successful_renders += 1
            
        except Exception as e:
            print(f"视图 {idx} 渲染失败: {e}")
            failed_renders += 1
            continue
    
    print(f"渲染完成: 成功 {successful_renders} 个，失败 {failed_renders} 个")

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description="场景类别位置交换器 - 交换两个类别的位置")
    
    # 基本参数
    parser.add_argument("--source_path", type=str, required=True, help="数据源路径")
    parser.add_argument("--model_path", type=str, required=True, help="模型路径")
    parser.add_argument("--iteration", type=int, required=True, help="模型迭代次数")
    parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
    
    # 交换配置
    parser.add_argument("--class1_id", type=int, required=True, help="第一个类别ID")
    parser.add_argument("--class2_id", type=int, required=True, help="第二个类别ID")
    parser.add_argument("--show_info", action="store_true", help="仅显示类别位置信息，不进行交换和渲染")
    
    # 获取参数
    args = parser.parse_args()
    
    print("=" * 60)
    print("场景类别位置交换器 - 交换两个类别的位置")
    print("=" * 60)
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 加载真实的高斯模型
    print(f"加载真实的高斯模型: {args.model_path}")
    try:
        # 导入必要的模块
        import sys
        sys.path.append('.')
        
        from scene import Scene, GaussianModel
        from arguments import ModelParams, PipelineParams
        
        # 创建模型参数
        parser = argparse.ArgumentParser()
        model_params = ModelParams(parser)
        model_params.model_path = args.model_path
        model_params.source_path = args.source_path
        model_params.images = "images"  # 默认图像目录
        model_params.eval = False       # 非评估模式
        model_params.white_background = False  # 非白色背景
        model_params.depths = ""        # 深度目录
        model_params.resolution = -1    # 分辨率
        model_params.data_device = "cuda"  # 数据设备
        
        # 创建高斯模型
        print("创建高斯模型...")
        gaussian_model = GaussianModel(model_params.sh_degree)
        print("加载场景...")
        scene = Scene(model_params, gaussian_model, load_iteration=args.iteration, shuffle=False)
        
        # 记录模型路径以便读取标签
        gaussian_model.source_path = args.model_path
        
        print(f"成功加载真实的高斯模型")
        print(f"高斯点数量: {gaussian_model._xyz.shape[0] if hasattr(gaussian_model, '_xyz') else 0}")
        
        # 验证数据完整性
        if gaussian_model._xyz.numel() == 0:
            print("错误: 高斯模型中没有有效的点云数据")
            return
            
    except ImportError as e:
        print(f"导入模块失败: {e}")
        print("请确保以下条件满足：")
        print("1. 在正确的目录中运行脚本")
        print("2. 已安装所有必要的依赖")
        print("3. 模型文件完整且未损坏")
        return
    except Exception as e:
        print(f"加载高斯模型失败: {e}")
        print("请确保以下条件满足：")
        print("1. 模型路径正确且包含检查点文件")
        print("2. 检查点文件存在且可读")
        print("3. 模型文件完整且未损坏")
        import traceback
        traceback.print_exc()
        return
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    print(f"\n交换配置:")
    print(f"  类别 {args.class1_id} <-> 类别 {args.class2_id}")
    
    # 执行位置交换
    if args.show_info:
        print("\n仅显示信息模式...")
        # 显示两个类别的信息
        load_class_points(gaussian_model, args.class1_id)
        load_class_points(gaussian_model, args.class2_id)
    else:
        result = swap_class_positions(gaussian_model, args.class1_id, args.class2_id)
        
        # 合并到真实模型并渲染
        if result is not None:
            swapped_points_list, indices1, indices2 = result
            print(f"\n执行真正的交换操作：删除原始位置，添加新位置…")
            merge_points_inplace_real_model(gaussian_model, swapped_points_list, [indices1, indices2])
            
            # 构建渲染管线与背景
            pp = PipelineParams(argparse.ArgumentParser())
            pp.convert_SHs_python = False
            pp.compute_cov3D_python = False
            pp.debug = False
            bg_color = [1, 1, 1] if not model_params.white_background else [1, 1, 1]
            background = torch.tensor(bg_color, dtype=torch.float32, device=gaussian_model._xyz.device)
            print(f"\n开始渲染…")
            render_scene_views(scene, pp, background, args.output_dir, split='train')
        else:
            print("没有成功交换的点，跳过渲染")
    
    print("\n场景类别位置交换完成！")

if __name__ == "__main__":
    main()
