import torch


def distance_to_implicit_manifold(point, manifold, opt_steps=100, manifold_weight=1e10):
    """Estimate distance to nearest point on implicit manifold"""
    nearest_point = point.detach().clone()
    nearest_point.requires_grad = True
    opt = torch.optim.LBFGS([nearest_point], line_search_fn='strong_wolfe')
    norm_dims = tuple(range(1, point.ndim))

    # Optimize `nearest_point` to become the nearest point on manifold
    for epoch in range(opt_steps):
        def closure():
            opt.zero_grad()

            manifold_loss = torch.linalg.vector_norm(manifold.mdf(nearest_point), dim=1).square().mean()
            proximity_loss = torch.linalg.vector_norm(nearest_point - point, dim=norm_dims).square().mean()
            loss = manifold_weight*manifold_loss + proximity_loss

            loss.backward()
            return loss

        loss = opt.step(closure)

    nearest_point = nearest_point.detach()
    return torch.linalg.vector_norm(nearest_point - point, dim=norm_dims)


def distance_to_pushforward_manifold(point, autoencoder, opt_steps=100):
    """Estimate distance to nearest point on pushforward manifold"""
    nearest_latent = autoencoder.encoder(point.detach()).detach().clone()
    nearest_latent.requires_grad = True
    opt = torch.optim.LBFGS([nearest_latent], line_search_fn='strong_wolfe')
    norm_dims = tuple(range(1, point.ndim))

    # Optimize `nearest_latent` to correspond to the nearest point on manifold
    for _ in range(opt_steps):
        def closure():
            opt.zero_grad()
            loss = torch.linalg.vector_norm(
                autoencoder.decoder(nearest_latent) - point, dim=norm_dims).square().mean()
            #loss = (autoencoder.decoder(nearest_latent) - point).abs().mean()
            loss.backward()
            return loss

        loss = opt.step(closure)

    nearest_latent = nearest_latent.detach()

    with torch.no_grad():
        return torch.linalg.vector_norm(autoencoder.decoder(nearest_latent) - point,
                                        dim=norm_dims)
