import torch
from typing import Tuple, List
import trimesh
import numpy as np
import open3d as o3d
import torch.nn.functional as F
import torch.nn as nn
import os
from collections import OrderedDict
from scipy.spatial import cKDTree

def get_mesh_query_points(
    faces: torch.Tensor,
    vertices: torch.Tensor,
    n_samples: int = 1024,
    noise_std: float = 0.1,
    distance_thresh: float = 1. / 256.,
    device='cpu',
    quantize_bits: int = 10,
    using_nerf: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a mix of on-surface and near-surface points from a mesh.
    Labels are computed using appropriate coordinates:
    - using_nerf=True: use raw float query_pts
    - using_nerf=False: use dequantized float coords from quantized ints
    """
    mesh = trimesh.Trimesh(
        vertices=vertices.cpu().numpy().astype(np.float32),
        faces=faces.cpu().numpy()
    )

    all_pts, _ = trimesh.sample.sample_surface(mesh, n_samples)
    all_pts = torch.tensor(all_pts, dtype=torch.float32, device=device)  

    n_surface = n_samples // 2
    surface_pts = all_pts[:n_surface]
    near_surface_pts = all_pts[n_surface:]

    noise = torch.randn_like(near_surface_pts) * noise_std
    near_surface_pts = near_surface_pts + noise

    query_pts = torch.cat([surface_pts, near_surface_pts], dim=0)  

    if using_nerf:
        dist_coords = query_pts  
    else:
        
        min_vals = query_pts.min(dim=0, keepdim=True)[0]  
        max_vals = query_pts.max(dim=0, keepdim=True)[0]  
        range_vals = (max_vals - min_vals).clamp(min=1e-6)
        normalized = (query_pts - min_vals) / range_vals
        quantized = (normalized * (2 ** quantize_bits - 1)).long()


        dequantized = quantized.float() / (2**quantize_bits - 1)
        dequantized = dequantized * (max_vals - min_vals) + min_vals

        dist_coords = dequantized  
        final_coords = quantized  

    
    _, dists, _ = mesh.nearest.on_surface(dist_coords.cpu().numpy().astype(np.float32))
    dists = torch.tensor(dists, dtype=torch.float32, device=device)
    labels = (dists < distance_thresh).long()  
    
    if using_nerf:
        
        
        return query_pts, labels  
    else:
        return final_coords, labels  

def get_mesh_query_points_udf(
    faces: torch.Tensor,
    vertices: torch.Tensor,
    n_samples: int = 1024,
    noise_std: float = 0.1,
    udf_max: float = 1. / 256., 
    device='cpu',
    quantize_bits: int = 10,
    using_nerf: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a mix of on-surface and near-surface points from a mesh.
    Return UDF labels (unsigned distance function) with values:
        [0, udf_max] for points within udf_max
        udf_max for points farther than udf_max
    """
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces
    )

    all_pts, _ = trimesh.sample.sample_surface(mesh, n_samples)
    all_pts = torch.tensor(all_pts, dtype=torch.float32, device=device)

    n_surface = n_samples // 2
    surface_pts = all_pts[:n_surface]
    near_surface_pts = all_pts[n_surface:]

    noise = torch.randn_like(near_surface_pts) * noise_std
    near_surface_pts = near_surface_pts + noise

    query_pts = torch.cat([surface_pts, near_surface_pts], dim=0)

    if using_nerf:
        dist_coords = query_pts
    else:
        min_vals = query_pts.min(dim=0, keepdim=True)[0]
        max_vals = query_pts.max(dim=0, keepdim=True)[0]
        range_vals = (max_vals - min_vals).clamp(min=1e-6)
        normalized = (query_pts - min_vals) / range_vals
        quantized = (normalized * (2 ** quantize_bits - 1)).long()
        dequantized = quantized.float() / (2**quantize_bits - 1)
        dequantized = dequantized * (max_vals - min_vals) + min_vals
        dist_coords = dequantized

        final_coords = torch.clamp(quantized, 0, 2 ** quantize_bits - 1)

    _, dists, _ = mesh.nearest.on_surface(dist_coords.cpu().numpy().astype(np.float32))
    dists = torch.tensor(dists, dtype=torch.float32, device=device)
    
    udf_labels = torch.clamp(dists, max=udf_max)

    if using_nerf:
        
        
        return query_pts, udf_labels
        
    else:
        return final_coords, udf_labels


