import torch
import numpy as np
import time
import torch.nn.functional as F
import torch.nn as nn

def hessian_vector_product(args, model, criterion, vector, data_in, data_out, create_graph=True, specified_layers=None):
    # Forward pass
    outputs, _ = model(data_in)
    loss = criterion(outputs.view(args.batch_size, args.seq_len, args.vocab_size)[:,-1,:], data_out[:,-1].view(-1))
    
    # For testing
    # output = model(data_in)
    # loss = criterion(output, data_out.squeeze())
    
    # Get parameters of specified layers or all layers if not specified
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers) and p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]
    
    # Check if there are any parameters to compute gradients
    if not params:
        raise ValueError("No parameters found for gradient computation. Check your specified_layers or model parameters.")
    
    # First backward pass
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)
    
    # Compute grad-vector product
    grad_vector_product = sum(torch.sum(g * v) for g, v in zip(grads, vector))
    
    # Second backward pass
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=create_graph)
    
    return hvp

def hessian_vector_product_with_diagonal(args, model, criterion, vector, data_in, data_out, diagonal_matrix, create_graph=True, specified_layers=None):
    # Forward pass
    outputs, _ = model(data_in)
    loss = criterion(outputs.view(args.batch_size, args.seq_len, args.vocab_size)[:,-1,:], data_out[:,-1].view(-1))

    # Get parameters of specified layers or all layers if not specified
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers) and p.requires_grad]
    else:
        params = [p for p in model.parameters() if p.requires_grad]

    # Check if there are any parameters to compute gradients
    if not params:
        raise ValueError("No parameters found for gradient computation. Check your specified_layers or model parameters.")

    # First backward pass
    grads = torch.autograd.grad(loss, params, create_graph=create_graph)

    # Apply diagonal matrix to vector
    scaled_vector = [v * d for v, d in zip(vector, diagonal_matrix)]

    # Compute grad-vector product with scaled vector
    grad_vector_product = sum(torch.sum(g * v) for g, v in zip(grads, scaled_vector))

    # Second backward pass
    hvp = torch.autograd.grad(grad_vector_product, params, retain_graph=create_graph)

    return hvp


def power_iteration(args, model, criterion, data_in, data_out, max_iter=500, tolerance=1e-6, precond=None, specified_layers=None):
    # Get parameters of specified layers or all layers if not specified
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
    else:
        params = model.parameters()
    
    # Initialize a random vector
    vector = [torch.randn_like(p) for p in params]
    
    # Normalize the initial vector
    with torch.no_grad():
        vector_norm = torch.sqrt(sum(torch.sum(v**2) for v in vector))
        vector = [v / vector_norm for v in vector]
    
    prev_eigenvalue = None
    for iteration in range(max_iter):
        # Clear gradients
        model.zero_grad()
        
        # Compute Hessian-vector product
        if precond is None:
            hvp = hessian_vector_product(args, model, criterion, vector, data_in, data_out, specified_layers=specified_layers)
        else:
            hvp = hessian_vector_product_with_diagonal(args, model, criterion, vector, data_in, data_out, precond, specified_layers=specified_layers)
        
        # Compute the new eigenvalue estimate and normalize the resulting vector
        with torch.no_grad():
            new_eigenvalue = sum(torch.sum(h * v) for h, v in zip(hvp, vector))
            vector_norm = torch.sqrt(sum(torch.sum(h**2) for h in hvp))
            vector = [h / vector_norm for h in hvp]
        
        # Check for convergence using relative error
        if prev_eigenvalue is not None:
            relative_error = abs((new_eigenvalue - prev_eigenvalue) / prev_eigenvalue)
            # print(relative_error)
            if relative_error < tolerance:
                print(f"Converged after {iteration + 1} iterations.")
                break
        
        prev_eigenvalue = new_eigenvalue
        
        # Clear cache if using GPU
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    if iteration == max_iter - 1:
        print(f"Maximum iterations ({max_iter}) reached without convergence.")
    
    return new_eigenvalue, vector

def flatten_and_concat(tensor_list):
    return torch.cat([t.flatten() for t in tensor_list])

def reshape_like(vector, tensor_list):
    shapes = [t.shape for t in tensor_list]
    tensor_list = []
    idx = 0
    for shape in shapes:
        n = np.prod(shape)
        tensor_list.append(vector[idx:idx+n].view(shape))
        idx += n
    return tensor_list

