import torch
import torch.nn as nn
import numpy as np


def hv_gn(model, X, y, v, device):
    hv = torch.zeros_like(v, device=device)
    n = X.shape[0]
    for i in range(n):

        f_i = model(X[i:i+1])
        grad_i = torch.autograd.grad(f_i, model.parameters(), create_graph=True, retain_graph=True)
        grad_i = torch.cat([g.contiguous().view(-1) for g in grad_i])
        inner = torch.dot(grad_i, v)
        hv += grad_i * inner
    hv /= n
    return 2 * hv  


def hv_r(model, X, y, v, device):
    hv = torch.zeros_like(v, device=device)
    n = X.shape[0]
    for i in range(n):
        f_i = model(X[i:i+1])
        r_i = (f_i - y[i:i+1]).squeeze()
        grad_i = torch.autograd.grad(f_i, model.parameters(), create_graph=True, retain_graph=True)
        grad_i = torch.cat([g.contiguous().view(-1) for g in grad_i])
        hv_i = torch.autograd.grad(grad_i, model.parameters(), grad_outputs=v, retain_graph=True)
        hv_i = torch.cat([h.contiguous().view(-1) for h in hv_i])
        hv += r_i * hv_i
    hv /= n
    return 2 * hv

#############################################

#############################################

def power_iteration_total(model, X, y, vt=None, num_iters=50, eps=1e-8, tol=1e-6, device='cpu', v_initial=None, diag_initial=None):
    model.eval()
    criterion = nn.MSELoss()
    

    X, y = X.to(device), y.to(device)
    vt = vt.to(device) if vt is not None else None
    

    y_pred = model(X)
    loss = criterion(y_pred, y)
    

    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grads = torch.cat([g.contiguous().view(-1) for g in grads]).to(device)
    

    def _init_vector(initial, default):
        if initial is not None:
            v = initial.detach().clone().to(device)
            return v / torch.norm(v)
        return default / torch.norm(default)
    

    v_H = _init_vector(v_initial, torch.randn_like(grads, device=device))
    

    compute_diag = vt is not None
    v_diagH = None
    if compute_diag:
        v_diagH = _init_vector(diag_initial, torch.randn_like(grads, device=device))
    
 
    max_eigen_H = 0.0
    max_eigen_diagvH = 0.0 if compute_diag else None
    prev_values = []

    for _ in range(num_iters):

        Hv_H_grads = torch.autograd.grad(grads, model.parameters(), grad_outputs=v_H, retain_graph=True)
        Hv_H = torch.cat([g.contiguous().view(-1) for g in Hv_H_grads]).to(device)
        hv_h_norm = torch.norm(Hv_H)
        if hv_h_norm > 0:
            v_H = Hv_H / hv_h_norm
            max_eigen_H = hv_h_norm.item()


        if compute_diag:
            Hv_diagH_grads = torch.autograd.grad(grads, model.parameters(), grad_outputs=v_diagH, retain_graph=True)
            Hv_diagH = torch.cat([g.contiguous().view(-1) for g in Hv_diagH_grads]).to(device)
            Hv_diagH_scaled = Hv_diagH / (torch.sqrt(vt) + eps)
            hv_diagh_norm = torch.norm(Hv_diagH_scaled)
            if hv_diagh_norm > 0:
                v_diagH = Hv_diagH_scaled / hv_diagh_norm
                max_eigen_diagvH = hv_diagh_norm.item()


        current_values = [max_eigen_H]
        if compute_diag:
            current_values.append(max_eigen_diagvH)
        
        if prev_values and all(abs(c - p) < tol for c, p in zip(current_values, prev_values)):
            break
        prev_values = current_values


    result_H = (max_eigen_H, v_H.detach())
    result_diag = (max_eigen_diagvH, v_diagH.detach()) if compute_diag else (None, None)
    
    return result_H, result_diag

