import torch
import torch.nn.functional as F
import numpy as np
import math

def dummy_cost(*inputs, **kwargs):
    """
    A dummy cost function that returns a tensor of zeros with the same shape as the input.
    This is useful for testing or as a placeholder when no cost is needed.
    
    Args:
        inputs: Input tensor, expected shape (B, H, d).
    
    Returns:
        A tensor of zeros with the same shape as the input.
    """
    ## Return a zero tensor of the same batch size as the first input 
    ## but keep the gradient flow.
    if len(inputs) == 0:
        raise ValueError("No input tensors provided to dummy_cost.")
    x:torch.Tensor = inputs[0]
    if x.dim() < 3:
        raise ValueError(f"Input tensor x must be at least 3D (B, H, d). Got {x.dim()}D.")
    return torch.zeros(x.shape[0], dtype=x.dtype, device=x.device, requires_grad=True)



def sum_log_l2_cost(
# def sum_log_l2_cost_batch(
    *inputs: torch.Tensor,
    alpha: float,  # positive constant to avoid log singularity
    **kwargs,
) -> torch.Tensor:
    """
    Vectorized version of sum_log_l2_cost.

    Args:
        inputs: >=2 tensors, each of shape (B, H, d)
        alpha:  >0, small constant inside the log

    Returns:
        (B,) tensor: summed log-penalty per batch element.
    """
    assert len(inputs) >= 2, "At least two input tensors are required."
    assert alpha > 0, "Alpha must be > 0."
    x0 = inputs[0]
    assert x0.dim() >= 3, f"Input tensor must be at least 3D (B, H, d). Got {x0.dim()}D."

    # Shape checks
    for y in inputs[1:]:
        if y.shape != x0.shape or y.dim() < 3:
            raise ValueError("Expected inputs of shape (..., H, d).")

    # Stack: (B, H, N, d)
    X = torch.stack(inputs, dim=-2)  # same as before but no cdist
    N = len(inputs)

    # Upper-triangular pair indices (i<j)
    i_idx, j_idx = torch.triu_indices(N, N, offset=1, device=X.device)

    # Gather pairs: (B, H, P, d), where P = N*(N-1)//2
    Xi = X[..., i_idx, :]   # (B, H, P, d)
    Xj = X[..., j_idx, :]   # (B, H, P, d)

    # Distances for all P pairs at once
    D = torch.linalg.norm(Xi - Xj, dim=-1)  # (B, H, P)

    # Apply penalty and sum over pairs then time
    pair_costs = -torch.log(alpha + D)      # (B, H, P)
    return pair_costs.sum(dim=-1).sum(dim=-1)  # (B,)

def sum_log_l2_cost_loop(
# def sum_log_l2_cost(
    *inputs: torch.Tensor,
    alpha: float,  # A small constant to avoid singularity in log
    **kwargs,
) -> torch.Tensor:
    """
    Computes a summed, negative-logarithmic penalty on squared L2 distances
    for batches of trajectories. At each time-step, the cost is
        // -log( sum_d (x_hd - y_hd) + epsilon )
        -log( sum_d (x_hd - y_hd) + 1 )
    which strongly penalizes small distances and grows more gently as
    distance increases.

    Args:
        inputs (torch.Tensor): Tensor of shape (B, H, d).
        epsilon (float): Small constant added inside the log to avoid
                         singularity when distance is zero.

    Returns:
        torch.Tensor: A tensor of shape (B,), where each element is
                      the total log-penalty cost for one trajectory pair.
    """
    assert len(inputs) >= 2, "At least two input tensors are required for sum_log_l2_cost."
    assert alpha > 0, "Alpha must be a positive constant to avoid singularity in log."
    x = inputs[0]
    assert x.dim() >= 3, f"Input tensor x must be at least 3D (B, H, d). Got {x.dim()}D."
    for i, y in enumerate(inputs[1:]):
        if x.shape != y.shape or y.dim() < 3:
            raise ValueError("Expected inputs of shape (..., H, d).")

    # L2 norm per time-step -> (B, H) over every pair of inputs
    per_step_costs = 0.0
    for i in range(len(inputs) - 1):
        for j in range(i + 1, len(inputs)):
            per_step_costs += \
                -torch.log(
                    alpha + \
                    torch.linalg.norm(inputs[i] - inputs[j], dim=-1)
                )
    
    # sum over time -> (B,)
    return per_step_costs.sum(dim=-1)



def hinge_sqr_l2_cost(
    *inputs: torch.Tensor,
    active_range: float = 1.0,  # Radius for the collision cost
    **kwargs,
) -> torch.Tensor:
    """
    Vectorized hinge squared L2 cost.

    Args:
        inputs: >=2 tensors, each of shape (B, H, d)
        active_range: radius below which the hinge penalty is applied

    Returns:
        (B,) tensor: summed hinge-squared penalty per batch element.
    """
    assert len(inputs) >= 2, "At least two input tensors are required."
    assert active_range > 0, "active_range must be > 0."
    x0 = inputs[0]
    assert x0.dim() >= 3, f"Input tensor must be at least 3D (B, H, d). Got {x0.dim()}D."

    # Shape checks
    for y in inputs[1:]:
        if y.shape != x0.shape or y.dim() < 3:
            raise ValueError("Expected inputs of shape (..., H, d).")

    X = torch.stack(inputs, dim=-2)  # (..., H, N, d)
    N = len(inputs)

    i_idx, j_idx = torch.triu_indices(N, N, offset=1, device=X.device)
    Xi = X[..., i_idx, :]   # (..., H, P, d)
    Xj = X[..., j_idx, :]   # (..., H, P, d)

    D = torch.linalg.norm(Xi - Xj, dim=-1)  # (..., H, P)
    pair_costs = F.relu(active_range - D) ** 2  # (..., H, P)
    return pair_costs.sum(dim=-1).sum(dim=-1)  # (...,)



# Registry for cost functions
cost_registry = {
    "dummy": dummy_cost,
    "sum_log_l2": sum_log_l2_cost,
    "hinge_sqr_l2": hinge_sqr_l2_cost,
}
