"""
将高效存储的点云文件转换为标准PLY格式

用法：
    # 转换单个文件
    python convert_npz_to_ply.py /path/to/000000_merged.npz --output output.ply
    
    # 批量转换整个目录
    python convert_npz_to_ply.py /path/to/point_clouds/ --format npz
"""

import argparse
import sys
from pathlib import Path
import numpy as np
import open3d as o3d
import torch
import gzip
import pickle

# 添加当前目录到路径以便导入
sys.path.insert(0, str(Path(__file__).parent))


def load_point_cloud_efficient(input_path: Path, format_type: str = "auto"):
    """
    加载高效存储的点云数据
    
    Returns:
        (points, colors): numpy数组
    """
    # 自动检测格式
    if format_type == "auto":
        if input_path.suffix == ".npz":
            format_type = "npz"
        elif input_path.suffix in [".pkl.gz", ".gz"]:
            format_type = "bf16"
        elif input_path.suffix == ".ply":
            format_type = "ply"
        else:
            format_type = "npz"  # 默认
    
    if format_type in ["npz", "npz_bf16", "npz_fp16"]:
        # npz_bf16 和 npz_fp16 实际上都是 .npz 文件（使用 float16 存储）
        data = np.load(str(input_path))
        points = data['points'].astype(np.float32)
        colors = data['colors'].astype(np.float32)
        return points, colors
        
    elif format_type in ["bf16", "bfloat16"]:
        # 从压缩的pickle文件加载（旧格式）
        with gzip.open(str(input_path), 'rb') as f:
            data = pickle.load(f)
        
        # 处理torch tensor或numpy数组
        if isinstance(data['points'], torch.Tensor):
            points = data['points'].to(torch.float32).numpy()
            colors = data['colors'].to(torch.float32).numpy()
        else:
            # numpy数组格式，转换为float32
            points = np.asarray(data['points']).astype(np.float32)
            colors = np.asarray(data['colors']).astype(np.float32)
        
        return points, colors
        
    elif format_type == "ply":
        pcd = o3d.io.read_point_cloud(str(input_path))
        points = np.asarray(pcd.points)
        colors = np.asarray(pcd.colors)
        return points, colors
    else:
        raise ValueError(f"不支持的存储格式: {format_type}")


def save_ply_file(points: np.ndarray, colors: np.ndarray, output_path: Path):
    """保存为标准PLY格式"""
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.io.write_point_cloud(str(output_path), pcd)


def convert_file(input_path: Path, output_path: Path = None, format_type: str = "auto"):
    """转换单个文件"""
    # 自动检测格式
    if format_type == "auto":
        if input_path.suffix == ".npz":
            format_type = "npz"
        elif input_path.suffix in [".pkl.gz", ".gz"]:
            format_type = "bf16"  # 支持bf16格式
        elif input_path.suffix == ".ply":
            format_type = "ply"
        else:
            raise ValueError(f"无法自动检测文件格式: {input_path}")
    
    # 确定输出路径
    if output_path is None:
        output_path = input_path.with_suffix(".ply")
    
    print(f"正在转换: {input_path} -> {output_path} (格式: {format_type})")
    
    # 加载点云
    points, colors = load_point_cloud_efficient(input_path, format_type)
    
    # 保存为PLY
    save_ply_file(points, colors, output_path)
    
    # 显示文件大小
    input_size = input_path.stat().st_size / (1024 * 1024)
    output_size = output_path.stat().st_size / (1024 * 1024)
    print(f"转换完成 (输入: {input_size:.2f} MB, 输出: {output_size:.2f} MB)")


def batch_convert_directory(directory: Path, format_type: str = "npz", output_dir: Path = None):
    """批量转换目录中的所有文件"""
    directory = Path(directory)
    
    # 确定要转换的文件
    if format_type in ["npz", "npz_bf16", "npz_fp16"]:
        input_files = list(directory.glob("*.npz"))
    elif format_type in ["bf16", "bfloat16"]:
        input_files = list(directory.glob("*.pkl.gz")) + list(directory.glob("*.gz"))
    elif format_type == "ply":
        input_files = []
    else:
        raise ValueError(f"不支持的格式: {format_type}")
    
    if len(input_files) == 0:
        print(f"在目录 {directory} 中未找到 {format_type} 文件")
        return
    
    print(f"找到 {len(input_files)} 个文件需要转换")
    
    # 确定输出目录
    if output_dir is None:
        output_dir = directory / "ply_format"
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 转换每个文件
    for input_file in sorted(input_files):
        output_file = output_dir / input_file.name.replace(input_file.suffix, ".ply")
        try:
            convert_file(input_file, output_file, format_type)
        except Exception as e:
            print(f"转换失败 {input_file}: {e}")
    
    print(f"\n批量转换完成，输出目录: {output_dir}")


def main():
    parser = argparse.ArgumentParser(description="将高效存储的点云转换为标准PLY格式")
    parser.add_argument("input", type=str, help="输入文件或目录路径")
    parser.add_argument("-o", "--output", type=str, default=None, help="输出文件路径")
    parser.add_argument("-f", "--format", type=str, default="npz_fp16", 
                       choices=["auto", "npz", "bf16", "bfloat16", "npz_bf16", "ply"],
                       help="输入文件格式（默认自动检测）")
    parser.add_argument("-d", "--directory", action="store_true",
                       help="批量转换整个目录")
    
    args = parser.parse_args()
    
    input_path = Path(args.input)
    
    if not input_path.exists():
        print(f"错误: 路径不存在: {input_path}")
        return
    
    if input_path.is_dir() or args.directory:
        # 批量转换目录
        batch_convert_directory(input_path, args.format, args.output)
    else:
        # 转换单个文件
        convert_file(input_path, Path(args.output) if args.output else None, args.format)


if __name__ == '__main__':
    main()