def hessian_vector_product(model, criterion, X, y, vector, create_graph=True, device='cpu', specified_layers=None):
    # Forward pass
    model.eval()
    criterion = nn.MSELoss()

    print(X.shape)
    X, y = X.to(device), y.to(device)
    model.to(device)

    y_pred = model(X)
    loss = criterion(y_pred, y)
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers) and p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]
    # Check if there are any parameters to compute gradients
    if not params:
        raise ValueError("No parameters found for gradient computation. Check your specified_layers or model parameters.")
    # First backward pass
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    # Compute grad-vector product
    grad_vector_product = sum(torch.sum(g * v) for g, v in zip(grads, vector))
    # Second backward pass
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=create_graph)
    return hvp

def hessian_vector_product_with_diagonal(model, criterion, X, y, vector, precond=None, create_graph=True, device='cpu', specified_layers=None):
    # Forward pass
    model.eval()
    criterion = nn.MSELoss()

    X, y = X.to(device), y.to(device)
    # precond = precond.to(device) if precond is not None else None

    y_pred = model(X)
    loss = criterion(y_pred, y)
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers) and p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]
    # Check if there are any parameters to compute gradients
    if not params:
        raise ValueError("No parameters found for gradient computation. Check your specified_layers or model parameters.")
    # First backward pass
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    # Compute grad-vector product
    grad_vector_product = sum(torch.sum(g * v) for g, v in zip(grads, vector))
    # Second backward pass
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=create_graph)
    # Apply diagonal matrix to vector
    scaled_hvp = [v * d for v, d in zip(hvp, precond)]
    return scaled_hvp


def compute_rayleigh_quotient(hvp, v):
    numerator = sum((h * v).sum() for h, v in zip(hvp, v))
    denominator = sum((v * v).sum() for v in v)
    return numerator / (denominator + 1e-12)  

def compute_rayleigh_quotient_update(hvp, v, grads):
    numerator = sum((h * v).sum() for h, v in zip(hvp, v))
    # denominator = sum((v * v).sum() for v in v)
    denominator = sum((g * v).sum() for g, v in zip(grads, v))
    return numerator / (denominator + 1e-12)  

# def compute_grad_hessian_lambad(model, criterion, X, y, vector, optimizer, precond=None, create_graph=True, device='cpu', specified_layers=None):
#     preconditioner = compute_adam_preconditioner(optimizer)
#     if precond is not None:
#         hvp = hessian_vector_product_with_diagonal(model, criterion, X, y, vector, precond=precond, create_graph=create_graph, device=device, specified_layers=specified_layers)
#     else:
#         hvp = hessian_vector_product(model, criterion, X, y, vector, create_graph=create_graph, device=device, specified_layers=specified_layers)
    
#     eigenvalue = compute_rayleigh_quotient(hvp, grads)

def compute_grad_hessian_lambda(loss, params, gradients, vt=None, eps=1e-8):

    if gradients is None:
        return None
    

    grad_flat = torch.cat([g.contiguous().view(-1) for g in gradients])
    

    grad_norm_sq = torch.dot(grad_flat, grad_flat)
    if grad_norm_sq.item() < 1e-12:
        return 0.0


    H_grad = hessian_vector_product(loss, params, gradients)


    if vt is not None:
        H_grad = [hvi / (vt_i) for hvi, vt_i in zip(H_grad, vt)]

 
    H_grad_flat = torch.cat([h.contiguous().view(-1) for h in H_grad])


    hessian_value = torch.dot(grad_flat, H_grad_flat) / grad_norm_sq

    return hessian_value.item()


def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):
    '''Compute the gradient of outputs with respect to inputs'''
    '''outputs: a scalar'''
    '''inputs: a list of tensors'''
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(outputs, inputs, grad_outputs,
                                allow_unused=True,
                                retain_graph=retain_graph,
                                create_graph=create_graph)
    grads = [x if x is not None else torch.zeros_like(
        y) for x, y in zip(grads, inputs)]
    return torch.cat([x.contiguous().view(-1) for x in grads])
    
    
