"""
DrAction: Differentiable Rendering of Actions
==============================================
A differentiable skeleton renderer based on 3D Gaussian Splatting for MLLM integration.

Core Components:
1. Gaussian Primitive Representation - each joint/bone modeled as 3D Gaussians
2. Linear Blend Skinning (LBS) - deform canonical Gaussians with skeletal motion  
3. Neural Feature Modulator (NFM) - pose-conditioned appearance modeling
4. Differentiable Rasterization - enables end-to-end gradient flow from MLLM
"""

import math
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================================
# Kinect-25 Skeleton Topology (bone connectivity)
# ============================================================================
PAIRS: List[Tuple[int, int]] = [
    (1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), (8, 7),
    (9, 21), (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14),
    (16, 15), (17, 1), (18, 17), (19, 18), (20, 19), (22, 23), (23, 8),
    (24, 25), (25, 12)
]


# ============================================================================
# Quaternion Operations (w, x, y, z format)
# ============================================================================
def _normalize_quat(q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """Normalize quaternion to unit length."""
    return q / (q.norm(dim=-1, keepdim=True) + eps)


def _quat_to_mat(q: torch.Tensor) -> torch.Tensor:
    """Convert quaternion (w,x,y,z) to 3x3 rotation matrix."""
    q = _normalize_quat(q)
    w, x, y, z = q.unbind(-1)
    ww, xx, yy, zz = w * w, x * x, y * y, z * z
    xy, xz, yz = x * y, x * z, y * z
    wx, wy, wz = w * x, w * y, w * z
    R = torch.stack([
        ww + xx - yy - zz, 2 * (xy - wz), 2 * (xz + wy),
        2 * (xy + wz), ww - xx + yy - zz, 2 * (yz - wx),
        2 * (xz - wy), 2 * (yz + wx), ww - xx - yy + zz
    ], dim=-1).reshape(q.shape[:-1] + (3, 3))
    return R


def _mat_to_quat(R: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """Convert 3x3 rotation matrix to quaternion (w,x,y,z)."""
    m00, m01, m02 = R[...,0,0], R[...,0,1], R[...,0,2]
    m10, m11, m12 = R[...,1,0], R[...,1,1], R[...,1,2]
    m20, m21, m22 = R[...,2,0], R[...,2,1], R[...,2,2]
    trace = m00 + m11 + m22

    qw = torch.zeros_like(trace)
    qx = torch.zeros_like(trace)
    qy = torch.zeros_like(trace)
    qz = torch.zeros_like(trace)

    cond0 = trace > 0
    S0 = torch.sqrt(torch.clamp(trace[cond0] + 1.0, min=eps)) * 2
    qw[cond0] = 0.25 * S0
    qx[cond0] = (m21[cond0] - m12[cond0]) / S0
    qy[cond0] = (m02[cond0] - m20[cond0]) / S0
    qz[cond0] = (m10[cond0] - m01[cond0]) / S0

    cond1 = (~cond0) & (m00 >= m11) & (m00 >= m22)
    S1 = torch.sqrt(torch.clamp(1.0 + m00[cond1] - m11[cond1] - m22[cond1], min=eps)) * 2
    qw[cond1] = (m21[cond1] - m12[cond1]) / S1
    qx[cond1] = 0.25 * S1
    qy[cond1] = (m01[cond1] + m10[cond1]) / S1
    qz[cond1] = (m02[cond1] + m20[cond1]) / S1

    cond2 = (~cond0) & (~cond1) & (m11 >= m22)
    S2 = torch.sqrt(torch.clamp(1.0 + m11[cond2] - m00[cond2] - m22[cond2], min=eps)) * 2
    qw[cond2] = (m02[cond2] - m20[cond2]) / S2
    qx[cond2] = (m01[cond2] + m10[cond2]) / S2
    qy[cond2] = 0.25 * S2
    qz[cond2] = (m12[cond2] + m21[cond2]) / S2

    cond3 = (~cond0) & (~cond1) & (~cond2)
    S3 = torch.sqrt(torch.clamp(1.0 + m22[cond3] - m00[cond3] - m11[cond3], min=eps)) * 2
    qw[cond3] = (m10[cond3] - m01[cond3]) / S3
    qx[cond3] = (m02[cond3] + m20[cond3]) / S3
    qy[cond3] = (m12[cond3] + m21[cond3]) / S3
    qz[cond3] = 0.25 * S3

    return _normalize_quat(torch.stack([qw, qx, qy, qz], dim=-1))


def _project_to_rotation(R: torch.Tensor) -> torch.Tensor:
    """Project matrix to valid SO(3) rotation via SVD polar decomposition."""
    U, _, Vh = torch.linalg.svd(R)
    R_ = U @ Vh
    det = torch.det(R_)
    fix = det < 0
    if fix.any():
        Vh[fix, -1, :] *= -1
        R_ = U @ Vh
    return R_


def _compose_T(R: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """Compose SE(3) transformation matrix from rotation R and translation t."""
    T = torch.eye(4, device=R.device, dtype=R.dtype).expand(R.shape[:-2] + (4, 4)).clone()
    T[..., :3, :3] = R
    T[..., :3, 3] = t
    return T


# ============================================================================
# DrAction: Differentiable Skeleton Renderer
# ============================================================================
class DifferentiableSkeletonRenderer(nn.Module):
    """
    Differentiable skeleton renderer using 3D Gaussian Splatting.
    
    Key features:
    - Models human pose as deformable 3D Gaussian primitives
    - Uses Linear Blend Skinning (LBS) for pose-dependent deformation
    - Neural Feature Modulator (NFM) for adaptive appearance
    - Fully differentiable for end-to-end MLLM optimization
    """
    
    def __init__(
        self,
        num_gaussians: int,
        num_joints: int,
        feature_dim: int,
        H: int,
        W: int,
        temporal_stride: int = 4,
        use_temporal_gru: bool = False,
        enable_nfm: bool = False
    ):
        super().__init__()
        self.num_gaussians = num_gaussians
        self.num_joints = num_joints
        self.feature_dim = feature_dim
        self.H, self.W = H, W
        self.temporal_stride = int(max(1, temporal_stride))
        self.enable_nfm = bool(enable_nfm)

        # ====== Canonical Gaussian Parameters (buffers) ======
        # Position, scale, rotation, opacity in canonical (rest) pose
        self.register_buffer('canonical_joints', torch.zeros(num_joints, 3))
        self.register_buffer('canonical_means', torch.zeros(num_gaussians, 3))
        self.register_buffer('canonical_scales', torch.ones(num_gaussians, 3) * 0.02)
        self.register_buffer('canonical_quats', torch.tensor([1, 0, 0, 0.], dtype=torch.float32).repeat(num_gaussians, 1))
        self.register_buffer('canonical_opacities', torch.ones(num_gaussians, 1))

        # ====== Learnable Parameters ======
        # Appearance feature vector for each Gaussian
        self.canonical_features = nn.Parameter(torch.randn(num_gaussians, feature_dim) * 0.01)
        # LBS blend weights (softmax over joints)
        self.register_buffer('lbs_weights_logits', torch.zeros(num_gaussians, num_joints))

        # ====== Appearance Head: feature -> RGBA ======
        self.appearance_head = nn.Sequential(
            nn.Linear(feature_dim, max(64, feature_dim * 2)),
            nn.ReLU(inplace=True),
            nn.Linear(max(64, feature_dim * 2), 4)  # RGB(3) + Alpha(1)
        )
        self._init_weights(self.appearance_head)

        # ====== Neural Feature Modulator (NFM) ======
        # Adaptively adjusts appearance based on pose and motion dynamics
        if self.enable_nfm:
            mod_in = 10  # [agg_pos(3), agg_vel(3), base_app(4)]
            mod_hidden = max(64, feature_dim * 2)
            self.nfm = nn.Sequential(
                nn.Linear(mod_in, mod_hidden),
                nn.ReLU(inplace=False),
                nn.Linear(mod_hidden, 5)  # delta_rgb(3) + delta_alpha(1) + saliency_gate(1)
            )
            self._init_weights(self.nfm)
        else:
            self.nfm = None

        # ====== Optional Temporal GRU for motion modeling ======
        if use_temporal_gru and self.enable_nfm:
            self.temporal_gru = nn.GRU(input_size=10, hidden_size=10, num_layers=1, batch_first=True)
        else:
            self.temporal_gru = None

        # Depth-color mixing weight (learnable blend between learned color and depth visualization)
        self.depth_mix_logit = nn.Parameter(torch.tensor(-0.5, dtype=torch.float32))

        # State flags
        self._canonical_joints_initialized = False
        self._canonical_gaussians_initialized = False
        self._lbs_initialized = False
        self._h_gru = None

    def _init_weights(self, module):
        """Initialize network weights with small random values."""
        for m in module.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(mean=0.0, std=0.01)
                if m.bias is not None:
                    m.bias.data.normal_(mean=0.0, std=0.1)

    # ====== Setup Methods ======
    def set_canonical_joints(self, joints_xyz: torch.Tensor):
        """Set canonical (rest) pose joint positions."""
        self.canonical_joints.copy_(joints_xyz)
        self._canonical_joints_initialized = True

    def set_canonical_means(self, means_xyz: torch.Tensor):
        """Set canonical Gaussian center positions."""
        self.canonical_means.copy_(means_xyz)
        self._canonical_gaussians_initialized = True

    def set_lbs_weights_logits(self, logits: torch.Tensor):
        """Set LBS blend weight logits (fixed, not trained)."""
        with torch.no_grad():
            self.lbs_weights_logits.copy_(logits)
        self._lbs_initialized = True

    def reset_temporal_state(self, batch_size: int, K_total: int, device: torch.device):
        """Reset GRU hidden state for new sequence."""
        if self.temporal_gru is not None:
            self._h_gru = torch.zeros(1, batch_size, self.temporal_gru.hidden_size, device=device)
        else:
            self._h_gru = None

    # ====== Core LBS Transform Methods ======
    def compute_joint_transforms(
        self, 
        joints_now: torch.Tensor, 
        orients_now: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute per-joint SE(3) transforms from canonical to current pose.
        
        Args:
            joints_now: Current joint positions (B, J, 3)
            orients_now: Optional joint orientations as quaternions (B, J, 4)
            
        Returns:
            Joint transforms (B, J, 4, 4)
        """
        B, J = joints_now.shape[:2]
        t = joints_now - self.canonical_joints.unsqueeze(0)
        
        if orients_now is not None:
            # Validate quaternions (replace invalid with identity)
            quat_norm = orients_now.norm(dim=-1, keepdim=True)
            invalid = (quat_norm < 1e-6) | torch.isnan(orients_now).any(dim=-1, keepdim=True)
            if invalid.any():
                identity = torch.tensor([1.0, 0.0, 0.0, 0.0], device=orients_now.device, dtype=orients_now.dtype)
                orients_now = torch.where(invalid.expand_as(orients_now), identity.view(1, 1, 4).expand_as(orients_now), orients_now)
            R = _quat_to_mat(orients_now)
        else:
            R = torch.eye(3, device=joints_now.device, dtype=joints_now.dtype).expand(B, J, 3, 3)
        
        return _compose_T(R, t)

    def blend_transforms(
        self, 
        joint_T: torch.Tensor, 
        lbs_w: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Blend joint transforms using LBS weights.
        
        Linear Blend Skinning formula:
            t_blend = Σ w_j * t_j
            R_blend = Project_SO3(Σ w_j * R_j)
        
        Args:
            joint_T: Per-joint transforms (B, J, 4, 4)
            lbs_w: Blend weights (K, J), softmax normalized
            
        Returns:
            per_gauss_T: Per-Gaussian transforms (B, K, 4, 4)
            R_blend: Blended rotations (B, K, 3, 3)
            t_blend: Blended translations (B, K, 3)
        """
        Bstar, J = joint_T.shape[:2]
        K = lbs_w.shape[0]
        Rj = joint_T[..., :3, :3]
        tj = joint_T[..., :3, 3]
        
        # Weighted sum of translations
        t_blend = torch.einsum('kj, bjc -> bkc', lbs_w, tj)
        # Weighted sum of rotations + project to SO(3)
        R_lin = torch.einsum('kj, bjmn -> bkmn', lbs_w, Rj)
        R_blend = _project_to_rotation(R_lin.reshape(-1, 3, 3)).reshape(Bstar, K, 3, 3)
        
        return _compose_T(R_blend, t_blend), R_blend, t_blend

    def transform_gaussians(
        self,
        means0: torch.Tensor,
        scales0: torch.Tensor,
        quats0: torch.Tensor,
        R_blend: torch.Tensor,
        t_blend: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Transform canonical Gaussians to current pose using blended transforms.
        
        Applies:
            μ' = R_blend * μ_canonical + t_blend
            Σ' = R_total * diag(s²) * R_total^T
        
        Returns:
            means: Transformed centers (B, K, 3)
            covs: Transformed covariances (B, K, 3, 3)
            quats: Transformed orientations (B, K, 4)
            scales: Scales (B, K, 3)
        """
        Bstar, K = R_blend.shape[:2]
        op_dtype = R_blend.dtype
        
        means0 = means0.to(op_dtype).unsqueeze(0).expand(Bstar, K, 3)
        R0 = _quat_to_mat(quats0.to(op_dtype)).unsqueeze(0).expand(Bstar, K, 3, 3)
        Rtot = R_blend @ R0
        
        means = (R_blend @ means0.unsqueeze(-1)).squeeze(-1) + t_blend
        s = scales0.to(op_dtype).unsqueeze(0).expand(Bstar, K, 3)
        covs = Rtot @ torch.diag_embed(s * s) @ Rtot.transpose(-1, -2)
        quats = _mat_to_quat(Rtot.reshape(-1, 3, 3)).reshape(Bstar, K, 4)
        
        return means, covs, quats, s

    # ====== Differentiable Rasterization ======
    def _vectorized_rasterize(
        self,
        means3D: torch.Tensor,
        cov3D: torch.Tensor,
        colors: torch.Tensor,
        opacities: torch.Tensor,
        K: torch.Tensor,
        w2c: torch.Tensor,
        H: int,
        W: int,
        chunk_k: int = 32
    ) -> torch.Tensor:
        """
        Differentiable rasterization via front-to-back alpha compositing.
        
        Projects 3D Gaussians to screen space and computes pixel colors:
            I(x,y) = Σ_k C_k * α'_k * Π_{j<k}(1 - α'_j)
        
        Args:
            means3D: Gaussian centers (B, K, 3)
            cov3D: 3D covariances (B, K, 3, 3)
            colors: RGB colors (B, K, 3)
            opacities: Opacity values (B, K)
            K: Camera intrinsics (B, 3, 3)
            w2c: World-to-camera transform (B, 4, 4)
            
        Returns:
            Rendered image (B, H, W, 3)
        """
        device = means3D.device
        B, Knum, _ = means3D.shape
        C = colors.shape[-1]

        # Project to camera coordinates
        means_h = torch.cat([means3D, torch.ones_like(means3D[..., :1])], dim=-1)
        w2c_exp = w2c.unsqueeze(1).expand(B, Knum, -1, -1)
        cam = (w2c_exp @ means_h.unsqueeze(-1)).squeeze(-1)
        Xc, Yc, Zc = cam[..., 0], cam[..., 1], torch.clamp(cam[..., 2], min=1e-4)
        
        # Project to screen coordinates
        fx, fy = K[:, 0, 0].unsqueeze(-1), K[:, 1, 1].unsqueeze(-1)
        cx, cy = K[:, 0, 2].unsqueeze(-1), K[:, 1, 2].unsqueeze(-1)
        u = fx * (Xc / Zc) + cx
        v = -fy * (Yc / Zc) + cy

        # Compute 2D covariance via Jacobian
        J11, J22 = fx / Zc, -fy / Zc
        J13, J23 = -fx * (Xc / (Zc * Zc)), fy * (Yc / (Zc * Zc))
        J12, J21 = torch.zeros_like(J11), torch.zeros_like(J11)
        J = torch.stack([
            torch.stack([J11, J12, J13], dim=-1),
            torch.stack([J21, J22, J23], dim=-1)
        ], dim=-2)
        Sigma2D = J @ cov3D @ J.transpose(-1, -2)
        Sigma2D = Sigma2D + 1e-5 * torch.eye(2, device=device, dtype=Sigma2D.dtype).view(1, 1, 2, 2)
        inv2D = torch.linalg.inv(Sigma2D)

        # Depth sort (front to back)
        zsort_idx = torch.argsort(Zc, dim=1, descending=False)
        u = torch.gather(u, 1, zsort_idx)
        v = torch.gather(v, 1, zsort_idx)
        colors = torch.gather(colors, 1, zsort_idx.unsqueeze(-1).expand(-1, -1, C))
        opacities = torch.gather(opacities, 1, zsort_idx)
        inv2D = torch.gather(inv2D, 1, zsort_idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, 2))

        # Pixel grid
        ys = torch.arange(0, H, device=device, dtype=means3D.dtype)
        xs = torch.arange(0, W, device=device, dtype=means3D.dtype)
        grid_y, grid_x = torch.meshgrid(ys, xs, indexing='ij')
        grid_x, grid_y = grid_x.unsqueeze(0), grid_y.unsqueeze(0)

        # Alpha compositing
        out = torch.zeros(B, H, W, C, device=device, dtype=means3D.dtype)
        trans = torch.ones(B, H, W, 1, device=device, dtype=means3D.dtype)

        for start in range(0, Knum, chunk_k):
            end = min(Knum, start + chunk_k)
            for i in range(end - start):
                idx = start + i
                u_i, v_i = u[:, idx], v[:, idx]
                inv_i = inv2D[:, idx]
                col_i, opa_i = colors[:, idx], opacities[:, idx]

                dx = grid_x - u_i.view(B, 1, 1)
                dy = grid_y - v_i.view(B, 1, 1)
                a, b = inv_i[:, 0, 0].view(B, 1, 1), inv_i[:, 0, 1].view(B, 1, 1)
                c, d = inv_i[:, 1, 0].view(B, 1, 1), inv_i[:, 1, 1].view(B, 1, 1)
                dist = a * dx * dx + (b + c) * dx * dy + d * dy * dy
                w = torch.exp(-0.5 * dist).clamp(0.0, 1.0)
                alpha = (1.0 - torch.exp(-opa_i.view(B, 1, 1) * w)).unsqueeze(-1)

                out = out + trans * alpha * col_i.view(B, 1, 1, 3)
                trans = trans * (1.0 - alpha)

        return out

    # ====== Forward Pass ======
    def forward(
        self,
        poses: torch.Tensor,
        K: torch.Tensor,
        w2c: torch.Tensor,
        vels: Optional[torch.Tensor] = None,
        orients: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        """
        Render skeleton sequence into video frames.
        
        Args:
            poses: Joint positions (T, P, J, 3)
            K: Camera intrinsics (T, 3, 3)
            w2c: World-to-camera transforms (T, 4, 4)
            vels: Pre-computed velocities (T, P, J, 3)
            orients: Optional joint orientations (T, P, J, 4)
            
        Returns:
            Rendered video (1, T, H, W, 3) in [0, 1]
        """
        device = poses.device
        poses = poses.to(device=device, dtype=torch.float32)
        K = K.to(device=device, dtype=torch.float32)
        w2c = w2c.to(device=device, dtype=torch.float32)
        if vels is not None:
            vels = vels.to(device=device, dtype=torch.float32)
        if orients is not None:
            orients = orients.to(device=device, dtype=torch.float32)
        
        T_len, P, J = poses.shape[0], poses.shape[1], poses.shape[2]
        K_total = self.canonical_means.shape[0]
        lbs_w = torch.softmax(self.lbs_weights_logits, dim=-1).to(dtype=poses.dtype)

        # Compute velocities if not provided
        if vels is None:
            stride = self.temporal_stride
            idx_fut = torch.clamp(torch.arange(T_len, device=device) + stride, max=T_len - 1)
            vels = poses.index_select(0, idx_fut) - poses

        # Base appearance from learnable features
        with torch.cuda.amp.autocast(enabled=False):
            base_rgba = self.appearance_head(self.canonical_features.float()).unsqueeze(0)
            base_rgba = torch.nan_to_num(base_rgba, nan=0.0)
            base_rgb = base_rgba[..., 0:3].sigmoid()
            base_alpha = base_rgba[..., 3:4].sigmoid()
        
        lam = torch.sigmoid(self.depth_mix_logit).to(poses.dtype)

        frames_list: List[torch.Tensor] = []
        self.reset_temporal_state(batch_size=1, K_total=K_total, device=device)

        for t in range(T_len):
            per_means, per_colors, per_opac, per_covs = [], [], [], []

            for p in range(P):
                joints_now = poses[t, p].unsqueeze(0)
                valid = (joints_now.abs().sum() > 0).float()
                orients_now = orients[t, p].unsqueeze(0) if orients is not None else None

                # LBS transform
                joint_T = self.compute_joint_transforms(joints_now, orients_now)
                _, R_blend, t_blend = self.blend_transforms(joint_T, lbs_w)
                means, covs, _, _ = self.transform_gaussians(
                    self.canonical_means, self.canonical_scales, self.canonical_quats, R_blend, t_blend
                )

                # Depth-based coloring
                means_h = torch.cat([means, torch.ones_like(means[..., :1])], dim=-1)
                cam = (w2c[t:t+1].unsqueeze(1) @ means_h.unsqueeze(-1)).squeeze(-1)
                Zc = torch.clamp(cam[..., 2], min=1e-4)
                depth_rgb = _depth_to_rgb(Zc)

                # Neural Feature Modulator (if enabled)
                if self.enable_nfm and self.nfm is not None:
                    vel_now = vels[t, p].unsqueeze(0)
                    agg_pos = torch.einsum('kj, bjc -> bkc', lbs_w, joints_now)
                    agg_vel = torch.einsum('kj, bjc -> bkc', lbs_w, vel_now)
                    mod_in = torch.cat([agg_pos, agg_vel, base_rgba], dim=-1)
                    
                    with torch.cuda.amp.autocast(enabled=False):
                        mod_in_eff = mod_in.float()
                        if self.temporal_gru is not None:
                            x_seq = mod_in_eff.reshape(1, -1, mod_in_eff.shape[-1])
                            x_seq, self._h_gru = self.temporal_gru(x_seq, self._h_gru)
                            mod_in_eff = x_seq.reshape(1, -1, mod_in_eff.shape[-1])
                        deltas = torch.nan_to_num(self.nfm(mod_in_eff), nan=0.0).to(mod_in.dtype)
                    
                    delta_rgb, delta_alpha = deltas[..., 0:3], deltas[..., 3:4]
                    saliency = deltas[..., 4:5].sigmoid()
                    learned_rgb = (base_rgb.to(delta_rgb.dtype) + delta_rgb).clamp(0.0, 1.0)
                    learned_alpha = (base_alpha.to(delta_alpha.dtype) + delta_alpha).sigmoid() * saliency
                else:
                    learned_rgb = base_rgb.to(depth_rgb.dtype)
                    learned_alpha = base_alpha.to(depth_rgb.dtype)

                final_color = (1.0 - lam) * learned_rgb + lam * depth_rgb
                opac = learned_alpha.squeeze(-1) * valid

                per_means.append(means)
                per_colors.append(final_color)
                per_opac.append(opac)
                per_covs.append(covs)

            # Concatenate all persons
            means_cat = torch.cat(per_means, dim=1)
            colors_cat = torch.nan_to_num(torch.cat(per_colors, dim=1), nan=0.0).clamp(0.0, 1.0)
            opac_cat = torch.nan_to_num(torch.cat(per_opac, dim=1), nan=0.0).clamp(0.0, None)
            cov3D = torch.cat(per_covs, dim=1)

            # Rasterize
            frame_t = self._vectorized_rasterize(
                means3D=means_cat, cov3D=cov3D, colors=colors_cat, opacities=opac_cat,
                K=K[t:t+1], w2c=w2c[t:t+1], H=self.H, W=self.W
            )
            frames_list.append(torch.nan_to_num(frame_t, nan=0.0))

        frames_bt = torch.cat(frames_list, dim=0)
        return frames_bt.reshape(1, T_len, self.H, self.W, 3).contiguous().clamp(0.0, 1.0)


# ============================================================================
# Helper Functions
# ============================================================================
def _depth_to_rgb(z: torch.Tensor) -> torch.Tensor:
    """Convert depth values to RGB color visualization."""
    zmin = z.amin(dim=1, keepdim=True)
    zmax = z.amax(dim=1, keepdim=True)
    d = (z - zmin) / (zmax - zmin + 1e-6)
    r, g, b = d, 1.0 - torch.abs(d - 0.5) * 2.0, 1.0 - d
    return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)


def _build_line_samples(
    canonical_joints: torch.Tensor, 
    num_line_samples: int
) -> Tuple[torch.Tensor, List[Tuple[int, int, float]]]:
    """
    Build Gaussian samples along bone edges.
    
    For each bone (a,b), samples num_line_samples points uniformly along the bone.
    
    Returns:
        samples: Sampled positions (N, 3)
        sample_defs: List of (joint_a_idx, joint_b_idx, interpolation_alpha)
    """
    samples = []
    sample_defs: List[Tuple[int, int, float]] = []
    for a, b in PAIRS:
        a_idx, b_idx = a - 1, b - 1
        for s in range(1, num_line_samples + 1):
            alpha = s / (num_line_samples + 1)
            p = (1.0 - alpha) * canonical_joints[a_idx] + alpha * canonical_joints[b_idx]
            samples.append(p)
            sample_defs.append((a_idx, b_idx, alpha))
    samples_tensor = torch.stack(samples, dim=0) if samples else torch.empty(0, 3, device=canonical_joints.device)
    return samples_tensor, sample_defs


def _make_lbs_logits_for_samples(
    num_joints: int,
    sample_defs: List[Tuple[int, int, float]],
    joint_focus_logit: float = 10.0,
    other_logit: float = -10.0
) -> torch.Tensor:
    """
    Create LBS weight logits for bone samples.
    
    Each sample's weight is distributed between its two parent joints
    based on the interpolation factor alpha.
    """
    logits = []
    for a_idx, b_idx, alpha in sample_defs:
        row = torch.full((num_joints,), other_logit)
        row[a_idx] = math.log(max(1e-6, 1.0 - alpha)) + joint_focus_logit
        row[b_idx] = math.log(max(1e-6, alpha)) + joint_focus_logit
        logits.append(row)
    return torch.stack(logits, dim=0) if logits else torch.empty(0, num_joints)


def _compute_adaptive_scales(
    canonical_joints: torch.Tensor,
    sample_defs: List[Tuple[int, int, float]],
    num_joints: int,
    base_joint_scale: float = 0.030,
    base_line_scale: float = 0.020,
    min_joint_scale: float = 0.010,
    max_joint_scale: float = 0.035,
    min_line_scale: float = 0.006,
    max_line_scale: float = 0.030,
    gamma: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute adaptive Gaussian scales based on bone lengths.
    
    Larger bones get larger Gaussians for consistent visual density.
    """
    device, dtype = canonical_joints.device, canonical_joints.dtype
    
    # Compute bone lengths
    pairs_idx = [(a - 1, b - 1) for a, b in PAIRS]
    lengths = [(canonical_joints[a] - canonical_joints[b]).norm().item() for a, b in pairs_idx]
    Lmax = max(lengths) if lengths else 1.0
    length_by_pair = {pair: L for pair, L in zip(pairs_idx, lengths)}
    
    # Joint adjacency for scale computation
    joints_adj: List[List[float]] = [[] for _ in range(num_joints)]
    for (a, b), L in zip(pairs_idx, lengths):
        joints_adj[a].append(L)
        joints_adj[b].append(L)

    # Joint scales (based on median adjacent bone length)
    joint_scales = []
    for j in range(num_joints):
        if joints_adj[j]:
            rj = float(torch.tensor(joints_adj[j]).median().item()) / Lmax
        else:
            rj = 0.5
        sj = max(min_joint_scale, min(max_joint_scale, base_joint_scale * (rj ** gamma)))
        joint_scales.append([sj, sj, sj])

    # Line sample scales (based on parent bone length)
    line_scales = []
    for a_idx, b_idx, _ in sample_defs:
        L = length_by_pair.get((a_idx, b_idx), Lmax)
        r = L / Lmax
        sl = max(min_line_scale, min(max_line_scale, base_line_scale * (r ** gamma)))
        line_scales.append([sl, sl, sl])

    return (torch.tensor(joint_scales, dtype=dtype, device=device),
            torch.tensor(line_scales, dtype=dtype, device=device) if line_scales else torch.empty(0, 3, dtype=dtype, device=device))


# ============================================================================
# NTU Skeleton Parser
# ============================================================================
def parse_ntu_skeleton_file(file_path: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Parse NTU RGB+D skeleton file.
    
    Returns:
        poses: Joint positions (T, 2, 25, 3)
        metas: Metadata (T, 2, 10)
        orients: Joint orientations as quaternions (T, 2, 25, 4)
    """
    with open(file_path, 'r') as f:
        tokens = f.read().split()
    idx = 0

    def next_int() -> int:
        nonlocal idx
        val = int(tokens[idx]); idx += 1
        return val

    def next_floats(n: int) -> List[float]:
        nonlocal idx
        vals = list(map(float, tokens[idx:idx + n])); idx += n
        return vals

    T = next_int()
    poses, metas, orients = [], [], []

    for _ in range(T):
        bodies = next_int()
        frame_joints, frame_quats, frame_meta = [], [], []

        for _ in range(bodies):
            meta10 = next_floats(10)
            J = next_int()
            joints_b, quats_b = [], []
            for __ in range(J):
                vals = next_floats(12)
                joints_b.append([vals[0], vals[1], vals[2]])
                quats_b.append([vals[7], vals[8], vals[9], vals[10]])
            
            # Pad/truncate to 25 joints
            while len(joints_b) < 25:
                joints_b.append([0.0, 0.0, 0.0])
                quats_b.append([1.0, 0.0, 0.0, 0.0])
            joints_b, quats_b = joints_b[:25], quats_b[:25]
            
            if len(frame_joints) < 2:
                frame_joints.append(joints_b)
                frame_quats.append(quats_b)
                frame_meta.append(meta10)

        # Pad to 2 bodies
        while len(frame_joints) < 2:
            frame_joints.append([[0.0, 0.0, 0.0] for _ in range(25)])
            frame_quats.append([[1.0, 0.0, 0.0, 0.0] for _ in range(25)])
            frame_meta.append([0.0 for _ in range(10)])

        poses.append(frame_joints)
        metas.append(frame_meta)
        orients.append(frame_quats)

    return (torch.tensor(poses, dtype=torch.float32),
            torch.tensor(metas, dtype=torch.float32),
            torch.tensor(orients, dtype=torch.float32))
