import os
import torch
import numpy as np
import json
from argparse import ArgumentParser
import glob

def load_pth_file(pth_file_path):
    """
    Load point cloud labels from PTH file
    
    Args:
        pth_file_path: PTH file path
        
    Returns:
        point_labels: Point cloud label list (0=background, 1=foreground)
        metadata: Metadata dictionary
    """
    try:
        data = torch.load(pth_file_path, map_location='cpu')
        
        print(f"Successfully loaded PTH file: {pth_file_path}")
        
        point_labels = None
        metadata = {}
        
        # Check data structure type
        if isinstance(data, dict):
            # If dictionary, look for point_indices
            if 'point_indices' in data:
                point_labels = data['point_indices']
                if isinstance(point_labels, torch.Tensor):
                    point_labels = point_labels.cpu().numpy()
                print(f"  Point cloud label count: {len(point_labels)}")
            else:
                print("Warning: point_indices not found in PTH file")
                return None, data
            
            # Extract metadata
            for key, value in data.items():
                if key != 'point_indices':
                    if isinstance(value, torch.Tensor):
                        metadata[key] = value.cpu().numpy().tolist()
                    else:
                        metadata[key] = value
                        
        elif isinstance(data, torch.Tensor):
            # If directly a tensor, assume the entire tensor is point cloud labels
            point_labels = data.cpu().numpy()
            print(f"  Point cloud label count: {len(point_labels)} (direct tensor)")
            metadata['data_type'] = 'direct_tensor'
            metadata['tensor_shape'] = list(data.shape)
            metadata['tensor_dtype'] = str(data.dtype)
            
            # Check if tensor is empty or all zeros
            if len(point_labels) == 0:
                print("Warning: Tensor is empty")
                return None, metadata
            elif np.all(point_labels == 0):
                print("Warning: Tensor is all zeros, may have no valid foreground points")
                # Still continue processing
                
        else:
            # Other types
            print(f"Warning: Unknown data type {type(data)}")
            return None, {'data_type': str(type(data)), 'data': str(data)}
        
        print(f"  Metadata: {list(metadata.keys())}")
        
        return point_labels, metadata
        
    except Exception as e:
        print(f"Failed to load PTH file: {e}")
        import traceback
        traceback.print_exc()
        return None, None

def load_point_cloud_from_ply(model_path, scene_name=None):
    """
    Load original point cloud coordinates from PLY file
    
    Args:
        model_path: Model path
        scene_name: Scene name (optional)
        
    Returns:
        points: Point cloud coordinates [N, 3]
    """
    try:
        # Try to find PLY file from model path
        if scene_name:
            ply_file = os.path.join(model_path, f"{scene_name}.ply")
        else:
            # Find PLY files under model path
            ply_files = glob.glob(os.path.join(model_path, "*.ply"))
            if not ply_files:
                # Try to find PLY files in subdirectories
                for root, dirs, files in os.walk(model_path):
                    ply_files = [f for f in files if f.endswith('.ply')]
                    if ply_files:
                        ply_file = os.path.join(root, ply_files[0])
                        break
                else:
                    raise FileNotFoundError(f"No PLY files found in model path {model_path}")
            else:
                ply_file = ply_files[0]
        
        if not os.path.exists(ply_file):
            raise FileNotFoundError(f"PLY文件不存在: {ply_file}")
        
        print(f"Loading PLY file: {ply_file}")
        
        # Try to load using plyfile library
        try:
            from plyfile import PlyData
            print("Loading PLY file using plyfile library...")
            ply_data = PlyData.read(ply_file)
            vertex_data = ply_data['vertex'].data
            
            # Extract coordinates
            points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
            
        except ImportError:
            print("plyfile library not available, trying manual ASCII parsing...")
            points = load_ply_ascii(ply_file)
        
        print(f"Successfully loaded point cloud: {len(points)} points")
        print(f"Point cloud coordinate range: X[{points[:, 0].min():.3f}, {points[:, 0].max():.3f}]")
        print(f"              Y[{points[:, 1].min():.3f}, {points[:, 1].max():.3f}]")
        print(f"              Z[{points[:, 2].min():.3f}, {points[:, 2].max():.3f}]")
        
        return points
        
    except Exception as e:
        print(f"Failed to load PLY file: {e}")
        return None