def hessian(output, inputs, out=None, allow_unused=False, create_graph=False):
    '''Compute the Hessian of output with respect to inputs'''
    '''output: a scalar'''
    '''inputs: a list of tensors'''
    '''## l = loss(net(X), y)  A=hessian(l, net.parameters())'''
    #     assert output.ndimension() == 0
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    n = sum(p.numel() for p in inputs)
    if out is None:
        out = output.new_zeros(n, n)

    ai = 0
    for i, inp in enumerate(inputs):
        [grad] = torch.autograd.grad(
        output, inp, create_graph=True, allow_unused=allow_unused)
        grad = torch.zeros_like(inp) if grad is None else grad
        grad = grad.contiguous().view(-1)
        for j in range(inp.numel()):
            if grad[j].requires_grad:
                row = gradient(
                grad[j], inputs[i:], retain_graph=True, create_graph=create_graph)[j:]
            else:
                row = grad[j].new_zeros(sum(x.numel() for x in inputs[i:]) - j)

            out[ai, ai:].add_(row.type_as(out))  # ai's row
            if ai + 1 < n:
                out[ai + 1:, ai].add_(row[1:].type_as(out))  # ai's column
            del row
            ai += 1
        del grad
    return out

# def compute_hessian(model, criterion, X, y):
#     model.zero_grad()
#     loss = criterion(model(X), y)
#     grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    
#     hessian = []
#     for grad in grads:
#         grad = grad.view(-1) 
#         hessian_row = []
#         for g in grad:
#             grad2 = torch.autograd.grad(g, model.parameters(), retain_graph=True)
#             hessian_row.append(torch.cat([g2.view(-1) for g2 in grad2]))
#         hessian.append(torch.stack(hessian_row))
    
#     return torch.cat(hessian, dim=0)



def flatten_and_concat(tensor_list):
    return torch.cat([t.flatten() for t in tensor_list])
def reshape_like(vector, tensor_list):
    shapes = [t.shape for t in tensor_list]
    tensor_list = []
    idx = 0
    for shape in shapes:
        n = np.prod(shape)
        tensor_list.append(vector[idx:idx+n].view(shape))
        idx += n
    return tensor_list


def lanczos(model, criterion, data_in, data_out, n, m, precond=None, specified_layers=None):
    def hvp_func(v):
        if specified_layers:
            params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
        else:
            params = model.parameters()
        v_list = reshape_like(v, params)
                # Compute Hessian-vector product
        if precond is None:
            hvp = hessian_vector_product(model, criterion, v_list, data_in, data_out, specified_layers=specified_layers)
        else:
            hvp = hessian_vector_product_with_diagonal(model, criterion, v_list, data_in, data_out, precond, specified_layers=specified_layers)
        # hvp = hessian_vector_product(args, model, criterion, v_list, data_in, data_out, specified_layers=specified_layers)
        return flatten_and_concat(hvp)

    q = torch.randn(n, device=next(model.parameters()).device)
    q = q / torch.norm(q)
    
    alpha = torch.zeros(m, device=q.device)
    beta = torch.zeros(m-1, device=q.device)
    
    q_old = torch.zeros(n, device=q.device)
    for j in range(m):
        model.zero_grad()
        v = hvp_func(q)
        alpha[j] = torch.dot(q, v)
        
        v = v - alpha[j]*q - (beta[j-1]*q_old if j > 0 else 0)
        
        if j < m-1:
            beta[j] = torch.norm(v)
            if beta[j] < 1e-8:
                return alpha[:j+1], beta[:j]
            q_old = q
            q = v / beta[j]
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return alpha, beta

