#!/usr/bin/env python3

import os
import torch
import numpy as np
import json
from argparse import ArgumentParser

def load_point_cloud_data(scene_name, model_path, output_dir, ply_file_path=None):
    """
    加载点云数据和分割结果
    
    Args:
        scene_name: 场景名称
        model_path: FlexiGaussian模型路径
        output_dir: 输出目录
        ply_file_path: 指定的PLY文件路径（可选）
        
    Returns:
        points: 原始点云坐标 [N, 3]
        segmentation_results: 分割结果字典
    """
    print(f"加载场景 {scene_name} 的数据...")
    
    # 1. 尝试加载原始PLY文件
    if ply_file_path and os.path.exists(ply_file_path):
        # 使用指定的PLY文件路径
        ply_file = ply_file_path
        print(f"使用指定的PLY文件: {ply_file}")
    else:
        # 尝试从模型路径加载
        ply_file = os.path.join(model_path, f"{scene_name}.ply")
        if os.path.exists(ply_file):
            print(f"从模型路径加载PLY文件: {ply_file}")
        else:
            print(f"错误: 未找到PLY文件: {ply_file}")
            print("请确保场景的点云数据文件存在且路径正确")
            print("点云文件应该位于模型路径下，命名为: {scene_name}.ply")
            print("\n可能的解决方案:")
            print("1. 检查场景名称是否正确")
            print("2. 确认模型路径下是否存在对应的PLY文件")
            print("3. 如果PLY文件在其他位置，请使用 --ply_file 参数指定完整路径")
            print("4. 确保PLY文件包含正确的点云坐标数据")
            raise FileNotFoundError(f"场景 {scene_name} 的点云数据文件不存在: {ply_file}")
    
    # 2. 检测PLY文件格式并加载点云数据
    points = load_ply_file(ply_file)
    
    # 3. 加载分割结果
    results_dir = os.path.join(output_dir, f"{scene_name}_clip_results")
    json_file = os.path.join(results_dir, f"{scene_name}_class_matching_results.json")
    
    if not os.path.exists(json_file):
        raise FileNotFoundError(f"未找到分割结果文件: {json_file}")
    
    with open(json_file, 'r', encoding='utf-8') as f:
        segmentation_results = json.load(f)
    
    print(f"处理 {len(segmentation_results)} 个类别的分割结果...")
    
    # 4. 验证数据一致性
    if len(points) == 0:
        raise ValueError(f"点云数据为空，无法进行分割可视化")
    
    print(f"数据验证完成:")
    print(f"  - 点云坐标数量: {len(points)}")
    print(f"  - 分割结果类别数: {len(segmentation_results)}")
    
    return points, segmentation_results

def load_ply_file(ply_file_path):
    """
    加载PLY文件（支持ASCII和二进制格式）
    
    Args:
        ply_file_path: PLY文件路径
        
    Returns:
        points: 点云坐标 [N, 3]
    """
    print(f"正在加载PLY文件: {ply_file_path}")
    
    # 首先尝试使用plyfile库（如果可用）
    try:
        from plyfile import PlyData
        print("使用plyfile库加载PLY文件...")
        ply_data = PlyData.read(ply_file_path)
        vertex_data = ply_data['vertex'].data
        
        # 提取坐标
        points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
        print(f"成功使用plyfile加载点云，形状: {points.shape}")
        return points
        
    except ImportError:
        print("plyfile库不可用，尝试手动解析...")
    except Exception as e:
        print(f"plyfile加载失败: {e}")
        print("尝试手动解析...")
    
    # 手动解析PLY文件
    try:
        # 检测文件格式
        with open(ply_file_path, 'rb') as f:
            header = f.read(1024).decode('ascii', errors='ignore')
        
        if 'format ascii' in header:
            print("检测到ASCII格式PLY文件")
            return load_ascii_ply(ply_file_path)
        elif 'format binary' in header:
            print("检测到二进制格式PLY文件")
            print("警告: 二进制PLY文件需要plyfile库支持")
            print("请安装plyfile库: pip install plyfile")
            raise ValueError("二进制PLY文件需要plyfile库支持")
        else:
            print("无法确定PLY文件格式")
            raise ValueError("无法确定PLY文件格式")
            
    except Exception as e:
        print(f"手动解析PLY文件失败: {e}")
        raise

