# Sample g functions and structure
import torch
from abc import abstractmethod, ABC

# Function to test if g is differentiable, grad(g) is Lipschitz, g is concave
class g(torch.nn.Module, ABC):
    def __init__(self):
        super(g, self).__init__()

    @abstractmethod
    def forward(self, embedding):
        pass

    def get_derivative(self, xs):
        """
        Compute the derivative of g with respect to the singular values.

        Parameters:
            xs (torch.Tensor): Singular values of the embedding.

        Returns:
            torch.Tensor: Derivative of g with respect to the singular values.
        """
        # Ensure gradients are enabled for xs
        xs.requires_grad_(True)

        # Compute the gradient of g with respect to the singular values
        grad = torch.autograd.grad(self.forward(xs), xs)[0]

        return grad

    def check_monotonicity(self, start=0, end=100, step=5e-2, tol=1e-6):
        """
        Check if g is monotonically increasing.

        Parameters:
            start (float): Start of the range (must be > 0).
            end (float): End of the range (must be > start).
            step (float): Step size for generating points in the range.
            tol (float): Tolerance for identifying near-zero differences.

        Raises:
            ValueError: If the function is not monotonically increasing at any point.

        Returns:
            torch.Tensor: A tensor of booleans indicating whether the function is monotonically
                        increasing at each point:
                        - True if f(x + step) >= f(x) - tol
                        - False otherwise
        """
        # Ensure start and end are scalar values
        if not isinstance(start, (int, float)) or not isinstance(end, (int, float)):
            raise TypeError("Start and end must be scalar values (int or float).")
        
        # Ensure the range is valid
        if start < 0 or start >= end:
            raise ValueError("Start must be > 0 and end must be > start.")
        
        # Generate the range of x values as a torch tensor
        x_values = torch.arange(start, end, step)
        
        # Compute g(x) and g(x + step)
        g_x = self.forward(x_values)
        g_x_next = self.forward(x_values + step)
        
        # Check for monotonic increase
        is_increasing = (g_x_next >= g_x - tol)
        
        # Raise an error if any point violates monotonicity
        if not torch.all(is_increasing):
            raise ValueError(f"The function {type(self).__name__} is not monotonically increasing over the specified range. Please modify it to be so.")
        
        return is_increasing


    def estimate_lipschitz_constant(self, embedding, num_iterations=1000, tol=1e-6):
        """
        Estimates the Lipschitz constant of the gradient of the loss function (i.e., 
        the largest eigenvalue of the Hessian) using the Power Method. The largest eigenvalue is 
        a lower bound on the lipschitz constant

        Args:
            embedding (torch.Tensor): The input embeddings (n x d).
            num_iterations (int): Maximum number of iterations for the Power Method.
            tol (float): Convergence tolerance for eigenvalue approximation.

        Returns:
            largest_eigenvalue (float): The estimated Lipschitz constant.
        """
        # Initialize a random vector with the same shape as the embedding
        v = torch.rand_like(embedding)
        v = v / v.norm()  # Normalize the vector

        largest_eigenvalue = 0.0

        for _ in range(num_iterations):
            # Ensure gradients are enabled for embedding
            embedding.requires_grad_(True)

            # Compute the gradient of the loss with respect to the embedding
            loss = self.forward(embedding)
            grad_loss = torch.autograd.grad(loss, embedding, create_graph=True)[0]

            # Compute Hessian-vector product (HVP)
            hvp = torch.autograd.grad(grad_loss @ v.flatten(), embedding, retain_graph=True)[0]

            # Estimate the Rayleigh quotient as an approximation of the eigenvalue
            new_eigenvalue = torch.dot(v.flatten(), hvp.flatten()) / torch.dot(v.flatten(), v.flatten())

            # Normalize the resulting vector
            v_next = hvp / hvp.norm()

            # Check for convergence
            if torch.abs(new_eigenvalue - largest_eigenvalue) < tol:
                break

            largest_eigenvalue = new_eigenvalue
            v = v_next.detach()  # Detach to avoid accumulating gradients

        return largest_eigenvalue.item() + 1
