import torch
import torch.nn as nn
from tensordict import TensorDict

from src.hyperbolic_math.src.manifolds import Hyperboloid, PoincareBall


def embed_data(
    encoder: nn.Module, data: TensorDict, manifold: PoincareBall | Hyperboloid
) -> tuple[TensorDict, TensorDict, TensorDict, TensorDict]:
    """Embeds the data in the trajectory tensordict using a given encoder and manifold."""
    euclidean_data = TensorDict({})
    poincare_data = TensorDict({})
    tangent_embeddings = TensorDict({})
    hyp_embeddings = TensorDict({})

    for key, trajectory in data.items():
        euclidean_embedding = encoder(trajectory)
        if isinstance(manifold, PoincareBall):
            poincare_embedding = manifold.expmap_0(euclidean_embedding)
            tangent_embeddings[key] = euclidean_embedding
            hyp_embeddings[key] = poincare_embedding
        if isinstance(manifold, Hyperboloid):
            tangent_embedding = torch.cat([torch.zeros_like(euclidean_embedding[..., :1]), euclidean_embedding], dim=-1)
            tangent_embeddings[key] = tangent_embedding
            hyp_embedding = manifold.expmap_0(tangent_embedding)
            hyp_embeddings[key] = hyp_embedding
            poincare_embedding = manifold.to_poincare(hyp_embedding)

        euclidean_data[key] = euclidean_embedding
        poincare_data[key] = poincare_embedding

    return euclidean_data, tangent_embeddings, hyp_embeddings, poincare_data