def quantize_vertices(vertices: torch.Tensor, res: int, return_transform: bool = False):
    """
    Robustly quantizes floating point vertices of ARBITRARY range to integer
    voxel indices in [0, res-1].

    This function first normalizes the vertices to a standard [-0.5, 0.5]
    bounding box before quantization.

    Args:
        vertices (torch.Tensor): Input coordinates of shape [N, 3].
        res (int): The resolution of the voxel grid.
        return_transform (bool): If True, also returns the center and scale
                                 used for normalization.

    Returns:
        torch.Tensor: Integer voxel indices of shape [N, 3] in the range [0, res-1].
        (optional) dict: A dictionary {'center': torch.Tensor, 'scale': float}
                         if return_transform is True.
    """
    min_coords, _ = torch.min(vertices, dim=0, keepdim=True) 
    max_coords, _ = torch.max(vertices, dim=0, keepdim=True) 

    center = (min_coords + max_coords) / 2.0
    size = max_coords - min_coords

    scale = torch.max(size) + 1e-8

    centered_vertices = vertices - center
    normalized_vertices = centered_vertices / scale

    shifted_vertices = normalized_vertices + 0.5
    scaled_for_quantize = shifted_vertices * res
    quantized = torch.floor(scaled_for_quantize).clamp(0, res - 1).int()

    if return_transform:
        transform = {'center': center.squeeze(0), 'scale': scale.item()}
        return quantized, transform
    else:
        return quantized


import numpy as np
from numba import njit

@njit(cache=True)
def bresenham_3d_array(p1, p2):
    x1, y1, z1 = np.round(p1).astype(np.int64)
    x2, y2, z2 = np.round(p2).astype(np.int64)

    dx, dy, dz = abs(x2 - x1), abs(y2 - y1), abs(z2 - z1)
    xs, ys, zs = (1 if x2 > x1 else -1), (1 if y2 > y1 else -1), (1 if z2 > z1 else -1)

    npts = max(dx, dy, dz) + 1
    points = np.zeros((npts, 3), dtype=np.int64)

    i = 0
    if dx >= dy and dx >= dz:
        err_1, err_2 = 2*dy - dx, 2*dz - dx
        for _ in range(dx+1):
            points[i, 0], points[i, 1], points[i, 2] = x1, y1, z1
            i += 1
            if err_1 > 0:
                y1 += ys; err_1 -= 2*dx
            if err_2 > 0:
                z1 += zs; err_2 -= 2*dx
            err_1 += 2*dy
            err_2 += 2*dz
            x1 += xs
    elif dy >= dx and dy >= dz:  
        err_1, err_2 = 2*dx - dy, 2*dz - dy
        for _ in range(dy+1):
            points[i, 0], points[i, 1], points[i, 2] = x1, y1, z1
            i += 1
            if err_1 > 0:
                x1 += xs; err_1 -= 2*dy
            if err_2 > 0:
                z1 += zs; err_2 -= 2*dy
            err_1 += 2*dx
            err_2 += 2*dz
            y1 += ys
    else:  
        err_1, err_2 = 2*dx - dz, 2*dy - dz
        for _ in range(dz+1):
            points[i, 0], points[i, 1], points[i, 2] = x1, y1, z1
            i += 1
            if err_1 > 0:
                x1 += xs; err_1 -= 2*dz
            if err_2 > 0:
                y1 += ys; err_2 -= 2*dz
            err_1 += 2*dx
            err_2 += 2*dy
            z1 += zs

    return points[:i]