def estimate_eigenvalues_lanczos(model, criterion, data_in, data_out, m=30, 
                               precond=None, specified_layers=None, 
                               return_smallest=False, return_eigenvectors=False):
    """
    Estimate eigenvalues using Lanczos algorithm.
    """
    if specified_layers:
        n = sum(p.numel() for name, p in model.named_parameters() if any(layer in name for layer in specified_layers))
    else:
        n = sum(p.numel() for p in model.parameters())
    
    alpha, beta = lanczos(model, criterion, data_in, data_out, n, m, precond, specified_layers)
    
    T = np.diag(alpha.cpu().numpy()) + np.diag(beta.cpu().numpy(), k=1) + np.diag(beta.cpu().numpy(), k=-1)
    
    if return_eigenvectors:
        eigenvalues, eigenvectors = np.linalg.eigh(T)
        largest_eig = eigenvalues[-1]
        largest_eigvec = eigenvectors[:, -1]  
        
        if return_smallest:
            smallest_eig = eigenvalues[0]
            smallest_eigvec = eigenvectors[:, 0]  
            return largest_eig, largest_eigvec, smallest_eig, smallest_eigvec
        else:
            return largest_eig, largest_eigvec
    else:
        eigenvalues = np.linalg.eigvalsh(T)
        if return_smallest:
            return eigenvalues[-1], eigenvalues[0]
        else:
            return eigenvalues[-1]
        
def hessian_vector_product_flat(model, criterion, x, y, vector, vt=None, eps=1e-8):
    model.zero_grad()
    loss = criterion(model(x), y)
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    
    flat_grad = torch.cat([g.contiguous().view(-1) for g in grads])
    grad_vec_product = torch.dot(flat_grad, vector)
    
    hvp = torch.autograd.grad(grad_vec_product, model.parameters(), retain_graph=True)
    hvp_flat = torch.cat([g.contiguous().view(-1) for g in hvp])
    
    if vt is not None:
        vt = torch.cat([g.contiguous().view(-1) for g in vt])
        hvp_flat = hvp_flat / (torch.sqrt(vt) + eps)
    
    return hvp_flat

def lanczos_eigen(model, criterion, x, y, k=30, vt=None, eps=1e-8, dtype=torch.float32, return_smallest=False, return_eigenvectors=False):
    params = [p for p in model.parameters()]
    device = params[0].device
    v0 = torch.randn(sum(p.numel() for p in params), device=device, dtype=dtype)
    v0 /= torch.norm(v0)
    
    alpha = []
    beta = [0.0]
    V = [v0]  

    for i in range(k):
        w = hessian_vector_product_flat(model, criterion, x, y, V[-1], vt=vt, eps=eps)
        if i > 0:
            w = w - beta[-1] * V[-2]
        
        alpha_i = torch.dot(w, V[-1])
        alpha.append(alpha_i.item())
        w = w - alpha_i * V[-1]
        

        for v in V:
            w -= torch.dot(w, v) * v
        
        beta_i = torch.norm(w).item()
        beta.append(beta_i)
        
        if beta_i < 1e-12:
            k = len(alpha)  
            break
        
        V.append(w / beta_i)


    T = np.zeros((k, k))
    for i in range(k):
        T[i, i] = alpha[i]
        if i > 0:
            T[i, i-1] = beta[i]
            T[i-1, i] = beta[i]

    if return_eigenvectors:
        eigenvalues, eigenvectors = np.linalg.eigh(T)

        def reconstruct(vec):
            reconstructed = torch.zeros_like(V[0])
            for i in range(k):
                reconstructed += vec[i] * V[i]
            return reconstructed / torch.norm(reconstructed)  
            
        largest_idx = np.argmax(eigenvalues)
        largest_eigvec = reconstruct(eigenvectors[:, largest_idx])
        
        if return_smallest:
            smallest_idx = np.argmin(eigenvalues)
            smallest_eigvec = reconstruct(eigenvectors[:, smallest_idx])
            return eigenvalues[largest_idx], largest_eigvec, eigenvalues[smallest_idx], smallest_eigvec
        else:
            return eigenvalues[largest_idx], largest_eigvec
    else:
        eigenvalues = np.linalg.eigvalsh(T)
        if return_smallest:
            return np.max(eigenvalues), np.min(eigenvalues)
        else:
            return np.max(eigenvalues)
        