def load_ply_ascii(ply_file_path):
    """
    手动解析ASCII格式的PLY文件
    
    Args:
        ply_file_path: PLY文件路径
        
    Returns:
        points: 点云坐标 [N, 3]
    """
    try:
        with open(ply_file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        # 解析头部
        header_end = 0
        num_vertices = 0
        
        for i, line in enumerate(lines):
            line = line.strip()
            if line.startswith('element vertex'):
                num_vertices = int(line.split()[-1])
            elif line == 'end_header':
                header_end = i + 1
                break
        
        # 解析点云数据
        points = []
        for line in lines[header_end:header_end + num_vertices]:
            values = line.strip().split()
            if len(values) >= 3:
                x, y, z = float(values[0]), float(values[1]), float(values[2])
                points.append([x, y, z])
        
        return np.array(points, dtype=np.float32)
        
    except Exception as e:
        print(f"手动解析PLY文件失败: {e}")
        return None

def create_ply_from_point_cloud(point_labels, point_cloud, output_path, 
                               foreground_color=[255, 0, 0], background_color=[128, 128, 128]):
    """
    根据点云和标签创建PLY文件
    
    Args:
        point_labels: 点云标签（0=背景，1=前景）
        point_cloud: 原始点云坐标 [N, 3]
        output_path: 输出PLY文件路径
        foreground_color: 前景点颜色 [R, G, B] (红色)
        background_color: 背景点颜色 [R, G, B] (灰色)
    """
    if point_labels is None or len(point_labels) == 0:
        print("警告: 没有点云标签，将创建空的PLY文件")
        return
    
    if point_cloud is None or len(point_cloud) == 0:
        print("警告: 没有点云坐标，将创建空的PLY文件")
        return
    
    # 确保标签和点云数量匹配
    if len(point_labels) != len(point_cloud):
        print(f"错误: 标签数量 ({len(point_labels)}) 与点云数量 ({len(point_cloud)}) 不匹配")
        return
    
    print(f"生成 {len(point_cloud)} 个点的PLY文件")
    print(f"前景点数量: {(point_labels == 1).sum()}")
    print(f"背景点数量: {(point_labels == 0).sum()}")
    
    # 直接写入PLY文件（ASCII格式）
    with open(output_path, 'w') as f:
        # 写入PLY头部
        f.write("ply\n")
        f.write("format ascii 1.0\n")
        f.write(f"element vertex {len(point_cloud)}\n")
        f.write("property float x\n")
        f.write("property float y\n")
        f.write("property float z\n")
        f.write("property uchar red\n")
        f.write("property uchar green\n")
        f.write("property uchar blue\n")
        f.write("end_header\n")
        
        # 写入点云数据
        for i in range(len(point_cloud)):
            # 根据标签设置颜色
            if point_labels[i] == 1:  # 前景点
                color = foreground_color
            else:  # 背景点
                color = background_color
            
            # 写入坐标和颜色
            f.write(f"{point_cloud[i][0]:.6f} {point_cloud[i][1]:.6f} {point_cloud[i][2]:.6f} "
                   f"{color[0]} {color[1]} {color[2]}\n")
    
    print(f"PLY文件已保存到: {output_path}")

def process_single_pth(pth_file_path, model_path, output_dir, foreground_color, background_color, scene_name=None):
    """
    处理单个PTH文件
    
    Args:
        pth_file_path: PTH文件路径
        model_path: 模型路径
        output_dir: 输出目录
        foreground_color: 前景点颜色
        background_color: 背景点颜色
        scene_name: 场景名称（可选）
    """
    print(f"\n处理文件: {pth_file_path}")
    
    # 加载PTH文件
    point_labels, metadata = load_pth_file(pth_file_path)
    if point_labels is None:
        print(f"跳过文件 {pth_file_path} (加载失败)")
        return False
    
    # 加载点云
    point_cloud = load_point_cloud_from_ply(model_path, scene_name)
    if point_cloud is None:
        print(f"跳过文件 {pth_file_path} (无法加载点云)")
        return False
    
    # 生成输出文件名
    pth_basename = os.path.splitext(os.path.basename(pth_file_path))[0]
    output_filename = f"{pth_basename}_converted.ply"
    output_path = os.path.join(output_dir, output_filename)
    
    # 创建PLY文件
    create_ply_from_point_cloud(point_labels, point_cloud, output_path, foreground_color, background_color)
    
    # 保存元数据
    metadata_file = os.path.join(output_dir, f"{pth_basename}_metadata.json")
    with open(metadata_file, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    
    print(f"元数据已保存到: {metadata_file}")
    return True

def main():
    parser = ArgumentParser(description="Convert PTH files to PLY files using original point cloud")
    parser.add_argument("--input_dir", type=str, required=True,
                        help="Directory containing PTH files")
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to model directory containing PLY files")
    parser.add_argument("--output_dir", type=str, default="./converted_ply_files",
                        help="Output directory for converted PLY files")
    parser.add_argument("--foreground_color", type=str, default="255,0,0",
                        help="Foreground point color (R,G,B), default: red")
    parser.add_argument("--background_color", type=str, default="150,150,150",
                        help="Background point color (R,G,B), default: gray")
    parser.add_argument("--scene_name", type=str, default=None,
                        help="Scene name for PLY file (optional)")
    parser.add_argument("--file_pattern", type=str, default="*.pth",
                        help="File pattern to match PTH files, default: *.pth")
    
    args = parser.parse_args()
    
    # 解析颜色参数
    try:
        fg_color = [int(x) for x in args.foreground_color.split(',')]
        bg_color = [int(x) for x in args.background_color.split(',')]
    except ValueError:
        print("错误: 颜色格式应为 R,G,B (例如: 255,0,0)")
        return
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    print("=" * 60)
    print("PTH到PLY转换器 (使用原始点云)")
    print("=" * 60)
    print(f"输入目录: {args.input_dir}")
    print(f"模型路径: {args.model_path}")
    print(f"输出目录: {args.output_dir}")
    print(f"文件模式: {args.file_pattern}")
    print(f"场景名称: {args.scene_name if args.scene_name else '自动检测'}")
    print(f"前景颜色: RGB{fg_color}")
    print(f"背景颜色: RGB{bg_color}")
    print()
    
    # 查找所有PTH文件
    search_pattern = os.path.join(args.input_dir, args.file_pattern)
    pth_files = glob.glob(search_pattern)
    
    if not pth_files:
        print(f"在目录 {args.input_dir} 中未找到匹配 {args.file_pattern} 的文件")
        return
    
    print(f"找到 {len(pth_files)} 个PTH文件:")
    for f in pth_files:
        print(f"  {f}")
    print()
    
    # 处理每个PTH文件
    success_count = 0
    for pth_file in pth_files:
        if process_single_pth(pth_file, args.model_path, args.output_dir, fg_color, bg_color, args.scene_name):
            success_count += 1
    
    # 生成总结信息
    summary_file = os.path.join(args.output_dir, "conversion_summary.txt")
    with open(summary_file, 'w', encoding='utf-8') as f:
        f.write("PTH到PLY转换总结 (使用原始点云)\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"输入目录: {args.input_dir}\n")
        f.write(f"模型路径: {args.model_path}\n")
        f.write(f"输出目录: {args.output_dir}\n")
        f.write(f"文件模式: {args.file_pattern}\n")
        f.write(f"场景名称: {args.scene_name if args.scene_name else '自动检测'}\n")
        f.write(f"总文件数: {len(pth_files)}\n")
        f.write(f"成功转换: {success_count}\n")
        f.write(f"失败数量: {len(pth_files) - success_count}\n\n")
        f.write(f"前景颜色: RGB{fg_color}\n")
        f.write(f"背景颜色: RGB{bg_color}\n\n")
        f.write("转换的文件:\n")
        for pth_file in pth_files:
            pth_basename = os.path.splitext(os.path.basename(pth_file))[0]
            f.write(f"  {pth_file} -> {pth_basename}_converted.ply\n")
    
    print("\n" + "=" * 60)
    print("转换完成！")
    print(f"总文件数: {len(pth_files)}")
    print(f"成功转换: {success_count}")
    print(f"失败数量: {len(pth_files) - success_count}")
    print(f"输出目录: {args.output_dir}")
    print(f"总结文件: {summary_file}")
    print("=" * 60)

if __name__ == "__main__":
    main()
