import torch
import numpy as np
def hessian_vector_product(model, criterion, vector, data_in, data_out, create_graph=True, specified_layers=None):
    """
    Compute the Hessian-vector product.
    """
    # Forward pass
    outputs = model(data_in)
    loss = criterion(outputs, data_out)
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers) and p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]
    # Check if there are any parameters to compute gradients
    if not params:
        raise ValueError("No parameters found for gradient computation. Check your specified_layers or model parameters.")
    # First backward pass
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    # Compute grad-vector product
    grad_vector_product = sum(torch.sum(g * v) for g, v in zip(grads, vector))
    # Second backward pass
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=create_graph)
    return hvp

def flatten_and_concat(tensor_list):
    return torch.cat([t.flatten() for t in tensor_list])
def reshape_like(vector, tensor_list):
    shapes = [t.shape for t in tensor_list]
    tensor_list = []
    idx = 0
    for shape in shapes:
        n = np.prod(shape)
        tensor_list.append(vector[idx:idx+n].view(shape))
        idx += n
    return tensor_list


def lanczos(model, criterion, data_in, data_out, n, m=30):
    def hvp_func(v):
        params = model.parameters()
        v_list = reshape_like(v, params)
        # Compute Hessian-vector product
        hvp = hessian_vector_product(model, criterion, v_list, data_in, data_out)
        return flatten_and_concat(hvp)

    q = torch.randn(n, device=next(model.parameters()).device)
    q = q / torch.norm(q)
    
    alpha = torch.zeros(m, device=q.device)
    beta = torch.zeros(m-1, device=q.device)
    
    q_old = torch.zeros(n, device=q.device)
    for j in range(m):
        model.zero_grad()
        v = hvp_func(q)
        alpha[j] = torch.dot(q, v)
        
        v = v - alpha[j]*q - (beta[j-1]*q_old if j > 0 else 0)
        
        if j < m-1:
            beta[j] = torch.norm(v)
            if beta[j] < 1e-8:
                return alpha[:j+1], beta[:j]
            q_old = q
            q = v / beta[j]
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return alpha, beta

def estimate_eigenvalues_lanczos(model, criterion, data_in, data_out, m=30, return_smallest=False, return_eigenvectors=False):
    """
    Estimate eigenvalues using Lanczos algorithm.
    """

    n = sum(p.numel() for p in model.parameters())
    
    alpha, beta = lanczos(model, criterion, data_in, data_out, n, m)
    
    T = np.diag(alpha.cpu().numpy()) + np.diag(beta.cpu().numpy(), k=1) + np.diag(beta.cpu().numpy(), k=-1)
    
    if return_eigenvectors:
        eigenvalues, eigenvectors = np.linalg.eigh(T)
        largest_eig = eigenvalues[-1]
        largest_eigvec = eigenvectors[:, -1] 
        
        if return_smallest:
            smallest_eig = eigenvalues[0]
            smallest_eigvec = eigenvectors[:, 0]  
            return largest_eig, largest_eigvec, smallest_eig, smallest_eigvec
        else:
            return largest_eig, largest_eigvec
    else:
        eigenvalues = np.linalg.eigvalsh(T)
        if return_smallest:
            return eigenvalues[-1], eigenvalues[0]
        else:
            return eigenvalues[-1]