def load_ascii_ply(ply_file_path):
    """
    手动解析ASCII格式的PLY文件
    
    Args:
        ply_file_path: PLY文件路径
        
    Returns:
        points: 点云坐标 [N, 3]
    """
    print("手动解析ASCII PLY文件...")
    
    points = []
    with open(ply_file_path, 'r', encoding='utf-8', errors='ignore') as f:
        lines = f.readlines()
    
    # 解析PLY文件头
    header_end = 0
    num_vertices = 0
    
    for i, line in enumerate(lines):
        line = line.strip()
        if line == 'end_header':
            header_end = i
            break
        elif line.startswith('element vertex'):
            num_vertices = int(line.split()[-1])
    
    print(f"PLY文件头信息: 顶点数 = {num_vertices}")
    
    # 读取点云数据
    for line in lines[header_end + 1:]:
        line = line.strip()
        if line:
            parts = line.split()
            if len(parts) >= 3:
                try:
                    x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
                    points.append([x, y, z])
                except ValueError:
                    continue
    
    points = np.array(points)
    print(f"成功解析ASCII PLY文件，点云形状: {points.shape}")
    
    if len(points) != num_vertices:
        print(f"警告: 解析的点数 {len(points)} 与头部声明的顶点数 {num_vertices} 不匹配")
    
    return points

def create_colored_ply_for_class(points, class_point_indices, class_name, scene_name, output_dir):
    """
    为单个类别创建带颜色的PLY文件
    
    Args:
        points: 点云坐标 [N, 3]
        class_point_indices: 该类别的点索引列表
        class_name: 类别名称
        scene_name: 场景名称
        output_dir: 输出目录
    """
    print(f"为类别 '{class_name}' 创建PLY文件...")
    
    # 创建分割掩码
    segmentation_mask = np.zeros(len(points), dtype=bool)
    
    if isinstance(class_point_indices, list):
        class_point_indices = np.array(class_point_indices)
    
    # 确保索引在有效范围内
    valid_indices = class_point_indices[class_point_indices < len(points)]
    if len(valid_indices) > 0:
        segmentation_mask[valid_indices] = True
        print(f"  类别 {class_name}: {len(valid_indices)} 个点")
    else:
        print(f"  类别 {class_name}: 0 个点（跳过）")
        return None
    
    # 创建颜色数组
    colors = np.zeros((len(points), 3), dtype=np.uint8)
    
    # 所有点云默认显示为白色 (255, 255, 255)
    colors[:] = [255, 255, 255]
    
    # 被标注的点显示为红色 (255, 0, 0)
    colors[segmentation_mask] = [255, 0, 0]
    
    # 保存PLY文件
    safe_class_name = class_name.replace(' ', '_').replace('/', '_')
    output_ply = os.path.join(output_dir, f"{scene_name}_{safe_class_name}_segmentation.ply")
    
    with open(output_ply, 'w') as f:
        # 写入PLY文件头
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {len(points)}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        f.write("property uchar red\n")
        f.write("property uchar green\n")
        f.write("property uchar blue\n")
        f.write("end_header\n")
        
        # 写入点云数据
        for i in range(len(points)):
            x, y, z = points[i]
            r, g, b = colors[i]
            f.write(f"{x:.6f} {y:.6f} {z:.6f} {r} {g} {b}\n")
    
    print(f"  PLY文件已保存到: {output_ply}")
    
    # 统计信息
    red_points = np.sum(segmentation_mask)
    white_points = len(points) - red_points
    
    print(f"  颜色统计: 红色点 {red_points} 个, 白色点 {white_points} 个")
    
    return output_ply