def get_voxel_lines_batched(p1s, p2s):

    device = p1s.device
    num_edges = p1s.shape[0]
    
    diffs = p2s - p1s
    d_norm = torch.norm(diffs, dim=1)
    d_norm_sq = (diffs ** 2).sum(dim=1)
    
    num_steps = (d_norm * 2).int() + 1
    
    total_samples = num_steps.sum().item()
    
    edge_indices = torch.repeat_interleave(torch.arange(num_edges, device=device), num_steps)
    
    step_starts = torch.cumsum(num_steps, dim=0) - num_steps
    local_indices = torch.arange(total_samples, device=device) - step_starts[edge_indices]
    
    steps_float = num_steps[edge_indices].float()
    t_vals = local_indices.float() / torch.clamp(steps_float - 1, min=1.0)
    
    sample_points = p1s[edge_indices] + t_vals.unsqueeze(1) * diffs[edge_indices]
    
    voxel_coords = torch.floor(sample_points).long()
    
    combined = torch.cat([voxel_coords, edge_indices.unsqueeze(1)], dim=1)
    
    unique_combined = torch.unique(combined, dim=0, sorted=True)
    
    unique_voxels = unique_combined[:, :3]      
    unique_edge_indices = unique_combined[:, 3] 
    
    
    v_centers = unique_voxels.float() + 0.5
    
    
    
    cur_p1 = p1s[unique_edge_indices]
    cur_diff = diffs[unique_edge_indices]
    cur_d_sq = d_norm_sq[unique_edge_indices]
    
    
    
    t_proj = ((v_centers - cur_p1) * cur_diff).sum(dim=1) / torch.clamp(cur_d_sq, min=1e-12)
    t_proj = torch.clamp(t_proj, 0.0, 1.0)
    
    closest_points = cur_p1 + t_proj.unsqueeze(1) * cur_diff
    error_vectors = v_centers - closest_points 
    
    
    
    
    
    unique_voxels_np = unique_voxels.cpu().numpy()
    error_vectors_np = error_vectors.cpu().numpy()
    
    
    
    
    
    
    
    
    target_p1 = p1s[unique_edge_indices].cpu()
    target_p2 = p2s[unique_edge_indices].cpu()
    
    
    
    all_edge_voxels = [tuple(v) for v in unique_voxels_np]
    edge_errors = [v for v in error_vectors_np]
    edge_endpoints = list(zip(target_p1, target_p2))

    return all_edge_voxels, edge_endpoints, edge_errors


def get_voxel_line(
    p1: torch.Tensor, 
    p2: torch.Tensor,
    mode: str = 'cpu' 
) -> Tuple[
    List[Tuple[int, int, int]], 
    List[Tuple[torch.Tensor, torch.Tensor]], 
    List[np.ndarray]
]:
    if mode == 'cuda':
        
        diff = p2 - p1
        d_norm = torch.norm(diff)
        d_norm_sq = torch.dot(diff, diff)
        
        
        
        num_steps = int(d_norm.item() * 2) + 1
        
        
        t_vals = torch.linspace(0, 1, steps=num_steps, device=p1.device)
        
        
        sample_points = p1.unsqueeze(0) + t_vals.unsqueeze(1) * diff.unsqueeze(0)
        
        
        voxel_coords_tensor = torch.floor(sample_points).long()
        
        voxel_coords_unique, _ = torch.unique(voxel_coords_tensor, sorted=False, return_inverse=True, dim=0)
        
        
        if voxel_coords_unique.shape[0] == 0:
            return [], [], []

        
        v_centers = voxel_coords_unique.float() + 0.5
        
        
        if d_norm_sq < 1e-12:
            closest_points = p1.unsqueeze(0).expand(voxel_coords_unique.shape[0], -1)
        else:
            
            
            t_proj = torch.matmul(v_centers - p1, diff) / d_norm_sq
            t_proj = torch.clamp(t_proj, 0.0, 1.0)
            
            
            closest_points = p1 + t_proj.unsqueeze(1) * diff
            
        error_vectors_tensor = v_centers - closest_points

        
        
        
        voxel_coords_np = voxel_coords_unique.cpu().numpy()
        voxel_coords = [tuple(coord) for coord in voxel_coords_np]
        
        num_voxels = len(voxel_coords)
        endpoint_pairs = [(p1, p2)] * num_voxels
        
        
        error_vectors_np = error_vectors_tensor.cpu().numpy()
        error_vectors = [vec for vec in error_vectors_np]
        
        return voxel_coords, endpoint_pairs, error_vectors

    
    
    
    else:
        p1_np = p1.cpu().numpy()
        p2_np = p2.cpu().numpy()

        
        voxel_coords_np = bresenham_3d_array(p1_np, p2_np)
        
        if voxel_coords_np.shape[0] == 0:
            return [], [], []

        v_centers = voxel_coords_np.astype(np.float64) + 0.5
        diff = p2_np - p1_np
        d_norm_sq = np.dot(diff, diff)
        
        if d_norm_sq < 1e-12:
            closest_points = np.tile(p1_np, (voxel_coords_np.shape[0], 1))
        else:
            t = np.dot(v_centers - p1_np, diff) / d_norm_sq
            t = np.clip(t, 0.0, 1.0)
            closest_points = p1_np + t[:, np.newaxis] * diff
            
        error_vectors_np = v_centers - closest_points
        
        voxel_coords = [tuple(coord) for coord in voxel_coords_np]
        num_voxels = len(voxel_coords)
        endpoint_pairs = [(p1, p2)] * num_voxels
        error_vectors = [vec for vec in error_vectors_np]
        
        return voxel_coords, endpoint_pairs, error_vectors


