import numpy as np
import torch
import torch.nn.functional as F
from functorch import vjp
from tqdm import tqdm


class ImplicitManifold:
    """A manifold represented implicitly as the zero set of some smooth map

    Args:
        mdf: a manifold-defining function (the manifold is given by M = {x: mdf(x) = 0})
        device: the device on which the computations will be performed (the mdf will be moved
            to this device)
    """
    def __init__(self, mdf, device):
        self.mdf = mdf.to(device)
        self.device = device

    def train(self, optim, dataloader, *, epochs, mu=1., eta=1.):
        self.mdf.train()

        pbar = tqdm(range(epochs))
        for epoch in pbar:

            losses = []
            rank_regs = []

            for batch in dataloader:
                if (isinstance(batch, tuple) or isinstance(batch, list)) and len(batch) == 2:
                    batch, _ = batch

                optim.zero_grad()

                batch = batch.to(self.device)
                v = torch.randn(batch.shape[0], self.mdf.codom_dim).to(self.device)

                out, vjp_fn = vjp(self.mdf, batch)
                vec_jac_prod = vjp_fn(v)[0]

                f_x = out.square().mean()
                vjp_norm = torch.linalg.vector_norm(vec_jac_prod.flatten(start_dim=1), dim=1)
                v_norm = torch.linalg.vector_norm(v, dim=1)
                rank_reg = F.relu(eta - vjp_norm/v_norm).square().mean()

                loss = f_x + mu*rank_reg
                loss.backward()
                optim.step()

                losses.append(loss.detach().cpu())
                rank_regs.append(rank_reg.detach().cpu())

            pbar.set_description(
                f"[E{epoch:3d}] loss: {np.mean(losses):4.10f}, reg: {np.mean(rank_regs):4.10f}")

        self.mdf.eval()

    def project(self, x, opt_steps=100):
        """Projects `x` onto the manifold.

        Performs L-BFGS initialized at `x` and returns the result.
        """
        x = x.detach().clone()
        x.requires_grad = True
        zero_opt = torch.optim.LBFGS([x], line_search_fn='strong_wolfe')

        for epoch in range(opt_steps):
            def closure():
                zero_opt.zero_grad()
                out = self.mdf(x)
                loss = out.square().mean()
                loss.backward()
                return loss

            loss = zero_opt.step(closure)

        return x.detach()