import torch
import numpy as np
import logging
from scipy.special import loggamma

DAMPING_CONST = 0  # Damping constant for numerical stability in matrix inversion

class Manifold_SOn:
    def __init__(self, dim):
        self.mat_dim = dim
        self.out_dim = self.mat_dim * self.mat_dim
        self.inner_dim = int(self.mat_dim * (self.mat_dim - 1) / 2)

        self.triu_indices = torch.triu_indices(self.mat_dim, self.mat_dim)
        self.constraint_num = self.triu_indices.shape[1]
        self.grad_coeff_matrix_normalized = self.get_grad_coeff_matrix(normalized=True)
        self.grad_coeff_matrix = self.get_grad_coeff_matrix(normalized=False)

    def get_grad_coeff_matrix(self, normalized):
        mat = torch.zeros(self.constraint_num, self.mat_dim, self.mat_dim)

        alpha = np.sqrt(0.5)
        for idx in range(self.constraint_num):
            i, j = self.triu_indices[:, idx]
            if normalized:
                if i == j:
                    mat[idx, i, j] += 1.
                else:
                    mat[idx, i, j] += alpha
                    mat[idx, j, i] += alpha
            else:
                mat[idx, i, j] += 1.0
                mat[idx, j, i] += 1.0
        return mat

    def constrain_fn(self, samples):
        samples = samples.reshape(-1, self.mat_dim, self.mat_dim)
        temp = torch.bmm(samples, torch.transpose(samples, dim0=1, dim1=2)) - torch.eye(self.mat_dim, self.mat_dim).to(samples)
        return temp[:, self.triu_indices[0, :], self.triu_indices[1, :]]

    def constrain_grad_fn(self, samples, normalized=False):
        samples = samples.reshape(-1, self.mat_dim, self.mat_dim).unsqueeze(1)
        if normalized:
            temp = torch.matmul(self.grad_coeff_matrix_normalized.to(samples), samples).flatten(
                start_dim=-2)  # use matmul for broadcasting multiplication
        else:
            temp = torch.matmul(self.grad_coeff_matrix.to(samples), samples).flatten(
                start_dim=-2)  # use matmul for broadcasting multiplication
        return temp

    def project_onto_tangent_space(self, y, base_point):
        """
        P_X(U) = (U-XU^TX)/2
        """
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        base_point = base_point.reshape(-1, self.mat_dim, self.mat_dim)
        out = (y - torch.bmm(base_point, torch.bmm(torch.transpose(y, dim0=1, dim1=2), base_point))) * 0.5
        return out.reshape(-1, self.out_dim)

    def constrain_mean_curvature_fn(self, samples):
        mats = samples.view(-1, self.mat_dim, self.mat_dim)
        H    = -(self.mat_dim - 1) * mats / 2
        return H.view(-1, self.out_dim)

    def constrain_hessian_momentum_contraction(self, samples, momentums):
        """
        Compute p^T (Hess h(X)) p for each scalar constraint h, using the non-normalized constraints.
        For SOn, constraints are the upper-triangular entries of (X X^T - I). The Hessian does not
        depend on X and, for a momentum matrix P reshaped from `momentums`, satisfies:
          - for i == j:   p^T Hess(h_{ii}) p = 2 * ||P_{i,:}||^2
          - for i <  j:   p^T Hess(h_{ij}) p = 2 * <P_{i,:}, P_{j,:}>
        Returns a tensor of shape [bsz, m] with m = number of constraints (upper-triangular entries).
        """
        B = momentums.shape[0]
        P = momentums.reshape(B, self.mat_dim, self.mat_dim)              # (B, n, n)
        # Row inner products: R[b, i, j] = <P_i, P_j>
        R = torch.bmm(P, P.transpose(1, 2))                               # (B, n, n)
        # Collect upper-triangular entries and multiply by 2 (covers both diag and off-diag cases)
        i_idx, j_idx = self.triu_indices
        term = 2.0 * R[:, i_idx, j_idx]                                   # (B, m)
        return term


    def adding_correction_decaying(self, y, base_point, delta_t, alpha, sigma_sq):
        # y.shape = [bsz, dim]
        # base_point.shape = [bsz, dim]
        # delta_t: shape [bsz, 1]
        # alpha: float

        # Compute current violation of h:
        h_val = self.constrain_fn(base_point) # shape [bsz, m]
        h_grad = self.constrain_grad_fn(base_point) # shape [bsz, m, dim]
        G_inv = torch.vmap(torch.linalg.inv)(torch.bmm(h_grad, h_grad.transpose(1,2)) + DAMPING_CONST * torch.eye(h_grad.shape[1]).to(h_grad.device)) # shape [bsz, dim, dim]

        # Compute mean_curvature H
        mean_curvature = self.constrain_mean_curvature_fn(base_point) # shape [bsz, dim]
        # mean_curvature = torch.zeros_like(mean_curvature) 

        # Compute the decaying term
        decaying_term = - alpha * torch.bmm(torch.bmm(h_grad.transpose(1,2), G_inv), h_val.unsqueeze(-1)).squeeze() # shape [bsz, dim]

        return base_point + y + (decaying_term + mean_curvature) * sigma_sq * torch.abs(delta_t) / 2# take absolute value of delta_t (for reverse process) 

    def adding_correction_decaying_momentum(self, y, base_point, base_momentum, delta_t, alpha, sigma_sq, mass):
        """
        Momentum-space decaying correction on SO(n), following the sphere code but generalized
        to multiple constraints with the metric induced by h_grad.

        y              : (B, dim) tangent vector (ambient coords)
        base_point     : (B, dim) current point (vectorized n×n)
        base_momentum  : (B, dim) ambient momentum (vectorized n×n)
        delta_t        : (B, 1) time step
        alpha, sigma_sq, mass : scalars
        """
        # Constraint value and gradient
        h_val  = self.constrain_fn(base_point)                            # (B, m)
        h_grad = self.constrain_grad_fn(base_point)                       # (B, m, dim)

        # Gram matrix in constraint space (regularized) and its inverse
        m = h_grad.shape[1]
        I_m = torch.eye(m, device=h_grad.device, dtype=h_grad.dtype)
        G = torch.bmm(h_grad, h_grad.transpose(1, 2)) + DAMPING_CONST * I_m  # (B, m, m)
        G_inv = torch.vmap(torch.linalg.inv)(G)                              # (B, m, m)

        # Pieces appearing in the scalar constraint ODEs
        # s_i = <∇h_i, p>
        s = torch.bmm(h_grad, base_momentum.unsqueeze(-1)).squeeze(-1)    # (B, m)

        # p^T Hess(h_i) p
        term1 = self.constrain_hessian_momentum_contraction(base_point, base_momentum)  # (B, m)

        # Combine terms per-constraint (kept consistent with the sphere version)
        tau = (term1 / (mass ** 2)) + (2.0 * alpha / mass) * s + (alpha ** 2) * h_val   # (B, m)
        # tau = (alpha ** 2) * h_val   # (B, m)


        # Map back to ambient space via h_grad^T G^{-1} (multi-constraint generalization)
        lam = torch.bmm(G_inv, tau.unsqueeze(-1)).squeeze(-1)              # (B, m)
        decaying_vec = - mass * torch.bmm(h_grad.transpose(1, 2), lam.unsqueeze(-1)).squeeze(-1)  # (B, dim)

        # Return momentum correction (no abs on delta_t to mirror sphere code semantics)

        return y + decaying_vec * sigma_sq * delta_t / 8.0 # No torch.abs => sign is already reflected in the dynamics

    @torch.no_grad()
    def force_project_SO(self, y):
        """
        Batch projection onto SO(n) by Procrustes (closest in Frobenius norm).
        y: (B, n*n) or (B, n, n)
        returns: (B, n*n)
        """
        X = y.reshape(-1, self.mat_dim, self.mat_dim)                       # (B, n, n)
        # SVD-based polar projection
        U, _, Vt = torch.linalg.svd(X, full_matrices=False)                  # (B, n, n)
        R = U @ Vt                                                           # in O(n)
        # Enforce det=+1 by flipping the last singular direction when needed
        detR = torch.det(R)                                                  # (B,)
        s = torch.where(detR >= 0, torch.ones_like(detR), -torch.ones_like(detR))  # (B,)
        D = torch.ones(R.shape[0], self.mat_dim, device=R.device, dtype=R.dtype)
        D[:, -1] = s                                                         # last diag entry is +/-1
        R = U @ torch.diag_embed(D) @ Vt                                     # in SO(n)
        return R.reshape(-1, self.out_dim)


    # @torch.no_grad()
    # def project_onto_manifold_with_base(self, y, base_point, threshold=1e-6, n_iters=10, **kwargs):

    #     keep_quiet = kwargs["keep_quiet"] if "keep_quiet" in kwargs else True
        
    #     grad_vec = self.constrain_grad_fn(base_point, normalized=False)
    #     mu = torch.zeros((y.shape[0], self.out_dim-self.inner_dim)).to(y)
    #     active_idx = torch.arange(0, y.shape[0], dtype=torch.int64).to(y.device)

    #     for i in range(n_iters):
    #         temp = y[active_idx,:] + base_point[active_idx,:] - torch.einsum('ijk,ij->ik', grad_vec[active_idx,:], mu[active_idx,:])
    #         value = self.constrain_fn(temp)
    #         bad_idx = (value.norm(dim=1, keepdim=True) >= threshold).squeeze(dim=1)
    #         if bad_idx.sum() == 0:
    #             break
    #         active_idx = active_idx[bad_idx]

    #         mu_grad = - torch.bmm(self.constrain_grad_fn(temp[bad_idx,:], normalized=False),
    #                               grad_vec[bad_idx,:].transpose(1, 2))
    #         mu[active_idx,:] = mu[active_idx,:] - torch.linalg.solve(mu_grad, value[bad_idx, :])
        
    #     projected_pt = y + base_point - torch.einsum('ijk,ij->ik', grad_vec, mu)
    #     value = self.constrain_fn(projected_pt).abs().squeeze()

    #     non_converged_flag = torch.any((value > threshold) | (~torch.isfinite(value)), dim=1)
    #     non_converged_num = non_converged_flag.sum()

    #     projected_pt[non_converged_flag, :] = base_point[non_converged_flag, :]

    #     if not keep_quiet:
    #         logging.info(f'total steps: {i}, max_error: {value.max():.3e}, {non_converged_num} states not converged!')
    #     return projected_pt.detach(), torch.logical_not(non_converged_flag).to(y)


    @torch.no_grad()
    def project_onto_manifold_with_base(self, y, base_point, threshold=1e-6, n_iters=10, **kwargs):

        keep_quiet = kwargs.get("keep_quiet", True)
        
        # Pre-calculate the full gradient vector outside the loop for efficiency.
        grad_vec_full = self.constrain_grad_fn(base_point, normalized=False)
        
        # Initialize mu. Its second dimension must match the number of constraints.
        mu = torch.zeros((y.shape[0], self.constraint_num)).to(y)
        
        # active_idx tracks which samples still need processing.
        active_idx = torch.arange(y.shape[0], dtype=torch.int64, device=y.device)

        for i in range(n_iters):
            # If all samples have converged, exit the loop.
            if active_idx.shape[0] == 0:
                break
            
            # Select only the currently active data subsets using active_idx.
            active_y = y[active_idx, :]
            active_base_point = base_point[active_idx, :]
            active_grad_vec = grad_vec_full[active_idx, :]
            active_mu = mu[active_idx, :]
            
            # Perform the projection step for the active subset.
            temp = active_y + active_base_point - torch.einsum('ijk,ij->ik', active_grad_vec, active_mu)
            value = self.constrain_fn(temp)
            
            # Find samples within the active set where the constraint error is still above the threshold.
            bad_mask = value.norm(dim=1) >= threshold
            
            # If no samples in the active set are "bad", we are done.
            if not torch.any(bad_mask):
                break
                
            # Get the relative indices of the non-converged samples within the active set.
            bad_indices_in_active = torch.where(bad_mask)[0]
            
            # Proceed with the update step using only the non-converged ("bad") samples.
            grad_vec_bad = active_grad_vec[bad_indices_in_active, :]
            temp_bad = temp[bad_indices_in_active, :]
            value_bad = value[bad_indices_in_active, :]
            
            # Calculate the gradient for the mu update.
            mu_grad = -torch.bmm(self.constrain_grad_fn(temp_bad, normalized=False),
                                  grad_vec_bad.transpose(1, 2))
            
            # Solve for the change in mu.
            delta_mu = torch.linalg.solve(mu_grad, value_bad)
            
            # Apply the calculated delta_mu to the correct positions in the main mu tensor.
            current_mu_bad = active_mu[bad_indices_in_active, :]
            mu[active_idx[bad_indices_in_active], :] = current_mu_bad - delta_mu

            # Update active_idx for the next iteration, keeping only the indices of non-converged samples.
            active_idx = active_idx[bad_indices_in_active]

        # Calculate the final projected point using the final mu values.
        projected_pt = y + base_point - torch.einsum('ijk,ij->ik', grad_vec_full, mu)
        value = self.constrain_fn(projected_pt).abs()

        # Identify any samples that failed to converge.
        non_converged_flag = torch.any((value > threshold) | (~torch.isfinite(value)), dim=1)
        non_converged_num = non_converged_flag.sum()

        # For non-converged points, revert to the base_point as a fallback.
        projected_pt[non_converged_flag, :] = base_point[non_converged_flag, :]

        if not keep_quiet:
            logging.info(f'total steps: {i+1}, max_error: {value.max().item():.3e}, {non_converged_num} states not converged!')
            
        return projected_pt.detach(), torch.logical_not(non_converged_flag).to(y)

    def uniform_sample(self, sample_num):
        """
        Ensure the matrices are in the correct component
        """
        sample = torch.tensor([])
        while sample.shape[0] < sample_num:
            Z = torch.randn(sample_num, self.mat_dim, self.mat_dim)
            idx1 = torch.where(torch.linalg.det(Z).abs() > 1e-4)[0]
            Q, R = torch.linalg.qr(Z[idx1], mode="complete")
            # Z = QR
            diag = torch.diag_embed(R.diagonal(dim1=-2, dim2=-1).sign())
            Q = torch.bmm(Q, diag)
            idx2 = torch.where(torch.linalg.det(Q) > 0)[0]
            sample = torch.cat((sample, Q[idx2]), dim=0)
        return sample[:sample_num].reshape(-1, self.out_dim)

    def exp_from_identity(self, y):
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        return torch.matrix_exp(y).reshape(-1, self.out_dim)

    def exp(self, y, base_point):
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        base_point = base_point.reshape(-1, self.mat_dim, self.mat_dim)
        lie_algebra = torch.bmm(torch.transpose(base_point, dim0=1, dim1=2), y)
        temp = self.exp_from_identity(lie_algebra).reshape(-1, self.mat_dim, self.mat_dim)
        shifted = torch.bmm(base_point, temp)
        return shifted.reshape(-1, self.out_dim)

    def log_volume(self):
        """https://arxiv.org/pdf/math-ph/0210033.pdf"""
        if self.mat_dim== 2:
            return np.log(2) + np.log(np.pi)
        elif self.mat_dim == 3:
            return np.log(8) + 2 * np.log(np.pi)
        else:
            out = (self.mat_dim - 1) * np.log(2)
            out += ((self.mat_dim - 1) * (self.mat_dim + 2) / 4) * np.log(np.pi)
            k = np.expand_dims(np.arange(2, self.mat_dim + 1), axis=-1)
            # out += np.sum(np.gammaln(k / 2), axis=0)
            out += np.sum(loggamma(k / 2), axis=0)
            return out