class AdaptiveFocalLoss(nn.Module):
    def __init__(self, gamma=2.0, max_alpha=10.0):
        super().__init__()
        self.gamma = gamma
        self.max_alpha = max_alpha
        
    def forward(self, inputs, targets):
        pos = targets.sum()
        neg = len(targets) - pos
        alpha = min(neg / (pos + 1e-6), self.max_alpha)
        BCE_loss = F.binary_cross_entropy_with_logits(
            inputs, targets, reduction='none',
        )
        
        pt = torch.exp(-BCE_loss)
        F_loss = alpha * (1 - pt) ** self.gamma * BCE_loss

        return F_loss.mean()
    
class DiceLoss(nn.Module):
    """
    A class-based implementation of the Dice Loss.
    This can be used to replace other loss functions like BCE, Focal, etc.
    It's particularly effective for handling class imbalance.
    """
    def __init__(self, smooth=1.0):
        """
        Initializes the DiceLoss.

        Args:
            smooth (float): A smoothing constant to prevent division by zero.
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        """
        Computes the Dice Loss.

        Args:
            inputs (torch.Tensor): The raw logits from the network output.
                                   Shape: [B, C, H, W, D] or [B, N]
            targets (torch.Tensor): The ground truth tensor (0s and 1s).
                                   Shape: [B, C, H, W, D] or [B, N]

        Returns:
            torch.Tensor: The computed Dice Loss.
        """
        
        inputs = torch.sigmoid(inputs)
        
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        
        intersection = (inputs * targets).sum()                            
        total = (inputs.sum() + targets.sum())
        
        
        dice_score = (2. * intersection + self.smooth) / (total + self.smooth)
        
        return 1 - dice_score

class AsymmetricFocalLoss(nn.Module):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricFocalLoss, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        
        self.xs_pos = torch.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)

        
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        final_loss = -self.loss.sum() 

        if True:
            num_positives = y.sum() 
            
            if num_positives == 0:
                return torch.tensor(0.0, device=x.device) 
            else:
                final_loss = final_loss / num_positives
        
        return final_loss

class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, alpha: float = 0.25, reduction: str = 'mean'):
        """
        Standard Focal Loss for binary classification.
        Args:
            gamma: Focusing parameter.
            alpha: Balancing weight for positive samples.
            reduction: 'none' | 'mean' | 'sum'
        """
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        """
        inputs: raw logits [N]
        targets: binary ground-truth [N]
        """
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)  
        alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)

        loss = alpha_t * ((1 - pt).clamp(min=1e-6) ** self.gamma) * BCE_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


def load_pretrained(checkpoint_path, vertex_encoder, vae, query_decoder=None, edge_predictor=None, optimizer=None):
    """
    Loads a pretrained model checkpoint flexibly, keeping the original function signature.

    This function loads weights for each model component separately, using `strict=False`
    to ignore missing or mismatched keys. This allows reusing older checkpoints with
    newer, modified model architectures (e.g., a VAE with new layers).

    Args:
        checkpoint_path (str): Path to the .pt or .pth checkpoint file.
        vertex_encoder (torch.nn.Module): The vertex encoder model.
        vae (torch.nn.Module): The VAE model.
        query_decoder (torch.nn.Module): The query point decoder model.
        edge_predictor (torch.nn.Module, optional): The edge predictor model. Defaults to None.
        optimizer (torch.optim.Optimizer, optional): If provided, the function will
                                                     attempt to load its state.

    Returns:
        dict: A dictionary containing metadata from the checkpoint, such as 'epoch'
              and 'best_loss'.
    """
    if not os.path.exists(checkpoint_path):
        
        print(f"[INFO] Checkpoint not found at '{checkpoint_path}'. Models will start from scratch.")
        return {'epoch': 0, 'best_loss': float('inf')}
    
    print(f"Loading pretrained models from: {os.path.basename(checkpoint_path)}")
    device = vertex_encoder.parameters().__next__().device  
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    
    def _load_and_log(model, model_name, state_dict):
        if state_dict is None:
            print(f"[WARN] No state found for '{model_name}' in checkpoint. It will remain initialized from scratch.")
            return

        current_state = model.state_dict()
        filtered_state = {}
        for k, v in state_dict.items():
            if k in current_state and v.shape == current_state[k].shape:
                filtered_state[k] = v
            else:
                print(f"[INFO] Skip loading {k}: checkpoint {tuple(v.shape)} "
                    f"!= model {tuple(current_state.get(k, torch.empty(0)).shape)}")

        missing_keys, unexpected_keys = model.load_state_dict(filtered_state, strict=False)
        
        print(f"--- Loading status for '{model_name}' ---")
        if not missing_keys and not unexpected_keys:
            print("Success: All weights loaded perfectly.")
        else:
            if missing_keys:
                print(f"[INFO] {len(missing_keys)} keys were missing in the checkpoint (new layers in your model):")
                
                print(f"       Examples: {', '.join(missing_keys[:3])}{'...' if len(missing_keys) > 3 else ''}")
            if unexpected_keys:
                print(f"[WARN] {len(unexpected_keys)} keys from the checkpoint were not found in the model (old layers):")
                print(f"       Examples: {', '.join(unexpected_keys[:3])}{'...' if len(unexpected_keys) > 3 else ''}")
    
    
    _load_and_log(vertex_encoder, 'VertexEncoder', checkpoint.get('encoder'))
    _load_and_log(vae, 'VoxelVAE', checkpoint.get('vae'))
    if query_decoder is not None:
        _load_and_log(query_decoder, 'QueryPointDecoder', checkpoint.get('decoder'))
    
    if edge_predictor is not None:
        _load_and_log(edge_predictor, 'EdgePredictor', checkpoint.get('edge_predictor'))

    
    if optimizer is not None and 'optimizer' in checkpoint:
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("\n[INFO] Successfully loaded optimizer state.")
        except ValueError as e:
            print(f"\n[WARN] Could not load optimizer state. It will be reset. Error: {e}")
            
    original_epoch = checkpoint.get('epoch', 'unknown')
    print(f"\nSuccessfully processed checkpoint. Original epoch was {original_epoch}.")
    
    
    return {
        'epoch': checkpoint.get('epoch', 0),
        'best_loss': checkpoint.get('best_loss', float('inf'))
    }

def load_pretrained_woself(
    checkpoint_path,
    vae,
    vertex_encoder=None,
    voxel_encoder=None,
    edge_encoder=None,
    query_decoder=None,
    active_encoder=None,
    connection_head=None,
    optimizer=None,
    ema_model=None,
):
    """
    Loads a pretrained model checkpoint flexibly, keeping the original function signature.

    Args:
        checkpoint_path (str): Path to the .pt or .pth checkpoint file.
        vertex_encoder (torch.nn.Module): The vertex encoder model.
        vae (torch.nn.Module): The VAE model.
        edge_encoder (torch.nn.Module, optional): The edge encoder model.
        query_decoder (torch.nn.Module, optional): The query point decoder model.
        active_encoder (torch.nn.Module, optional): The active voxel encoder model. 
        optimizer (torch.optim.Optimizer, optional): If provided, the function will
                                                     attempt to load its state.

    Returns:
        dict: Metadata from the checkpoint, such as 'epoch' and 'best_loss'.
    """
    if not os.path.exists(checkpoint_path):
        print(f"[INFO] Checkpoint not found at '{checkpoint_path}'. Models will start from scratch.")
        return {'epoch': 0, 'best_loss': float('inf')}
    
    print(f"Loading pretrained models from: {os.path.basename(checkpoint_path)}")
    
    try:
        
        device = next(iter(vae.parameters())).device
        checkpoint = torch.load(checkpoint_path, map_location=device)
    except Exception as e:
        print(f"[ERROR] Failed to load checkpoint {checkpoint_path}. Error: {e}")
        return {'epoch': 0, 'best_loss': float('inf')}
    
    
    def _load_and_log(model, model_name, state_dict):
        if state_dict is None:
            print(f"[WARN] No state found for '{model_name}' in checkpoint. It will remain initialized from scratch.")
            return
        current_state = model.state_dict()
        filtered_state = {}
        for k, v in state_dict.items():
            if k in current_state and v.shape == current_state[k].shape:
                filtered_state[k] = v
            else:
                print(f"[INFO] Skip loading {k}: checkpoint {tuple(v.shape)} "
                      f"!= model {tuple(current_state.get(k, torch.empty(0)).shape)}")
        missing_keys, unexpected_keys = model.load_state_dict(filtered_state, strict=False)
        print(f"--- Loading status for '{model_name}' ---")
        if not missing_keys and not unexpected_keys:
            print("Success: All weights loaded perfectly.")
        else:
            if missing_keys:
                print(f"[INFO] {len(missing_keys)} keys missing (new layers in model): "
                      f"{', '.join(missing_keys[:3])}{'...' if len(missing_keys)>3 else ''}")
            if unexpected_keys:
                print(f"[WARN] {len(unexpected_keys)} unexpected keys (old layers removed): "
                      f"{', '.join(unexpected_keys[:3])}{'...' if len(unexpected_keys)>3 else ''}")
    
    
    
    _load_and_log(vae,           'VoxelVAE',      checkpoint.get('vae'))
    
    if query_decoder is not None:
        _load_and_log(query_decoder, 'QueryPointDecoder', checkpoint.get('query_decoder'))
    
    if edge_encoder is not None:
        _load_and_log(edge_encoder,  'EdgeEncoder',       checkpoint.get('edge_encoder'))
    
    if active_encoder is not None:
        _load_and_log(active_encoder,'ActiveEncoder',     checkpoint.get('active_encoder'))

    if connection_head is not None:
        _load_and_log(connection_head,'connection_head',     checkpoint.get('connection_head'))

    if voxel_encoder is not None:
        _load_and_log(voxel_encoder, 'VoxelEncoder',      checkpoint.get('voxel_encoder'))

    
    if optimizer is not None and 'optimizer' in checkpoint:
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("\n[INFO] Successfully loaded optimizer state.")
        except ValueError as e:
            print(f"\n[WARN] Could not load optimizer state. It will be reset. Error: {e}")

    if ema_model is not None:
        if 'ema_state_dict' in checkpoint:
            try:
                
                
                ema_model.load_state_dict(checkpoint['ema_state_dict'])
                ema_model.to(device)
                print("\n[INFO] Successfully loaded EMA state.")
            except Exception as e:
                print(f"\n[WARN] Failed to load EMA state (Shapes mismatch or other error). EMA will start fresh. Error: {e}")
        else:
            print("\n[INFO] No 'ema_state_dict' found in checkpoint. EMA will start fresh.")

    original_epoch = checkpoint.get('epoch', 'unknown')
    best_loss = checkpoint.get('best_loss', checkpoint.get('loss', float('inf')))
    print(f"\nSuccessfully processed checkpoint. Original epoch was {original_epoch}, best loss {best_loss:.4f}.")
    
    return {
        'epoch': checkpoint.get('epoch', 0),
        'best_loss': best_loss
    }


def fast_isin(target_coords: torch.Tensor, query_coords: torch.Tensor, resolution: int) -> torch.Tensor:
    device = target_coords.device
    if target_coords.shape[-1] == 4:
        weight = torch.tensor([resolution ** 3, resolution ** 2, resolution, 1], device=device)

    elif target_coords.shape[-1] == 3:
        weight = torch.tensor([resolution ** 3, resolution ** 2, resolution], device=device)

    target_hash = (target_coords * weight).sum(dim=1)
    query_hash = (query_coords * weight).sum(dim=1)

    
    query_hash_unique = torch.unique(query_hash)
    query_hash_sorted, indices = torch.sort(query_hash_unique)

    idx = torch.searchsorted(query_hash_sorted, target_hash)
    in_bounds = (idx < len(query_hash_sorted))
    matches = torch.zeros_like(target_hash, dtype=torch.bool)
    matches[in_bounds] = query_hash_sorted[idx[in_bounds]] == target_hash[in_bounds]
    
    return matches


def get_mesh_edge_query_points(
    faces: torch.Tensor,
    vertices: torch.Tensor,
    n_samples: int = 1024,
    noise_std: float = 0.1,
    distance_thresh: float = 1. / 256.,
    device='cpu',
    quantize_bits: int = 10,
    using_nerf: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a mix of on-edge and near-edge points from a mesh with optimizations.

    Improvements:
    1.  **Weighted Edge Sampling**: Edges are sampled with probabilities proportional to their lengths.
    2.  **Vectorized Operations**: All loops for sampling and distance calculation are replaced with
        fast, vectorized PyTorch/NumPy operations.

    Args:
        faces: Tensor of mesh faces [N_faces, 3].
        vertices: Tensor of mesh vertices [N_vertices, 3].
        n_samples: Total number of points to sample.
        noise_std: Standard deviation for near-edge point noise.
        distance_thresh: Threshold to classify a point as "on-edge".
        device: The torch device to use.
        quantize_bits: Bit depth for quantization (if not using_nerf).
        using_nerf: If True, returns float coordinates. If False, returns quantized coordinates.

    Returns:
        A tuple of (query_points, labels).
        - query_points: [n_samples, 3] tensor of sampled point coordinates.
        - labels: [n_samples] tensor of labels (1 for on-edge, 0 for near-edge).
    """
    
    vertices_np = vertices.cpu().numpy().astype(np.float32)
    faces_np = faces.cpu().numpy()
    
    mesh = trimesh.Trimesh(vertices=vertices_np, faces=faces_np)

    
    edges = mesh.edges_unique  
    edge_vertices = mesh.vertices[edges]  

    
    
    edge_lengths = np.linalg.norm(edge_vertices[:, 0, :] - edge_vertices[:, 1, :], axis=1)
    
    if np.sum(edge_lengths) < 1e-6:
        
        edge_probabilities = None
    else:
        edge_probabilities = edge_lengths / np.sum(edge_lengths)

    
    n_edge = n_samples // 2
    n_near = n_samples - n_edge

    
    sampled_edge_indices = np.random.choice(
        len(edges),
        size=n_edge,
        p=edge_probabilities
    )
    
    
    sampled_edge_verts = torch.from_numpy(edge_vertices[sampled_edge_indices]).to(device, dtype=torch.float32)
    v_starts, v_ends = sampled_edge_verts[:, 0], sampled_edge_verts[:, 1]

    
    
    t = torch.rand(n_edge, 1, device=device)
    on_edge_pts = (1 - t) * v_starts + t * v_ends 

    
    
    
    indices_for_noise = torch.randperm(n_edge)[:n_near]
    noise = torch.randn(n_near, 3, device=device) * noise_std
    near_edge_pts = on_edge_pts[indices_for_noise] + noise

    query_pts = torch.cat([on_edge_pts, near_edge_pts], dim=0) 

    
    if using_nerf:
        dist_coords = query_pts
    else:
        min_vals = query_pts.min(dim=0, keepdim=True)[0]
        max_vals = query_pts.max(dim=0, keepdim=True)[0]
        range_vals = (max_vals - min_vals).clamp(min=1e-6)
        
        normalized = (query_pts - min_vals) / range_vals
        quantized = (normalized * (2**quantize_bits - 1)).round().long()

        dequantized = quantized.float() / (2**quantize_bits - 1)
        dist_coords = dequantized * range_vals + min_vals
        final_coords = quantized

    
    
    edge_starts_t = torch.from_numpy(edge_vertices[:, 0]).to(device, dtype=torch.float32) 
    edge_ends_t = torch.from_numpy(edge_vertices[:, 1]).to(device, dtype=torch.float32) 

    
    
    P = dist_coords.unsqueeze(1)
    A = edge_starts_t.unsqueeze(0)
    B = edge_ends_t.unsqueeze(0)

    
    ab = B - A  
    ap = P - A  
    
    
    
    t_dist = (ap * ab).sum(-1) / ((ab * ab).sum(-1) + 1e-8) 
    t_clamped = t_dist.clamp(0, 1)

    
    projection = A + t_clamped.unsqueeze(-1) * ab 

    
    
    dists = (P - projection).norm(dim=-1)

    
    min_dists, _ = dists.min(dim=1) 

    
    labels = (min_dists < distance_thresh).long()

    if using_nerf:
        return query_pts, labels
    else:
        return final_coords, labels
