import math
import torch
import numpy as np


def nose_tail_object_coords(node_positions, object_ids):
    # Extract the positions of the nodes corresponding to the object
    object_positions = node_positions[object_ids]

    # Nose coordinate system
    leftmost = object_positions[object_positions[:, 0].argmin()]
    leftmost_displacement = node_positions - leftmost

    # Tail coordinate system
    rightmost = object_positions[object_positions[:, 0].argmax()]
    rightmost_displacement = node_positions - rightmost

    return leftmost_displacement, rightmost_displacement


def closest_object_coords(node_positions, object_ids, device='cpu'):
    # Move node positions and surface coordinates to the specified device
    simulation_positions = node_positions.to(device)
    object_positions = node_positions[object_ids].to(device)

    # Calculate closest points
    with torch.no_grad():
        # Calculate the pairwise distances
        closest_object_ids = (simulation_positions.unsqueeze(0) - object_positions.unsqueeze(1)).norm(dim=-1).argmin(dim=0).cpu()

    # Calculate the displacement to the closest points
    closest_displacement = simulation_positions.cpu() - object_positions[closest_object_ids]

    return closest_displacement, closest_object_ids


def angles_to_planes(coords):
    x, y = coords[:, 0], coords[:, 1]
    # Angles
    angles = torch.stack([
        torch.atan2(y, x),
        torch.atan2(y, -x),
        torch.atan2(-y, -x),
        torch.atan2(-y, x),
    ], dim=1)
    return angles


def sinusoidal_embedding(x: torch.Tensor,
                         num_basis: int = 8,
                         max_coord: float = 2.0,
                         spacing: float = 1.0) -> torch.Tensor:
    """
    Sin/Cos positional embedding (like Transformer PE).
    Input x: [*, 1] or [*, D] tensor of coords/distances.
    Returns: [*, 2 * num_basis * D] flattened embedding.
    """
    # Normalize and compute frequencies
    x = x / spacing
    max_seq = max_coord / spacing
    exponents = -math.log(max_seq * 4 / math.pi) / (num_basis - 1)
    div_term = torch.exp(torch.arange(num_basis, device=x.device) * exponents)
    # Shape: [* , D, num_basis]
    emb = x.unsqueeze(-1) * div_term
    sin_emb = emb.sin()
    cos_emb = emb.cos()
    # Concat and flatten the last two dims: -> [*, D * 2 * num_basis]
    return torch.cat([sin_emb, cos_emb], dim=-1).flatten(-2, -1)


def spherical_harmonics(angle: torch.Tensor,
                           l_max: int = 3) -> torch.Tensor:
    """
    Real spherical harmonics (Legendre P_l) for 2D axisymmetric case.
    angle: [*, 1] tensor of theta = atan2(y,x).
    Returns: [*, l_max+1] tensor: [P_0(cosθ), P_1(cosθ),...,P_l_max(cosθ)].
    """
    cos_t = torch.cos(angle).cpu().numpy()
    harmonics = []
    for l in range(l_max + 1):
        # P_l(cosθ) via Legendre polynomial
        coeffs = np.zeros(l + 1)
        coeffs[-1] = 1
        P_l = np.polynomial.legendre.Legendre(coeffs)
        harmonics.append(torch.tensor(P_l(cos_t), device=angle.device, dtype=angle.dtype).unsqueeze(-1))
    return torch.cat(harmonics, dim=-1)  # [*, l_max+1]


def compute_sdf(positions, surface_mask):
    """
    positions: [N,2] torch.Tensor (all nodes)
    surface_mask: [N] torch.BoolTensor (True if boundary/surface node)
    Returns: [N] tensor of shortest distances to boundary
    """
    boundary = positions[surface_mask]
    dist = torch.cdist(positions, boundary)
    sdf = dist.min(dim=1)[0]
    return sdf


def compute_sv(positions, surface_mask):
    """
    Returns: [N,2] tensor of shortest vector (from each node to closest boundary node)
    """
    boundary = positions[surface_mask]
    dist = torch.cdist(positions, boundary)  # [N, Nb]
    min_idx = dist.argmin(dim=1)             # [N]
    closest_boundary = boundary[min_idx]      # [N,2]
    sv = positions - closest_boundary         # [N,2]
    return sv


def compute_did(positions, surface_mask, theta_rot=math.pi/8, theta_seg=math.pi/4, inf=4.0):
    # get boundary positions
    boundary = positions[surface_mask]             # [M,2]
    N, _ = positions.shape
    M, _ = boundary.shape

    # calculate theta_ij and dist_ij: [N, M]
    disp = boundary.unsqueeze(1) - positions.unsqueeze(0)  # [M, N, 2]
    dist_ij = disp.norm(dim=2).clamp_max(inf)             # [M, N]
    theta_ij = torch.atan2(disp[...,1], disp[...,0])      # [M, N]

    # return [N, M]
    dist_ij = dist_ij.permute(1, 0)
    theta_ij = theta_ij.permute(1, 0)

    num_segments = int(np.ceil(2 * math.pi / theta_rot))
    # create a tensor containing all segment centers.
    seg_cens = torch.arange(num_segments, device=positions.device) * theta_rot  # [S]

    # with broadcasting, we have the tensor [N, M, S] for θ_ij - θ_cen
    theta_diff = theta_ij.unsqueeze(-1) - seg_cens  # [N, M, S]
    theta_mod = (theta_diff + math.pi) % (2*math.pi) - math.pi  # [-π, π]
    mask = (theta_mod.abs() <= (theta_seg/2)).float()           # [N, M, S]

    # calculate the sum and count the boundary points for each segment
    sum_dist = (dist_ij.unsqueeze(-1) * mask).sum(dim=1)   # [N, S]
    count   = mask.sum(dim=1).clamp_min(1.0)               # [N, S]

    # mean or inf if there is no point
    mean_dist = sum_dist / count
    mean_dist[count == 0] = inf

    # return [S, N]
    return mean_dist.transpose(0,1)


def gaussian_fourier_features(x, num_frequencies=8, scale=5.0, gaussian=True):
    """
    x: [N, D] coordinates
    Returns: [N, num_frequencies*2] encoded features (sin/cos)
    """
    D = x.shape[-1]
    if gaussian:
        B = torch.randn(num_frequencies, D, device=x.device) * scale  # [F, D]
    else:
        # Regular frequencies (as in NeRF)
        freqs = 2 ** torch.arange(num_frequencies, device=x.device) * np.pi
        B = torch.eye(D, device=x.device).repeat(num_frequencies, 1) * freqs[:, None]
    x_proj = (2.0 * torch.pi * x.float()) @ B.T  # [N, F]
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)  # [N, 2F]
