# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import Tensor

from cdnp.model.flow_matching.utils.manifolds import Manifold


class Sphere(Manifold):
    """Represents a hyperpshere in :math:`R^D`. Isometric to the product of 1-D spheres."""

    EPS = {torch.float32: 1e-4, torch.float64: 1e-7}

    def expmap(self, x: Tensor, u: Tensor) -> Tensor:
        norm_u = u.norm(dim=-1, keepdim=True)
        exp = x * torch.cos(norm_u) + u * torch.sin(norm_u) / norm_u
        retr = self.projx(x + u)
        cond = norm_u > self.EPS[norm_u.dtype]

        return torch.where(cond, exp, retr)

    def logmap(self, x: Tensor, y: Tensor) -> Tensor:
        u = self.proju(x, y - x)
        dist = self.dist(x, y, keepdim=True)
        cond = dist.gt(self.EPS[x.dtype])
        result = torch.where(
            cond,
            u * dist / u.norm(dim=-1, keepdim=True).clamp_min(self.EPS[x.dtype]),
            u,
        )
        return result

    def projx(self, x: Tensor) -> Tensor:
        return x / x.norm(dim=-1, keepdim=True)

    def proju(self, x: Tensor, u: Tensor) -> Tensor:
        return u - (x * u).sum(dim=-1, keepdim=True) * x

    def dist(self, x: Tensor, y: Tensor, *, keepdim=False) -> Tensor:
        inner = (x * y).sum(-1, keepdim=keepdim)
        return torch.acos(inner)
