import torch
import logging


def dihedral(x):
    if not isinstance(x, torch.Tensor): x = torch.tensor(x).float()

    r12 = x[:, 1, :] - x[:, 0, :]
    r23 = x[:, 2, :] - x[:, 1, :]
    r34 = x[:, 3, :] - x[:, 2, :]
    # n1 = torch.cross(r12, r23)
    # n2 = torch.cross(r23, r34)
    n1 = torch.linalg.cross(r12, r23)
    n2 = torch.linalg.cross(r23, r34)
    cos_phi = (n1*n2).sum(dim=1, keepdim=True)
    sin_phi = (n1*r34).sum(dim=1, keepdim=True) * torch.norm(r23, dim=1, keepdim=True)

    return torch.atan2(sin_phi, cos_phi)


class Manifold_MD:
    def __init__(self):
        self.out_dim = 30

    def angle_phi(self, x):
        atom_indices = [1, 3, 4, 6]
        return dihedral(x[:, atom_indices,:])

    def angle_psi(self, x):
        atom_indices = [3, 4, 6, 8]
        return dihedral(x[:,atom_indices,:])

    def constrain_fn(self, samples):
        return self.angle_phi(samples) - (- 70 / 180 * torch.pi)

    @torch.enable_grad()
    def constrain_grad_fn(self, samples):
        samples.requires_grad_(True)
        gradients = torch.autograd.grad(
            outputs=self.constrain_fn(samples).sum(),
            inputs=samples,
            create_graph=True,
            retain_graph=True)[0]
        return gradients.detach()

    def project_onto_tangent_space(self, y, base_point):
        norm_vec = self.constrain_grad_fn(base_point)
        coeff = torch.sum(y * norm_vec, dim=(1, 2)) / (norm_vec**2).sum(dim=(1, 2))
        return y - coeff.reshape(-1, 1, 1) * norm_vec

    def project_onto_manifold(self, y):
        raise NotImplementedError

    @torch.no_grad()
    def project_onto_manifold_with_base(self, y, base_point, threshold=1e-5, n_iters=10, **kwargs):
        """
            Here, y is a tangent vector.
            Find mu such that xi(y + base_point + mu grad xi(base_point)) = 0.
        """
        keep_quiet = kwargs["keep_quiet"] if "keep_quiet" in kwargs else True

        grad_vec = self.constrain_grad_fn(base_point)
        mu = torch.zeros(y.shape[0], 1, 1).to(y)
        active_idx = torch.arange(0, y.shape[0], dtype=torch.int64).to(y.device)

        for i in range(n_iters):
            temp = y[active_idx,:,:] + base_point[active_idx,:,:] - grad_vec[active_idx,:,:] * mu[active_idx,:,:]
            value = self.constrain_fn(temp)
            bad_idx = (value.abs() >= threshold).squeeze(dim=1)
            if bad_idx.sum() == 0:
                break
            active_idx = active_idx[bad_idx]
            mu_grad = - (self.constrain_grad_fn(temp[bad_idx,:,:]) * grad_vec[active_idx,:,:]).sum(dim=(1, 2))
            mu[active_idx,:,:] = mu[active_idx,:,:] - value.reshape(-1, 1, 1)[bad_idx,:,:] / mu_grad.reshape(-1, 1, 1)

        projected_pt = y + base_point - grad_vec * mu
        value = self.constrain_fn(projected_pt).abs().squeeze()

        non_converged_flag = (value > threshold) | (~torch.isfinite(value))
        non_converged_num = non_converged_flag.sum() 

        # restore the previous states if not converged!
        projected_pt[non_converged_flag] = base_point[non_converged_flag]

        if not keep_quiet:
            logging.info(f'total steps: {i}, max_error: {value.max():.3e}, {non_converged_num} states not converged!')
        return projected_pt.detach(), torch.logical_not(non_converged_flag).to(y)

    def uniform_sample(self, sample_num):
        raise NotImplementedError

    def exp(self, y, base_point):
        raise NotImplementedError
    
    def log_volume(self):
        raise NotImplementedError




