import torch
import torch.nn.functional as F
from typing import Optional, List


def calculate_semantic_loss(
    joints1: torch.Tensor,
    joints2: torch.Tensor,
    height1: float = 1.75,
    height2: float = 1.75,
    attention_list: Optional[List[int]] = None,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    
    device = joints1.device
    batch_size, seq_len, num_joints, _ = joints1.shape

    
    if attention_list is None:
        attention_list = [15, 16, 19, 20]

    
    if mask is None:
        mask = torch.ones(batch_size, seq_len, device=device)

    
    attW = torch.ones(num_joints, device=device)
    if attention_list:
        attW[attention_list] = 1.0

    total_loss = 0.0
    valid_batches = 0

    for b in range(batch_size):
        
        motion1 = joints1[b]  # (seq_len, num_joints, 3)
        motion2 = joints2[b]  # (seq_len, num_joints, 3)
        batch_mask = mask[b]  # (seq_len,)

        
        dist_matrix1 = torch.cdist(motion1, motion1, p=2)  # (seq_len, num_joints, num_joints)
        dist_matrix2 = torch.cdist(motion2, motion2, p=2)  # (seq_len, num_joints, num_joints)

        
        height_tensor1 = torch.tensor([height1] * seq_len, device=device).view(seq_len, 1, 1)
        height_tensor2 = torch.tensor([height2] * seq_len, device=device).view(seq_len, 1, 1)
        
        norm_matrix1 = dist_matrix1 / (height_tensor1 * 100.0)
        norm_matrix2 = dist_matrix2 / (height_tensor2 * 100.0)

        
        row_sum1 = torch.sum(norm_matrix1, dim=2, keepdim=True)
        row_sum2 = torch.sum(norm_matrix2, dim=2, keepdim=True)
        
        norm_matrix1 = norm_matrix1 / (row_sum1 + 1e-8)
        norm_matrix2 = norm_matrix2 / (row_sum2 + 1e-8)

        
        diff = norm_matrix1 - norm_matrix2
        masked_diff_sq = (batch_mask.view(seq_len, 1, 1) * diff) ** 2

        
        loss_per_joint = torch.sum(masked_diff_sq, dim=[0, 2])  # (num_joints,)

        
        weighted_loss = torch.sum(attW * loss_per_joint)

        
        num_valid_frames = torch.maximum(
            torch.sum(batch_mask), torch.tensor(1.0, device=device)
        )
        
        batch_loss = weighted_loss / num_valid_frames
        total_loss += batch_loss
        valid_batches += 1

    
    return total_loss / max(valid_batches, 1)


def calculate_jitter_loss(
    joints: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    order: int = 2
) -> torch.Tensor:
    
    device = joints.device
    batch_size, seq_len, num_joints, _ = joints.shape

    
    if mask is None:
        mask = torch.ones(batch_size, seq_len, device=device)

    total_loss = 0.0
    valid_batches = 0

    for b in range(batch_size):
        
        motion = joints[b]  # (seq_len, num_joints, 3)
        batch_mask = mask[b]  # (seq_len,)

        
        diff = motion
        for _ in range(order):
            
            diff = diff[1:] - diff[:-1]
            
            batch_mask = batch_mask[1:] * batch_mask[:-1]

        
        diff_norm = torch.norm(diff, p=2, dim=-1)  # (seq_len-order, num_joints)
        
        
        masked_diff = diff_norm * batch_mask.view(-1, 1)
        batch_loss = torch.mean(masked_diff)
        
        total_loss += batch_loss
        valid_batches += 1

    
    return total_loss / max(valid_batches, 1)


def semantic_loss() -> torch.Tensor:
    
    
        raise NotImplementedError(
            "semantic_loss() should be implemented in optimization context. "
            "Use calculate_semantic_loss(joints1, joints2, ...) instead."
        )


def jitter_loss() -> torch.Tensor:
    
    
        raise NotImplementedError(
            "jitter_loss() should be implemented in optimization context. "
            "Use calculate_jitter_loss(joints, ...) instead."
        )



if __name__ == '__main__':
    
    batch_size, seq_len, num_joints = 2, 50, 22
    
    
    joints1 = torch.randn(batch_size, seq_len, num_joints, 3)
    joints2 = torch.randn(batch_size, seq_len, num_joints, 3)
    
    
    mask = torch.ones(batch_size, seq_len)
    mask[0, 45:] = 0  
    
    
    sem_loss = calculate_semantic_loss(joints1, joints2, mask=mask)
    print(f"semantic loss: {sem_loss.item():.6f}")
    
    
    jit_loss = calculate_jitter_loss(joints1, mask=mask)
    print(f"jitter loss: {jit_loss.item():.6f}")
    
    
    jit_loss_order1 = calculate_jitter_loss(joints1, mask=mask, order=1)
    jit_loss_order2 = calculate_jitter_loss(joints1, mask=mask, order=2)
    jit_loss_order3 = calculate_jitter_loss(joints1, mask=mask, order=3)
    
    print(f"order-1 jitter loss: {jit_loss_order1.item():.6f}")
    print(f"order-2 jitter loss: {jit_loss_order2.item():.6f}")
    print(f"order-3 jitter loss: {jit_loss_order3.item():.6f}")