"""
SDF Mesh Generation utilities - adapted from SIREN
"""
import logging
import os
import time
import numpy as np
import plyfile
import skimage.measure
import torch


def create_mesh(decoder, filename=None, N=256, max_batch=64**3, level=0):
    """
    Create mesh from SDF decoder
    
    Args:
        decoder: Neural network that takes coordinates and returns SDF values
        filename: Output filename (without .ply extension)
        N: Resolution (N^3 voxels)
        max_batch: Maximum batch size for inference
        level: ISO-surface level (0 for surface)
    """
    start = time.time()
    ply_filename = filename
    if filename is not None:
        os.makedirs(os.path.dirname(filename), exist_ok=True)
    decoder.eval()

    # Get device from decoder
    device = next(decoder.parameters()).device

    # Voxel grid setup
    voxel_origin = [-0.5] * 3
    voxel_size = -2 * voxel_origin[0] / (N - 1)

    overall_index = torch.arange(0, N**3, 1, out=torch.LongTensor())
    samples = torch.zeros(N**3, 4)

    # Transform indices to coordinates
    samples[:, 2] = overall_index % N
    samples[:, 1] = (overall_index.long() / N) % N
    samples[:, 0] = ((overall_index.long() / N) / N) % N

    # Convert to world coordinates
    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
    
    num_samples = N**3
    samples.requires_grad = False

    # Process in batches
    head = 0
    print(f"Generating mesh with {N}^3 = {num_samples:,} samples...")
    
    while head < num_samples:
        sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].to(device)
        
        with torch.no_grad():
            sdf_vals = decoder(sample_subset).squeeze()
            if len(sdf_vals.shape) == 0:  # Handle single sample
                sdf_vals = sdf_vals.unsqueeze(0)
            samples[head : min(head + max_batch, num_samples), 3] = sdf_vals.detach().cpu()
        
        head += max_batch
        if head % (max_batch * 10) == 0:
            print(f"  Progress: {head:,}/{num_samples:,} ({100*head/num_samples:.1f}%)")

    sdf_values = samples[:, 3]
    sdf_values = sdf_values.reshape(N, N, N)
    
    # Apply symmetry along y-axis (typical for cars)
    sdf_values = (sdf_values + torch.flip(sdf_values, [1])) / 2.0
    
    end = time.time()
    print(f"SDF sampling took: {end - start:.2f}s")

    return convert_sdf_samples_to_ply(
        sdf_values.data.cpu(),
        voxel_origin,
        voxel_size,
        None if ply_filename is None else ply_filename + ".ply",
        level,
    )


def convert_sdf_samples_to_ply(
    pytorch_3d_sdf_tensor,
    voxel_grid_origin,
    voxel_size,
    ply_filename_out,
    level
):
    """Convert SDF samples to PLY mesh file"""
    start_time = time.time()
    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
    
    verts, faces, normals, values = (
        np.zeros((0, 3)),
        np.zeros((0, 3)),
        np.zeros((0, 3)),
        np.zeros(0),
    )
    
    try:
        verts, faces, normals, values = skimage.measure.marching_cubes(
            numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3
        )
        print(f"Generated mesh: {len(verts)} vertices, {len(faces)} faces")
    except Exception as e:
        print(f"Error in marching cubes: {e}")
        # Create placeholder mesh
        if ply_filename_out:
            with open(ply_filename_out, 'w') as f:
                f.write("ply\nformat ascii 1.0\nelement vertex 0\nelement face 0\nend_header\n")
        return np.zeros((0, 3)), np.zeros((0, 3)), pytorch_3d_sdf_tensor

    # Transform from voxel coordinates to world coordinates
    mesh_points = np.zeros_like(verts)
    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]  
    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]

    # Write PLY file
    if ply_filename_out is not None:
        num_verts = verts.shape[0]
        num_faces = faces.shape[0]

        verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
        for i in range(0, num_verts):
            verts_tuple[i] = tuple(mesh_points[i, :])

        faces_building = []
        for i in range(0, num_faces):
            faces_building.append(((faces[i, :].tolist(),)))
        faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])

        el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
        el_faces = plyfile.PlyElement.describe(faces_tuple, "face")

        ply_data = plyfile.PlyData([el_verts, el_faces])
        ply_data.write(ply_filename_out)
        print(f"Saved mesh: {ply_filename_out}")

    print(f"Mesh conversion took: {time.time() - start_time:.2f}s")
    return mesh_points, faces, pytorch_3d_sdf_tensor
