import torch
from typing import Optional
from torchmetrics import Metric


def calculate_jitter_metric(
    motion: torch.Tensor,
    fps: float = 20.0,
    order: int = 3,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    
    
    if motion.dim() == 3:
        motion = motion.unsqueeze(0)  # (seq_len, num_joints, 3) -> (1, seq_len, num_joints, 3)
        if mask is not None and mask.dim() == 1:
            mask = mask.unsqueeze(0)  # (seq_len,) -> (1, seq_len)
    
    device = motion.device
    batch_size, seq_len, num_joints, _ = motion.shape

    
    if mask is None:
        mask = torch.ones(batch_size, seq_len, device=device)

    total_jitter = 0.0
    valid_batches = 0

    for b in range(batch_size):
        
        motion_seq = motion[b]  # (seq_len, num_joints, 3)
        batch_mask = mask[b]  # (seq_len,)

        
        diff = motion_seq
        for _ in range(order):
            
            diff = diff[1:] - diff[:-1]
            
            batch_mask = batch_mask[1:] * batch_mask[:-1]

        
        diff_norm = torch.norm(diff, p=2, dim=-1)  # (seq_len-order, num_joints)
        
        
        time_normalized_diff = diff_norm * (fps ** order)
        
        
        masked_diff = time_normalized_diff * batch_mask.view(-1, 1)
        
        
        batch_jitter = torch.mean(masked_diff)
        
        total_jitter += batch_jitter
        valid_batches += 1

    
    return total_jitter / max(valid_batches, 1)


class JitterMetric(Metric):
    
    
    def __init__(self, fps: float = 20.0, order: int = 3) -> None:
        
        super().__init__()
        
        self.fps = fps
        self.order = order
        
        
        self.add_state("jitter_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("gt_jitter_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
        
        
        self.metrics = ["Jitter", "gt_Jitter"]
    
    def update(self, joints_rst: torch.Tensor, joints_ref: torch.Tensor, lengths: list[int]) -> None:
        
        num_samples = joints_rst.shape[0]
        self.count += num_samples
        
        
        device = joints_rst.device
        max_len = joints_rst.shape[1]
        mask = torch.zeros(num_samples, max_len, device=device)
        for i, length in enumerate(lengths):
            mask[i, :length] = 1.0
        
        
        jitter_value = calculate_jitter_metric(
            joints_rst, 
            fps=self.fps, 
            order=self.order, 
            mask=mask
        )
        
        
        gt_jitter_value = calculate_jitter_metric(
            joints_ref, 
            fps=self.fps, 
            order=self.order, 
            mask=mask
        )
        
        
        jitter_value = jitter_value.to(self.jitter_sum.device)
        gt_jitter_value = gt_jitter_value.to(self.gt_jitter_sum.device)
        
        
        self.jitter_sum += jitter_value * num_samples
        self.gt_jitter_sum += gt_jitter_value * num_samples
    
    def compute(self) -> dict:
        
        count = self.count.item()
        if count == 0:
            return {"Jitter": torch.tensor(0.0), "gt_Jitter": torch.tensor(0.0)}
        
        avg_jitter = self.jitter_sum / count
        avg_gt_jitter = self.gt_jitter_sum / count
        return {"Jitter": avg_jitter, "gt_Jitter": avg_gt_jitter}



def compute_jitter_score(
    motion: torch.Tensor,
    fps: float = 20.0,
    order: int = 3,
    mask: Optional[torch.Tensor] = None
) -> float:
    
    with torch.no_grad():
        jitter_value = calculate_jitter_metric(motion, fps=fps, order=order, mask=mask)
        return jitter_value.item()



if __name__ == '__main__':
    
    batch_size, seq_len, num_joints = 2, 50, 22
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    
    motion_smooth = torch.randn(batch_size, seq_len, num_joints, 3, device=device)
    
    for i in range(1, seq_len):
        motion_smooth[:, i] = motion_smooth[:, i-1] + 0.1 * torch.randn_like(motion_smooth[:, i])
    
    
    motion_jittery = motion_smooth.clone()
    motion_jittery += 0.5 * torch.randn_like(motion_jittery)  
    
    
    mask = torch.ones(batch_size, seq_len, device=device)
    mask[0, 45:] = 0  
    
    
    jitter_smooth = compute_jitter_score(motion_smooth, mask=mask)
    jitter_jittery = compute_jitter_score(motion_jittery, mask=mask)
    
    print(f"jitter score of smooth sequence: {jitter_smooth:.6f}")
    print(f"jitter score of jittery sequence: {jitter_jittery:.6f}")
    print(f"jittery sequence should be larger: {jitter_jittery > jitter_smooth}")
    
    
    for order in [1, 2, 3]:
        jitter_order = compute_jitter_score(motion_jittery, order=order, mask=mask)
        print(f"order {order} jitter score: {jitter_order:.6f}")
    
    
    metric = JitterMetric(fps=20.0, order=3).to(device)
    
    joints_ref_smooth = motion_smooth.clone()
    joints_ref_jittery = motion_jittery.clone()
    lengths_smooth = [motion_smooth.shape[1]] * motion_smooth.shape[0]
    lengths_jittery = [motion_jittery.shape[1]] * motion_jittery.shape[0]
    
    metric.update(motion_smooth, joints_ref_smooth, lengths_smooth)
    metric.update(motion_jittery, joints_ref_jittery, lengths_jittery)
    result = metric.compute()
    print(f"Results using JitterMetric:")
    print(f"  generated motion jitter: {result['Jitter'].item():.6f}")
    print(f"  gt motion jitter: {result['gt_Jitter'].item():.6f}")