def arnoldi_eigen(model, criterion, x, y, k=30, vt=None, eps=1e-8, dtype=torch.float32, return_smallest=False, return_eigenvectors=False):
    params = [p for p in model.parameters()]
    device = params[0].device
    v0 = torch.randn(sum(p.numel() for p in params), device=device, dtype=dtype)
    v0 /= torch.norm(v0)
    
    H = np.zeros((k+1, k), dtype=np.float32)
    V = [v0]  
    
    for j in range(k):
        w = hessian_vector_product_flat(model, criterion, x, y, V[j], vt=vt, eps=eps)
        for i in range(j+1):
            H[i, j] = np.dot(w, V[i])
            w -= H[i, j] * V[i]
        H[j+1, j] = np.linalg.norm(w)
        if H[j+1, j] < eps:
            break
        V.append(w / H[j+1, j])
    

    m = j + 1
    Hm = H[:m, :m]
    
 
    eigenvalues = np.linalg.eigvals(Hm)
    sorted_indices = np.argsort(eigenvalues.real)
    

    largest_eig = eigenvalues[sorted_indices[-1]].real
    smallest_eig = eigenvalues[sorted_indices[0]].real
    
    if return_eigenvectors:
        
        _, eigenvectors = np.linalg.eig(Hm)
        
        largest_vec = torch.zeros(V[0].shape)
        for i in range(m):
            largest_vec += (eigenvectors[i, sorted_indices[-1]] * V[i])
        smallest_vec = torch.zeros(V[0].shape)
        for i in range(m):
            smallest_vec += eigenvectors[i, sorted_indices[0]] * V[i]
        return (largest_eig, largest_vec, smallest_eig, smallest_vec) if return_smallest else (largest_eig, largest_vec)
    else:
        return (largest_eig, smallest_eig) if return_smallest else largest_eig
        
        
def power_iteration(model, criterion, data_in, data_out, max_iter=500, tolerance=1e-6, precond=None, specified_layers=None, vector=None):
    # Get parameters of specified layers or all layers if not specified
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
    else:
        params = model.parameters()
    
    # Initialize a random vector
    if vector == None:
        vector = [torch.randn_like(p) for p in params]
    else:
        vector = [v.clone() for v in vector]
    
    # Normalize the initial vector
    with torch.no_grad():
        vector_norm = torch.sqrt(sum(torch.sum(v**2) for v in vector))
        vector = [v / vector_norm for v in vector]
    
    prev_eigenvalue = None
    for iteration in range(max_iter):
        # Clear gradients
        model.zero_grad()
        
        # Compute Hessian-vector product
        if precond is None:
            hvp = hessian_vector_product(model, criterion, data_in, data_out, vector, specified_layers=specified_layers)
        else:
            hvp = hessian_vector_product_with_diagonal(model, criterion, data_in, data_out, vector, precond, specified_layers=specified_layers)
        
        # Compute the new eigenvalue estimate and normalize the resulting vector
        with torch.no_grad():
            new_eigenvalue = sum(torch.sum(h * v) for h, v in zip(hvp, vector))
            vector_norm = torch.sqrt(sum(torch.sum(h**2) for h in hvp))
            vector = [h / vector_norm for h in hvp]
        
        # Check for convergence using relative error
        if prev_eigenvalue is not None:
            relative_error = abs((new_eigenvalue - prev_eigenvalue) / prev_eigenvalue)
            # print(relative_error)
            if relative_error < tolerance:
                print(f"Converged after {iteration + 1} iterations.")
                break
        
        prev_eigenvalue = new_eigenvalue
        
        # Clear cache if using GPU
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    if iteration == max_iter - 1:
        print(f"Maximum iterations ({max_iter}) reached without convergence.")
    
    return new_eigenvalue, vector


def cosine_similarity(grads1, grads2):
    if grads1 is None or grads2 is None:
        return 0.0
    

    v1 = torch.cat([g.ravel() for g in grads1]) 
    v2 = torch.cat([g.ravel() for g in grads2])
    

    dot_product = torch.dot(v1, v2)
    norm1 = torch.norm(v1)
    norm2 = torch.norm(v2)
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    return (dot_product / (norm1 * norm2)).item()
