# Sample f functions and structure
import torch
from abc import abstractmethod, ABC


class f(torch.nn.Module, ABC):
    def __init__(self):
        super(f, self).__init__()

    @abstractmethod
    def forward(self, embedding):
        pass

    def estimate_lipschitz_constant(self, embedding, num_iterations=1000, tol=1e-6):
        """
        Estimates the Lipschitz constant of the gradient of the loss function (i.e., 
        the largest eigenvalue of the Hessian) using the Power Method. The largest eigenvalue is 
        a lower bound on the Lipschitz constant.

        Args:
            embedding (torch.Tensor): The input embeddings (n x d).
            num_iterations (int): Maximum number of iterations for the Power Method.
            tol (float): Convergence tolerance for eigenvalue approximation.

        Returns:
            largest_eigenvalue (float): The estimated Lipschitz constant.
        """
        # Initialize a random vector with the same shape as embedding
        v = torch.rand_like(embedding)
        embedding = torch.Tensor(embedding)
        v = v / v.norm()  # Normalize the vector
        largest_eigenvalue = 0.0

        for _ in range(num_iterations):
            # Ensure gradients are enabled for embedding
            embedding.requires_grad_(True)

            # Compute the gradient of the loss with respect to embedding
            loss = self.forward(embedding)  # Scalar output
            grad_loss = torch.autograd.grad(loss, embedding, create_graph=True)[0]  # Gradient w.r.t. embedding
            # Compute Hessian-vector product (HVP)
            hvp = torch.autograd.grad((grad_loss * v).sum(), embedding, retain_graph=True)[0]

            # Estimate the Rayleigh quotient as an approximation of the eigenvalue
            new_eigenvalue = torch.dot(v.flatten(), hvp.flatten()) / torch.dot(v.flatten(), v.flatten())

            # Normalize the resulting vector
            v_next = hvp / hvp.norm()

            # Check for convergence
            if torch.abs(new_eigenvalue - largest_eigenvalue) < tol:
                break

            largest_eigenvalue = abs(new_eigenvalue)
            v = v_next.detach()  # Detach to avoid accumulating gradients
        print(f"Estimated largest eigenvalue: {largest_eigenvalue.item():.4f}")
        return largest_eigenvalue.item() + 0.05