import torch

from src.features.features import (
    nose_tail_object_coords,
    closest_object_coords,
    sinusoidal_embedding,
    angles_to_planes,
    spherical_harmonics,
    compute_did
)


_node_registry, _edge_registry = {}, {}

def register_node(name):
    def _wrap(fn):
        _node_registry[name] = fn
        return fn
    return _wrap

def register_edge(name):
    def _wrap(fn):
        _edge_registry[name] = fn
        return fn
    return _wrap

def get_node_builder(name):
    return _node_registry[name]

def get_edge_builder(name):
    return _edge_registry[name]


@register_node("basic_nodes")
def basic_nodes(pos, tags, dist, sca, ctx):
    feats = [pos, tags, dist]
    if sca.numel() > 0:
        feats.append(sca)
    x = torch.cat(feats, dim=1)
    return x


@register_node("rich_nodes")
def rich_nodes(pos, tags, dist, sca, ctx):
    feats = [pos, tags, dist]
    nose, tail = nose_tail_object_coords(pos, ctx["obj_ids"])
    feats.append(nose)
    feats.append(tail)
    close, _ = closest_object_coords(pos, ctx["obj_ids"], device=pos.device)
    feats.append(close)
    feats.append(sinusoidal_embedding(pos,  num_basis=6, max_coord=6, spacing=1e-3))
    feats.append(sinusoidal_embedding(dist, num_basis=6, max_coord=6, spacing=1e-3))
    ang = angles_to_planes(pos)
    feats.append(ang)
    sph = [spherical_harmonics(ang[:, j], l_max=4)[:, 1:] for j in range(2)]
    feats.append(torch.cat(sph, dim=1))
    did = compute_did(pos, ctx["obj_ids"])
    did = torch.cat([did[j].unsqueeze(1) for j in range(did.shape[0])], dim=1)
    feats.append(did)
    if sca.numel() > 0:
        feats.append(sca)
    x = torch.cat(feats, dim=1)
    return x


@register_edge("basic_edges")
def gauss_diffs(pos, src, dst, ctx):
    diffs = pos[src] - pos[dst]
    sqd   = (diffs * diffs).sum(dim=1, keepdim=True)
    bw = ctx["args"].bandwidth * 10.0
    w     = torch.exp(-0.5 * sqd / bw)
    return torch.cat([w, diffs], dim=1)


@register_edge("rich_edges")
def rich_edges(pos, src, dst, ctx):
    diffs = pos[src] - pos[dst]
    sqd   = (diffs * diffs).sum(dim=1, keepdim=True)
    bw = ctx["args"].bandwidth * 10.0
    w     = torch.exp(-0.5 * sqd / bw)
    # sd    = sinusoidal_embedding(diffs,  num_basis=6, max_coord=6, spacing=1e-3)
    sc_e  = sinusoidal_embedding(w,      num_basis=6, max_coord=6, spacing=1e-3)
    ang_e = angles_to_planes(diffs)
    sph_e = [spherical_harmonics(ang_e[:, j], l_max=4)[:, 1:] for j in range(2)]
    sph_e = torch.cat(sph_e, dim=1)
    return torch.cat([w, sc_e, diffs, sph_e], dim=1)
