# Adapted from geoopt to change the indexces order

import torch
from geoopt.manifolds.base import Manifold
from geoopt import linalg
from geoopt.utils import size2shape
from geoopt.tensor import ManifoldTensor


__all__ = ["StiefelT", "EuclideanStiefelT"]


_stiefel_doc = r"""
    Manifold induced by the following matrix constraint:

    .. math::

        X X^\top = I\\
        X \in \mathrm{R}^{m\times n}\\
        n \ge m
"""


class StiefelT(Manifold):
    __doc__ = r"""
    {}
    See Also
    --------
    :class:`EuclideanStiefel`
    """.format(
        _stiefel_doc
    )
    ndim = 2

    def __new__(cls):
        if cls is StiefelT:
            return super().__new__(EuclideanStiefelT)
        else:
            return super().__new__(cls)

    def _check_shape(
        self, shape: torch.Size, name: str
    ) -> tuple[bool, str | None] | bool:
        ok, reason = super()._check_shape(shape, name)
        if not ok:
            return False, reason
        shape_is_ok = shape[-2] <= shape[-1]
        if not shape_is_ok:
            return (
                False,
                "`{}` should have shape[-2] <= shape[-1], got {} </= {}".format(
                    name, shape[-2], shape[-1]
                ),
            )
        return True, None

    def _check_point_on_manifold(
        self, x: torch.Tensor, *, atol=1e-5, rtol=1e-5
    ) -> tuple[bool, str | None] | bool:
        xxt = x @ x.transpose(-1, -2)
        # less memory usage for substract diagonal
        xxt[..., torch.arange(x.shape[-2]), torch.arange(x.shape[-2])] -= 1
        ok = torch.allclose(xxt, xxt.new((1,)).fill_(0), atol=atol, rtol=rtol)
        if not ok:
            return False, "`X X^T != I` with atol={}, rtol={}".format(atol, rtol)
        return True, None

    def _check_vector_on_tangent(
        self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5
    ) -> tuple[bool, str | None] | bool:
        diff = u @ x.transpose(-1, -2) + x @ u.transpose(-1, -2)
        ok = torch.allclose(diff, diff.new((1,)).fill_(0), atol=atol, rtol=rtol)
        if not ok:
            return False, "`u^T x + x^T u !=0` with atol={}, rtol={}".format(atol, rtol)
        return True, None

    def projx(self, x: torch.Tensor) -> torch.Tensor:
        U, _, V = linalg.svd(x.transpose(-1, -2), full_matrices=False)
        return torch.einsum("...ik,...kj->...ij", U, V).transpose(-1, -2)

    def random_naive(self, *size, dtype=None, device=None) -> torch.Tensor:
        """
        Naive approach to get random matrix on Stiefel manifold.

        A helper function to sample a random point on the Stiefel manifold.
        The measure is non-uniform for this method, but fast to compute.

        Parameters
        ----------
        size : shape
            the desired output shape
        dtype : torch.dtype
            desired dtype
        device : torch.device
            desired device

        Returns
        -------
        ManifoldTensor
            random point on Stiefel manifold
        """
        self._assert_check_shape(size2shape(*size), "x")
        tens = torch.randn(*size, device=device, dtype=dtype)
        return ManifoldTensor(linalg.qr(tens.transpose(-1, -2))[0].transpose(-1, -2), manifold=self)

    random = random_naive

    def origin(self, *size, dtype=None, device=None, seed=42) -> torch.Tensor:
        """
        Identity matrix point origin.

        Parameters
        ----------
        size : shape
            the desired shape
        device : torch.device
            the desired device
        dtype : torch.dtype
            the desired dtype
        seed : int
            ignored

        Returns
        -------
        ManifoldTensor
        """
        self._assert_check_shape(size2shape(*size), "x")
        eye = torch.zeros(*size, dtype=dtype, device=device)
        eye[..., torch.arange(eye.shape[-2]), torch.arange(eye.shape[-2])] += 1
        return ManifoldTensor(eye, manifold=self)


class EuclideanStiefelT(StiefelT):
    __doc__ = r"""Stiefel Manifold with Euclidean inner product

    {}
    """.format(
        _stiefel_doc
    )

    name = "StiefelT(euclidean)"
    reversible = False

    def proju(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        return u - linalg.sym(x @ u.transpose(-1, -2)) @ x

    egrad2rgrad = proju

    def transp(self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        return self.proju(y, v)

    def inner(
        self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor | None = None, *, keepdim=False
    ) -> torch.Tensor:
        if v is None:
            v = u
        return (u * v).sum([-1, -2], keepdim=keepdim)

    def retr(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        q, r = linalg.qr((x + u).transpose(-1, -2))
        unflip = linalg.extract_diag(r).sign().add(0.5).sign()
        q *= unflip[..., None, :]
        return q.transpose(-1, -2)

    def expmap(self, x: torch.Tensor, u: torch.Tensor) -> torch.Tensor:
        xut = x @ u.transpose(-1, -2)
        uut = u @ u.transpose(-1, -2)
        eye = torch.zeros_like(uut)
        eye[..., torch.arange(uut.shape[-2]), torch.arange(uut.shape[-2])] += 1
        logw = linalg.block_matrix(((xut, -uut), (eye, xut)))
        w = linalg.expm(logw)
        z = torch.cat((linalg.expm(-xut), torch.zeros_like(uut)), dim=-1)
        y = z @ w @ torch.cat((x, u), dim=-2)
        return y


