import torch
import numpy as np
import logging
from scipy.special import loggamma


def latlon_to_xyz(data):
    """
    :data[:, 0]: [-90, 90], latitude
    :data[:, 1]: [-180, 180], longitude
    
    :return: 3D point in S2
    :theta: [0, pi]
    :phi: [-pi, pi]
    """
    theta = (90 - data[:, 0]) * np.pi / 180
    phi = (data[:, 1]) * np.pi / 180

    if isinstance(theta, torch.Tensor):
        lib = torch
        concatenate = torch.cat
    else:
        lib = np
        concatenate = np.concatenate

    x = lib.sin(theta) * lib.cos(phi)
    y = lib.sin(theta) * lib.sin(phi)
    z = lib.cos(theta)
    return concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], 1)


def xyz_to_latlon(points):
    """
    points: 3D point in S2
    """
    x = points[:, 0]
    y = points[:, 1]
    z = points[:, 2]

    if isinstance(points, torch.Tensor):
        acos = torch.acos
        atan2 = torch.atan2
    else:
        acos = np.arccos
        atan2 = np.arctan2
    
    # Check for NaN values in z
    if isinstance(points, torch.Tensor):
        if torch.isnan(z).any():
            logging.warning("NaN values detected in z coordinates")
    else:
        if np.isnan(z).any():
            logging.warning("NaN values detected in z coordinates")

    theta = acos(z)
    phi = atan2(y, x)
    lat, lon = 90 - theta * 180 / np.pi, phi * 180 / np.pi
    return lat, lon


