import torch
import numpy as np
import logging
from scipy.special import loggamma


class Manifold_SOn:
    def __init__(self, dim):
        self.mat_dim = dim
        self.out_dim = self.mat_dim * self.mat_dim
        self.inner_dim = int(self.mat_dim * (self.mat_dim - 1) / 2)

        self.triu_indices = torch.triu_indices(self.mat_dim, self.mat_dim)
        self.constraint_num = self.triu_indices.shape[1]
        self.grad_coeff_matrix_normalized = self.get_grad_coeff_matrix(normalized=True)
        self.grad_coeff_matrix = self.get_grad_coeff_matrix(normalized=False)

    def get_grad_coeff_matrix(self, normalized):
        mat = torch.zeros(self.constraint_num, self.mat_dim, self.mat_dim)

        alpha = np.sqrt(0.5)
        for idx in range(self.constraint_num):
            i, j = self.triu_indices[:, idx]
            if normalized:
                if i == j:
                    mat[idx, i, j] += 1.
                else:
                    mat[idx, i, j] += alpha
                    mat[idx, j, i] += alpha
            else:
                mat[idx, i, j] += 1.0
                mat[idx, j, i] += 1.0
        return mat

    def constrain_fn(self, samples):
        samples = samples.reshape(-1, self.mat_dim, self.mat_dim)
        temp = torch.bmm(samples, torch.transpose(samples, dim0=1, dim1=2)) - torch.eye(self.mat_dim, self.mat_dim).to(samples)
        return temp[:, self.triu_indices[0, :], self.triu_indices[1, :]]

    def constrain_grad_fn(self, samples, normalized=False):
        samples = samples.reshape(-1, self.mat_dim, self.mat_dim).unsqueeze(1)
        if normalized:
            temp = torch.matmul(self.grad_coeff_matrix_normalized.to(samples), samples).flatten(
                start_dim=-2)  # use matmul for broadcasting multiplication
        else:
            temp = torch.matmul(self.grad_coeff_matrix.to(samples), samples).flatten(
                start_dim=-2)  # use matmul for broadcasting multiplication
        return temp

    def project_onto_tangent_space(self, y, base_point):
        """
        P_X(U) = (U-XU^TX)/2
        """
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        base_point = base_point.reshape(-1, self.mat_dim, self.mat_dim)
        out = (y - torch.bmm(base_point, torch.bmm(torch.transpose(y, dim0=1, dim1=2), base_point))) * 0.5
        return out.reshape(-1, self.out_dim)

    def project_onto_manifold(self, y):
        """
        P(X)=UV^T if X = UDV^T
        """
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        # full_matrices = False
        U, S, VT = torch.linalg.svd(y)
        # torch.bmm(U, torch.bmm(torch.diag_embed(S), Vh)) - y should be skew-symmetric
        return torch.bmm(U, VT).reshape(-1, self.out_dim)

    @torch.no_grad()
    def project_onto_manifold_with_base(self, y, base_point, threshold=1e-6, n_iters=10, **kwargs):

        keep_quiet = kwargs["keep_quiet"] if "keep_quiet" in kwargs else True
        
        grad_vec = self.constrain_grad_fn(base_point, normalized=False)
        mu = torch.zeros((y.shape[0], self.out_dim-self.inner_dim)).to(y)
        active_idx = torch.arange(0, y.shape[0], dtype=torch.int64).to(y.device)

        for i in range(n_iters):
            temp = y[active_idx,:] + base_point[active_idx,:] - torch.einsum('ijk,ij->ik', grad_vec[active_idx,:], mu[active_idx,:])
            value = self.constrain_fn(temp)
            bad_idx = (value.norm(dim=1, keepdim=True) >= threshold).squeeze(dim=1)
            if bad_idx.sum() == 0:
                break
            active_idx = active_idx[bad_idx]
            mu_grad = - torch.bmm(self.constrain_grad_fn(temp[bad_idx,:], normalized=False),
                                  grad_vec[bad_idx,:].transpose(1, 2))
            mu[active_idx,:] = mu[active_idx,:] - torch.linalg.solve(mu_grad, value[bad_idx, :])
        
        projected_pt = y + base_point - torch.einsum('ijk,ij->ik', grad_vec, mu)
        value = self.constrain_fn(projected_pt).abs().squeeze()

        non_converged_flag = torch.any((value > threshold) | (~torch.isfinite(value)), dim=1)
        non_converged_num = non_converged_flag.sum()

        projected_pt[non_converged_flag, :] = base_point[non_converged_flag, :]

        if not keep_quiet:
            logging.info(f'total steps: {i}, max_error: {value.max():.3e}, {non_converged_num} states not converged!')
        return projected_pt.detach(), torch.logical_not(non_converged_flag).to(y)

    def uniform_sample(self, sample_num):
        """
        Ensure the matrices are in the correct component
        """
        sample = torch.tensor([])
        while sample.shape[0] < sample_num:
            Z = torch.randn(sample_num, self.mat_dim, self.mat_dim)
            idx1 = torch.where(torch.linalg.det(Z).abs() > 1e-4)[0]
            Q, R = torch.linalg.qr(Z[idx1], mode="complete")
            # Z = QR
            diag = torch.diag_embed(R.diagonal(dim1=-2, dim2=-1).sign())
            Q = torch.bmm(Q, diag)
            idx2 = torch.where(torch.linalg.det(Q) > 0)[0]
            sample = torch.cat((sample, Q[idx2]), dim=0)
        return sample[:sample_num].reshape(-1, self.out_dim)

    def exp_from_identity(self, y):
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        return torch.matrix_exp(y).reshape(-1, self.out_dim)

    def exp(self, y, base_point):
        y = y.reshape(-1, self.mat_dim, self.mat_dim)
        base_point = base_point.reshape(-1, self.mat_dim, self.mat_dim)
        lie_algebra = torch.bmm(torch.transpose(base_point, dim0=1, dim1=2), y)
        temp = self.exp_from_identity(lie_algebra).reshape(-1, self.mat_dim, self.mat_dim)
        shifted = torch.bmm(base_point, temp)
        return shifted.reshape(-1, self.out_dim)

    def log_volume(self):
        """https://arxiv.org/pdf/math-ph/0210033.pdf"""
        if self.mat_dim== 2:
            return np.log(2) + np.log(np.pi)
        elif self.mat_dim == 3:
            return np.log(8) + 2 * np.log(np.pi)
        else:
            out = (self.mat_dim - 1) * np.log(2)
            out += ((self.mat_dim - 1) * (self.mat_dim + 2) / 4) * np.log(np.pi)
            k = np.expand_dims(np.arange(2, self.mat_dim + 1), axis=-1)
            # out += np.sum(np.gammaln(k / 2), axis=0)
            out += np.sum(loggamma(k / 2), axis=0)
            return out