def lanczos(args, model, criterion, data_in, data_out, n, m, precond=None, specified_layers=None):
    def hvp_func(v):
        if specified_layers:
            params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
        else:
            params = model.parameters()
        v_list = reshape_like(v, params)
                # Compute Hessian-vector product
        if precond is None:
            hvp = hessian_vector_product(args, model, criterion, v_list, data_in, data_out, specified_layers=specified_layers)
        else:
            hvp = hessian_vector_product_with_diagonal(args, model, criterion, v_list, data_in, data_out, precond, specified_layers=specified_layers)
        # hvp = hessian_vector_product(args, model, criterion, v_list, data_in, data_out, specified_layers=specified_layers)
        return flatten_and_concat(hvp)

    q = torch.randn(n, device=next(model.parameters()).device)
    q = q / torch.norm(q)
    
    alpha = torch.zeros(m, device=q.device)
    beta = torch.zeros(m-1, device=q.device)
    
    q_old = torch.zeros(n, device=q.device)
    for j in range(m):
        model.zero_grad()
        v = hvp_func(q)
        alpha[j] = torch.dot(q, v)
        
        v = v - alpha[j]*q - (beta[j-1]*q_old if j > 0 else 0)
        
        if j < m-1:
            beta[j] = torch.norm(v)
            if beta[j] < 1e-8:
                return alpha[:j+1], beta[:j]
            q_old = q
            q = v / beta[j]
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return alpha, beta

def estimate_largest_eigenvalue_lanczos(args, model, criterion, data_in, data_out, m=30, precond=None, specified_layers=None):
    if specified_layers:
        n = sum(p.numel() for name, p in model.named_parameters() if any(layer in name for layer in specified_layers))
    else:
        n = sum(p.numel() for p in model.parameters())
    
    alpha, beta = lanczos(args, model, criterion, data_in, data_out, n, m, precond, specified_layers)
    
    T = np.diag(alpha.cpu().numpy()) + np.diag(beta.cpu().numpy(), k=1) + np.diag(beta.cpu().numpy(), k=-1)
    
    eigenvalues = np.linalg.eigvalsh(T)
    
    return eigenvalues[-1]


def hutchinson_hessian_diagonal(args, model, criterion, data_in, data_out, num_samples=100, specified_layers=None):
    device = next(model.parameters()).device
    
    if specified_layers:
        params = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
    else:
        params = model.parameters()
    
    diag_estimate = [torch.zeros_like(p) for p in params]
    
    for _ in range(num_samples):
        rand_vec = [torch.randint(0, 2, p.shape, device=device) * 2.0 - 1.0 for p in params]
        
        hvp = hessian_vector_product(args, model, criterion, rand_vec, data_in, data_out, specified_layers=specified_layers)
        
        for d, h, r in zip(diag_estimate, hvp, rand_vec):
            d.add_(h * r)
    
    diag_estimate = [d / num_samples for d in diag_estimate]
    
    return diag_estimate

def hutchinson_hessian_estimate_for_layers(args, model, criterion, data_in, data_out, specified_layers=None, num_samples=100):
    """
    Estimate the trace of the Hessian for specified layers using Hutchinson's estimator.
    
    Args:
    - model: the neural network model
    - loss_fn: the loss function
    - inputs: input data
    - targets: target data
    - layers: list of layer names to estimate Hessian for (if None, use all layers)
    - num_samples: number of random samples to use for the estimation
    
    Returns:
    - estimated trace of the Hessian for specified layers
    """
    device = next(model.parameters()).device
    
    if specified_layers is None:
        parameters = list(model.parameters())
    else:
        parameters = [p for name, p in model.named_parameters() if any(layer in name for layer in specified_layers)]
    
    if not parameters:
        raise ValueError("No parameters found for the specified layers.")
    
    
    trace_estimate = 0
    for _ in range(num_samples):
        v = [torch.randn_like(p) for p in parameters]
        hvp = hessian_vector_product(args, model, criterion, v, data_in, data_out,specified_layers=specified_layers)
        trace_estimate += sum((h * v_i).sum().item() for h, v_i in zip(hvp, v))
    
    trace_estimate /= num_samples
    
    return trace_estimate




def gauss_newton_bartlett(args, model, data_in):
    device = next(model.parameters()).device
    
    # Step 2: Draw a mini-batch of input
    data_in = data_in.to(device)
    
    # Step 3: Compute logits on the mini-batch
    outputs, _ = model(data_in)
    # print(outputs.shape)
    
    # Step 4: Sample y_hat from softmax distribution
    probs = F.softmax(outputs, dim=2)
    # print(probs.shape)
    y_hat = torch.multinomial(probs[:,-1,:], num_samples=1).squeeze()
    # print(y_hat.shape)
    # Step 5: Calculate g_hat
    loss = F.cross_entropy(outputs[:,-1,:], y_hat, reduction='mean')
    g_hat = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    
    # Step 6: Compute the GNB estimate
    gnb_estimate = 0
    for grad in g_hat:
        gnb_estimate += (args.batch_size * (grad.view(-1) @ grad.view(-1))).item()
    
    return gnb_estimate

# Usage example:
# model = YourModel()
# data_loader = torch.utils.data.DataLoader(your_dataset, batch_size=B)
# gnb_estimate = gauss_newton_bartlett(model, data_loader, batch_size=B)