from typing import Tuple
import torch
from geoopt.manifolds.symmetric_positive_definite import SymmetricPositiveDefinite
from geoopt.linalg import batch_linalg as lalg
from manifolds.metrics import Metric, MetricType


class SPDManifold(SymmetricPositiveDefinite):
    def __init__(self, dims=2, ndim=2, metric=MetricType.RIEMANNIAN):
        super().__init__()
        self.name = 'SPD'
        self.dims = dims
        self.ndim = ndim
        self.metric = Metric.get(metric.value, self.dims)

    def sqdist(self, a: torch.Tensor, b: torch.Tensor, c=None) -> torch.Tensor:

        ##  Cholesky Metric
        try:
            L_a = torch.linalg.cholesky(a, upper=False)
        except RuntimeError:
            a = self.clamping(a)
            L_a = torch.linalg.cholesky(a, upper=False)
        try:
            L_b = torch.linalg.cholesky(b, upper=False)
        except RuntimeError:
            b = self.clamping(b)
            L_b = torch.linalg.cholesky(b, upper=False)

        R = torch.sum(((torch.tril(L_a, diagonal=-1)-torch.tril(L_b, diagonal=-1))**2), dim=(-2,-1))
        D = torch.log(torch.diagonal(L_a, dim1=-2, dim2=-1)) - torch.log(torch.diagonal(L_b, dim1=-2, dim2=-1))
        R += torch.sum(D**2, dim=(-1))

        return R

    @staticmethod
    def clamping(x):
        eigvals, eigvecs = torch.linalg.eigh(x)

        eigvals = torch.clamp(eigvals, min=1e-6)

        pos_x = eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-2, -1)

        return pos_x

    @staticmethod
    def expmap_id(x: torch.Tensor) -> torch.Tensor:
        """
        Performs an exponential map using the Identity as basepoint :math:`\operatorname{Exp}_{Id}(u)`.
        :param: x: b x n x n torch.Tensor point on the SPD manifold
        """
        return lalg.sym_funcm(x, torch.exp)

    @staticmethod
    def logmap_id(y: torch.Tensor) -> torch.Tensor:
        """
        Perform an logarithmic map using the Identity as basepoint :math:`\operatorname{Log}_{Id}(y)`.
        :param: y: b x n x n torch.Tensor point on the tangent space of the SPD manifold
        """
        return lalg.sym_funcm(y, torch.log)

    @staticmethod
    def addition_id(a: torch.Tensor, b: torch.Tensor):
        """
        Performs addition using the Identity as basepoint.

        The addition on SPD using the identity as basepoint is :math:`A \oplus_{Id} B = sqrt(A) B sqrt(A)`.

        :param a: b x n x n torch.Tensor points in the SPD manifold
        :param b: b x n x n torch.Tensor points in the SPD manifold.
        :return: b x n x n torch.Tensor points in the SPD manifold
        """
        sqrt_a = lalg.sym_sqrtm(a)
        return sqrt_a @ b @ sqrt_a

    @staticmethod
    def addition_id_from_sqrt(sqrt_a: torch.Tensor, b: torch.Tensor):
        """
        Performs addition using the Identity as basepoint.
        Assumes that sqrt_a = sqrt(A) so it does not apply the sqrt again

        The addition on SPD using the identity as basepoint is :math:`A \oplus_{Id} B = sqrt(A) B sqrt(A)`.

        :param sqrt_a: b x n x n torch.Tensor points in the SPD manifold
        :param b: b x n x n torch.Tensor points in the SPD manifold.
        :return: b x n x n torch.Tensor points in the SPD manifold
        """
        return sqrt_a @ b @ sqrt_a

    def random(self, *size, dtype=None, device=None, **kwargs) -> torch.Tensor:
        """
        Random sampling on the manifold.

        The exact implementation depends on manifold and usually does not follow all
        assumptions about uniform measure, etc.
        """
        from_ = kwargs.get("from_", -0.001)
        to = kwargs.get("to", 0.001)
        init_eps = (to - from_) / 2
        dims = self.dims
        perturbation = torch.randn((size[0], dims, dims), dtype=dtype, device=device) * init_eps
        perturbation = lalg.sym(perturbation)
        identity = torch.eye(dims).unsqueeze(0).repeat(size[0], 1, 1).to(device)
        return identity + perturbation

    def extra_repr(self) -> str:
        return f"metric={type(self.metric).__name__}"
