"""
OAT (Optimal Acceleration Transport) Distance Implementation
"""

import torch
import torch.nn.functional as F
from typing import Tuple, Optional

# Define a small constant for numerical stability of sqrt gradient
EPS_SQRT = 1e-8

def oat_distance(z0: Tuple[torch.Tensor, torch.Tensor], 
                z1: Tuple[torch.Tensor, torch.Tensor],
                squared: bool = True) -> torch.Tensor:
    x0, v0 = z0
    x1, v1 = z1
    
    # Ensure tensors are on the same device and dtype
    if x0.device != x1.device or x0.dtype != x1.dtype:
        x1, v1 = x1.to(x0.device, x0.dtype), v1.to(x0.device, x0.dtype)
    
    displacement = x1 - x0
    avg_velocity = (v1 + v0) * 0.5
    
    # Curvature term: ||displacement - avg_velocity||^2
    # Using .pow(2) is standard and efficient in PyTorch
    curvature_term = torch.sum((displacement - avg_velocity).pow(2), dim=-1)
    
    # Impulse term: ||v1 - v0||^2  
    impulse_term = torch.sum((v1 - v0).pow(2), dim=-1)
    
    # OAT distance: 12 * curvature + impulse
    oat_dist_sq = 12.0 * curvature_term + impulse_term
    
    if squared:
        return oat_dist_sq
    else:
        # Add epsilon for numerical stability
        return torch.sqrt(oat_dist_sq + EPS_SQRT)


def efficient_batch_oat_distance(z0_batch: Tuple[torch.Tensor, torch.Tensor],
                                z1_batch: Tuple[torch.Tensor, torch.Tensor],
                                squared: bool = True) -> torch.Tensor:
    x0, v0 = z0_batch  # (N, d)
    x1, v1 = z1_batch  # (M, d)

    # Ensure consistent dtype
    if x0.dtype != x1.dtype:
         x1, v1 = x1.to(x0.dtype), v1.to(x0.dtype)

    # 1. Impulse term: ||v1 - v0||^2
    # torch.cdist(p=2).pow(2) efficiently computes pairwise squared L2 distance
    impulse_terms = torch.cdist(v0, v1, p=2.0).pow(2)

    # 2. Curvature term: 12 * || B - A ||^2
    # Define intermediate variables A (source) and B (target)
    A_src = x0 + 0.5 * v0
    B_tgt = x1 - 0.5 * v1
    
    curvature_terms = 12.0 * (torch.cdist(A_src, B_tgt, p=2.0).pow(2))

    oat_dist_sq = curvature_terms + impulse_terms
    
    # Clamp for safety against minor floating point inaccuracies leading to negative values
    oat_dist_sq = torch.clamp(oat_dist_sq, min=0.0)

    if squared:
        return oat_dist_sq
    else:
        return torch.sqrt(oat_dist_sq + EPS_SQRT)