def create_visualization_summary(points, segmentation_results, scene_name, output_dir, generated_files):
    """
    创建可视化总结信息
    
    Args:
        points: 点云坐标 [N, 3]
        segmentation_results: 分割结果字典
        scene_name: 场景名称
        output_dir: 输出目录
        generated_files: 生成的文件列表
    """
    print(f"创建可视化总结...")
    
    # 创建总结文件
    summary_file = os.path.join(output_dir, f"{scene_name}_visualization_summary.txt")
    
    with open(summary_file, 'w', encoding='utf-8') as f:
        f.write(f"点云可视化总结 - {scene_name}\n")
        f.write("=" * 50 + "\n\n")
        
        f.write(f"点云总数: {len(points)}\n")
        f.write(f"处理类别数: {len(segmentation_results)}\n\n")
        
        f.write("各类别统计:\n")
        total_labeled_points = 0
        for class_id, result in segmentation_results.items():
            class_name = result['class_name']
            point_indices = result['point_indices']
            
            if isinstance(point_indices, list):
                point_indices = np.array(point_indices)
            
            valid_indices = point_indices[point_indices < len(points)]
            labeled_count = len(valid_indices)
            total_labeled_points += labeled_count
            
            f.write(f"  {class_name}: {labeled_count} 个点\n")
        
        f.write(f"\n总标注点数: {total_labeled_points}\n")
        f.write(f"标注比例: {total_labeled_points/len(points)*100:.2f}%\n\n")
        
        f.write("坐标范围:\n")
        f.write(f"  X: {points[:, 0].min():.3f} ~ {points[:, 0].max():.3f}\n")
        f.write(f"  Y: {points[:, 1].min():.3f} ~ {points[:, 1].max():.3f}\n")
        f.write(f"  Z: {points[:, 2].min():.3f} ~ {points[:, 2].max():.3f}\n\n")
        
        f.write("颜色说明:\n")
        f.write("  - 红色 (255, 0, 0): 被CLIP分割标注的点\n")
        f.write("  - 白色 (255, 255, 255): 未被标注的点\n\n")
        
        f.write("生成的文件:\n")
        for file_path in generated_files:
            if file_path:
                f.write(f"  - {os.path.basename(file_path)}\n")
    
    print(f"可视化总结已保存到: {summary_file}")

def main():
    parser = ArgumentParser(description="将FlexiGaussian分割结果转换为带颜色的PLY文件（每个类别一个文件）")
    parser.add_argument("--scene_name", type=str, required=True,
                        help="场景名称")
    parser.add_argument("--model_path", type=str, required=True,
                        help="FlexiGaussian模型路径")
    parser.add_argument("--output_dir", type=str, default="./evaluation_results",
                        help="输出目录")
    parser.add_argument("--ply_file", type=str, default=None,
                        help="原始PLY文件路径（如果不在model_path下）")
    
    args = parser.parse_args()
    
    try:
        # 1. 加载数据
        points, segmentation_results = load_point_cloud_data(
            args.scene_name, args.model_path, args.output_dir, args.ply_file
        )
        
        # 2. 验证分割结果与点云的一致性
        print(f"\n验证数据一致性...")
        max_index = 0
        for class_id, result in segmentation_results.items():
            point_indices = result['point_indices']
            if isinstance(point_indices, list):
                point_indices = np.array(point_indices)
            if len(point_indices) > 0:
                max_index = max(max_index, point_indices.max())
        
        if max_index >= len(points):
            raise ValueError(f"数据不一致错误: 分割结果中的最大索引 {max_index} 超出了点云数量 {len(points)}")
        
        print(f"数据一致性验证通过: 最大索引 {max_index} < 点云数量 {len(points)}")
        
        # 3. 为每个类别创建PLY文件
        generated_files = []
        total_labeled_points = 0
        
        for class_id, result in segmentation_results.items():
            class_name = result['class_name']
            point_indices = result['point_indices']
            
            # 创建该类别的PLY文件
            output_ply = create_colored_ply_for_class(
                points, point_indices, class_name, args.scene_name, args.output_dir
            )
            
            if output_ply:
                generated_files.append(output_ply)
                # 统计该类别的点数
                if isinstance(point_indices, list):
                    point_indices = np.array(point_indices)
                valid_indices = point_indices[point_indices < len(points)]
                total_labeled_points += len(valid_indices)
        
        # 4. 创建可视化总结
        create_visualization_summary(
            points, segmentation_results, args.scene_name, args.output_dir, generated_files
        )
        
        print(f"\n=== 可视化转换完成 ===")
        print(f"场景: {args.scene_name}")
        print(f"生成PLY文件数: {len(generated_files)}")
        print(f"总标注点数: {total_labeled_points}")
        print(f"点云总数: {len(points)}")
        print(f"标注比例: {total_labeled_points/len(points)*100:.2f}%")
        
    except Exception as e:
        print(f"可视化转换过程中出现错误: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
