import os
import torch
import numpy as np
import json
from argparse import ArgumentParser
from plyfile import PlyData, PlyElement

def load_ply_file(ply_file_path):
    """
    加载PLY文件，支持ASCII和二进制格式
    
    Args:
        ply_file_path: PLY文件路径
        
    Returns:
        points: 点云坐标 [N, 3]
        colors: 点云颜色 [N, 3] (如果存在)
        other_data: 其他属性数据字典
    """
    try:
        # 尝试使用plyfile库加载
        ply_data = PlyData.read(ply_file_path)
        vertex_data = ply_data['vertex'].data
        
        # 提取坐标
        points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
        
        # 提取颜色（如果存在）
        colors = None
        if 'red' in vertex_data.dtype.names and 'green' in vertex_data.dtype.names and 'blue' in vertex_data.dtype.names:
            colors = np.vstack([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T
        
        # 提取其他属性
        other_data = {}
        for name in vertex_data.dtype.names:
            if name not in ['x', 'y', 'z', 'red', 'green', 'blue']:
                other_data[name] = vertex_data[name]
        
        print(f"成功加载PLY文件: {ply_file_path}")
        print(f"  点云数量: {len(points)}")
        print(f"  颜色信息: {'是' if colors is not None else '否'}")
        print(f"  其他属性: {list(other_data.keys())}")
        
        return points, colors, other_data
        
    except Exception as e:
        print(f"使用plyfile加载PLY失败: {e}")
        print("尝试手动解析ASCII格式...")
        
        # 手动解析ASCII格式
        try:
            with open(ply_file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            # 解析头部
            header_end = 0
            num_vertices = 0
            has_color = False
            
            for i, line in enumerate(lines):
                line = line.strip()
                if line.startswith('element vertex'):
                    num_vertices = int(line.split()[-1])
                elif line.startswith('property') and ('red' in line or 'green' in line or 'blue' in line):
                    has_color = True
                elif line == 'end_header':
                    header_end = i + 1
                    break
            
            # 解析点云数据
            points = []
            colors = []
            
            for line in lines[header_end:header_end + num_vertices]:
                values = line.strip().split()
                if len(values) >= 3:
                    x, y, z = float(values[0]), float(values[1]), float(values[2])
                    points.append([x, y, z])
                    
                    if has_color and len(values) >= 6:
                        r, g, b = int(values[3]), int(values[4]), int(values[5])
                        colors.append([r, g, b])
            
            points = np.array(points)
            colors = np.array(colors) if colors else None
            
            print(f"手动解析成功: {len(points)} 个点")
            return points, colors, {}
            
        except Exception as e2:
            print(f"手动解析也失败: {e2}")
            return None, None, None

def load_pth_file(pth_file_path):
    """
    加载PTH文件中的点云索引
    
    Args:
        pth_file_path: PTH文件路径
        
    Returns:
        point_indices: 点云索引列表
        metadata: 元数据字典
    """
    try:
        data = torch.load(pth_file_path, map_location='cpu')
        
        print(f"成功加载PTH文件: {pth_file_path}")
        
        # 提取点云索引
        if 'point_indices' in data:
            point_indices = data['point_indices']
            if isinstance(point_indices, torch.Tensor):
                point_indices = point_indices.cpu().numpy()
            print(f"  点云索引数量: {len(point_indices)}")
        else:
            print("警告: PTH文件中未找到point_indices")
            return None, data
        
        # 提取元数据
        metadata = {}
        for key, value in data.items():
            if key != 'point_indices':
                if isinstance(value, torch.Tensor):
                    metadata[key] = value.cpu().numpy().tolist()
                else:
                    metadata[key] = value
        
        print(f"  元数据: {list(metadata.keys())}")
        
        return point_indices, metadata
        
    except Exception as e:
        print(f"加载PTH文件失败: {e}")
        return None, None

def create_colored_ply(points, colors, point_indices, output_path, 
                       foreground_color=[255, 0, 0], background_color=[255, 255, 255]):
    """
    创建带颜色的PLY文件
    
    Args:
        points: 原始点云坐标 [N, 3]
        colors: 原始点云颜色 [N, 3] (可选)
        point_indices: 前景点索引
        output_path: 输出PLY文件路径
        foreground_color: 前景点颜色 [R, G, B]
        background_color: 背景点颜色 [R, G, B]
    """
    # 创建新的颜色数组
    new_colors = np.full((len(points), 3), background_color, dtype=np.uint8)
    
    # 设置前景点颜色
    if point_indices is not None and len(point_indices) > 0:
        # 确保索引在有效范围内
        valid_indices = point_indices[point_indices < len(points)]
        new_colors[valid_indices] = foreground_color
        print(f"设置 {len(valid_indices)} 个前景点为红色")
    
    # 创建PLY数据
    vertex_data = []
    for i in range(len(points)):
        vertex = (points[i][0], points[i][1], points[i][2], 
                 new_colors[i][0], new_colors[i][1], new_colors[i][2])
        vertex_data.append(vertex)
    
    # 定义数据类型
    vertex_dtype = [
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
        ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')
    ]
    
    # 创建结构化数组
    vertex_array = np.array(vertex_data, dtype=vertex_dtype)
    
    # 创建PLY元素
    vertex_element = PlyElement.describe(vertex_array, 'vertex')
    
    # 创建PLY数据并保存
    ply_data = PlyData([vertex_element], text=True)
    ply_data.write(output_path)
    
    print(f"可视化PLY文件已保存到: {output_path}")

def main():
    parser = ArgumentParser(description="Visualize PTH point cloud data with PLY file")
    parser.add_argument("--pth_file", type=str, required=True,
                        help="Path to PTH file containing point indices")
    parser.add_argument("--ply_file", type=str, required=True,
                        help="Path to original PLY point cloud file")
    parser.add_argument("--output_dir", type=str, default="./evaluation_results",
                        help="Output directory for visualization results")
    parser.add_argument("--foreground_color", type=str, default="255,0,0",
                        help="Foreground point color (R,G,B), default: red")
    parser.add_argument("--background_color", type=str, default="255,255,255",
                        help="Background point color (R,G,B), default: white")
    
    args = parser.parse_args()
    
    # 解析颜色参数
    try:
        fg_color = [int(x) for x in args.foreground_color.split(',')]
        bg_color = [int(x) for x in args.background_color.split(',')]
    except ValueError:
        print("错误: 颜色格式应为 R,G,B (例如: 255,0,0)")
        return
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("=" * 60)
    print("PTH点云可视化脚本")
    print("=" * 60)
    print(f"PTH文件: {args.pth_file}")
    print(f"PLY文件: {args.ply_file}")
    print(f"前景颜色: RGB{fg_color}")
    print(f"背景颜色: RGB{bg_color}")
    print()
    
    # 1. 加载PLY文件
    print("步骤1: 加载PLY点云文件...")
    points, colors, other_data = load_ply_file(args.ply_file)
    if points is None:
        print("错误: 无法加载PLY文件")
        return
    
    # 2. 加载PTH文件
    print("\n步骤2: 加载PTH文件...")
    point_indices, metadata = load_pth_file(args.pth_file)
    if point_indices is None:
        print("错误: 无法加载PTH文件")
        return
    
    # 3. 生成输出文件名
    pth_basename = os.path.splitext(os.path.basename(args.pth_file))[0]
    ply_basename = os.path.splitext(os.path.basename(args.ply_file))[0]
    output_filename = f"{ply_basename}_{pth_basename}_visualization.ply"
    output_path = os.path.join(args.output_dir, output_filename)
    
    # 4. 创建可视化PLY文件
    print(f"\n步骤3: 创建可视化PLY文件...")
    create_colored_ply(points, colors, point_indices, output_path, fg_color, bg_color)
    
    # 5. 保存元数据
    metadata_file = os.path.join(args.output_dir, f"{pth_basename}_metadata.json")
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    print(f"元数据已保存到: {metadata_file}")
    
    # 6. 生成总结信息
    summary_file = os.path.join(args.output_dir, f"{pth_basename}_visualization_summary.txt")
    with open(summary_file, 'w', encoding='utf-8') as f:
        f.write("PTH点云可视化总结\n")
        f.write("=" * 40 + "\n\n")
        f.write(f"PTH文件: {args.pth_file}\n")
        f.write(f"PLY文件: {args.ply_file}\n")
        f.write(f"输出文件: {output_path}\n\n")
        f.write(f"原始点云数量: {len(points)}\n")
        f.write(f"前景点数量: {len(point_indices) if point_indices is not None else 0}\n")
        f.write(f"前景颜色: RGB{fg_color}\n")
        f.write(f"背景颜色: RGB{bg_color}\n\n")
        f.write("元数据信息:\n")
        for key, value in metadata.items():
            f.write(f"  {key}: {value}\n")
    
    print(f"总结信息已保存到: {summary_file}")
    
    print("\n" + "=" * 60)
    print("可视化完成！")
    print(f"主要输出文件: {output_path}")
    print(f"元数据文件: {metadata_file}")
    print(f"总结文件: {summary_file}")
    print("=" * 60)

if __name__ == "__main__":
    main()