class Manifold_Sphere: # dim = 2, out_dim = 3
    def __init__(self, dim):
        self.out_dim = dim + 1
        self.inner_dim = dim

    def constrain_fn(self, samples):
        return samples.norm(dim=1, keepdim=True) - 1

    def constrain_grad_fn(self, samples):
        return samples/samples.norm(dim=1, keepdim=True)
    
    def constrain_mean_curvature_fn(self, samples):
        # Computing the mean curvature term (\mathcal{H}(x)) of the sphere 
        return -self.inner_dim * samples / (samples.norm(dim=1, keepdim=True)**2)
    
    def constrain_hessian_momentum_contraction(self, samples, momentums):
        # Computes p^T Hess_h(x) p
        # For the sphere, h(x) = ||x|| - 1
        # The Hessian of h(x) is Hess_h(x) = (I - (x*x^T)/||x||^2) / ||x||
        # The contraction p^T Hess_h(x) p = (||p||^2 - (p^T x)^2 / ||x||^2) / ||x||
        
        norm_x = samples.norm(dim=1, keepdim=True)
        p_dot_x = torch.sum(momentums * samples, dim=1, keepdim=True)
        norm_p_sq = torch.sum(momentums**2, dim=1, keepdim=True)
        
        hessian_contraction = (norm_p_sq - p_dot_x**2 / norm_x**2) / norm_x
        return hessian_contraction

    def project_onto_tangent_space(self, y, base_point, **kwargs):
        coeff = torch.sum(y * base_point, dim=1, keepdim=True) / (base_point**2).sum(dim=1, keepdim=True)
        return y - coeff * base_point

    def project_onto_manifold(self, y):
        return y / y.norm(dim=1, keepdim=True)
    
    def adding_correction_decaying(self, y, base_point, delta_t, alpha, sigma_sq):
        # y.shape = [bsz, dim]
        # base_point.shape = [bsz, dim]
        # delta_t: shape [bsz, 1]
        # alpha: float

        # Compute current violation of h:
        h_val = self.constrain_fn(base_point) # shape [bsz, 1]
        h_grad = self.constrain_grad_fn(base_point) # shape [bsz, dim]
        # Skip computing G^{-1} since G = 1 for sphere.

        # Compute mean_curvature H
        mean_curvature = self.constrain_mean_curvature_fn(base_point) # shape [bsz, dim]
        mean_curvature = torch.zeros_like(mean_curvature) 

        # Compute the decaying term
        decaying_term = - alpha * h_grad * h_val.reshape(-1, 1) # shape [bsz, dim]

        return base_point + y + (decaying_term + mean_curvature) * sigma_sq * torch.abs(delta_t) # take absolute value of delta_t (for reverse process) 

    def adding_correction_decaying_momentum(self, y, base_point, base_momentum, delta_t, alpha, sigma_sq, mass = 1.0):
        # y : tangent vector shape [bsz, dim]

        h_val = self.constrain_fn(base_point)  # shape [bsz, 1]
        h_grad = self.constrain_grad_fn(base_point)  # shape [bsz, dim]
        # Skip computing G^{-1} since G = 1 for sphere.

        term_1 = 1 / (mass ** 2) * self.constrain_hessian_momentum_contraction(base_point, base_momentum)
        term_2 = 2 * alpha / mass * torch.sum(h_grad * base_momentum, dim=-1, keepdim=True)
        term_3 = (alpha ** 2) * h_val
        decaying_term = - mass * torch.sum(h_grad * (term_1 + term_2 + term_3), dim=-1, keepdim=True)

        return base_point + y + decaying_term * (sigma_sq * delta_t) ** 2  # (Remove delta_t, for reverse process)

    def project_onto_manifold_with_base(self, y, base_point):
        """
        Proj(x+v/(1-|v|^2)^(1/2))
        """
        if (y.norm(dim=1) > 1).any():
            bad_idx = torch.where(y.norm(dim=1) > 1)[0]
            logging.info(f'Warning: index {bad_idx.detach().cpu()} of v can not be projected! The max norm of v: {y.norm(dim=1).max():.4f}.')
            converged_flag =(y.norm(dim=1) < 1)
            y[bad_idx, :] = y[bad_idx, :] * 0.99 / y[bad_idx, :].norm(dim=1).max()
        else:
            converged_flag = torch.ones(y.shape[0], dtype=torch.bool)

        temp = base_point + y/torch.sqrt(1-(y**2).sum(dim=1, keepdim=True))
        return temp / temp.norm(dim=1, keepdim=True), converged_flag.to(y)

    def uniform_sample(self, sample_num): # sample_num is the number of points to sample
        point = torch.randn((sample_num, self.out_dim)) # generate random points in R^(dim+1)
        return point / (point.norm(dim=1, keepdim=True) + 1e-6) # normalize the points to lie on the sphere

    def exp(self, y, base_point):
        norm = y.norm(dim=1, keepdim=True)
        return torch.cos(norm) * base_point + torch.sin(norm) * y / norm

    def log_volume(self, sample = None):
        # sample: [bsz, dim+1] (optional) for OLLA
        """log area of n-sphere https://en.wikipedia.org/wiki/N-sphere#Closed_forms"""
        half_dim = (self.inner_dim + 1) / 2
        return np.log(2) + half_dim * np.log(np.pi) - loggamma(half_dim)


    # ------------------------------------------------------------------
    # 1.  orthonormal basis U_x  ∈  R^{(n+1)×n}
    # ------------------------------------------------------------------
    def tangent_basis(self, base_point: torch.Tensor) -> torch.Tensor:
        """
        Args
        ----
        base_point : (B, out_dim) tensor of unit-norm points on S^n

        Returns
        -------
        U_x : (B, out_dim, inner_dim) tensor whose columns form an ONB of T_x S^n
        """
        # ensure unit length (just in case)
        base_point = base_point / base_point.norm(dim=1, keepdim=True)

        B, d = base_point.shape                     # d = n+1
        I = torch.eye(d, device=base_point.device)
        # Build A = [x | I] in a batched way
        A = I.unsqueeze(0).repeat(B, 1, 1).clone()
        A[:, :, 0] = base_point                    # replace first column by x
        Q, _ = torch.linalg.qr(A)                  # batched QR
        return Q[:, :, 1:]                         # drop the x-column

    # ------------------------------------------------------------------
    # 2.  ξ  (local coords)  →  ambient tangent vector
    # ------------------------------------------------------------------
    def local_to_ambient(self,
                         xi: torch.Tensor,         # (B, n)
                         base_point: torch.Tensor  # (B, n+1)
                         ) -> torch.Tensor:        # (B, n+1)
        """
        v = U_x ξ   with U_x from tangent_basis
        """
        U = self.tangent_basis(base_point)                 # (B, n+1, n)
        return torch.bmm(U, xi.unsqueeze(-1)).squeeze(-1)  # (B, n+1)
