import torch
import torch.nn as nn

# 1. Gauss–Newton ：H_GN = (1/n) * sum_i [grad f(x_i)] [grad f(x_i)]^T
def hv_gn(model, X, y, v, device):
    hv = torch.zeros_like(v, device=device)
    n = X.shape[0]
    for i in range(n):
        # （ GPU ）
        f_i = model(X[i:i+1])
        grad_i = torch.autograd.grad(f_i, model.parameters(), create_graph=True, retain_graph=True)
        grad_i = torch.cat([g.contiguous().view(-1) for g in grad_i])
        inner = torch.dot(grad_i, v)
        hv += grad_i * inner
    hv /= n
    return 2 * hv  #  MSELoss（）

# 2. ：H_R = (1/n) * sum_i (f(x_i)-y_i) * Hessian(f(x_i))
def hv_r(model, X, y, v, device):
    hv = torch.zeros_like(v, device=device)
    n = X.shape[0]
    for i in range(n):
        f_i = model(X[i:i+1])
        r_i = (f_i - y[i:i+1]).squeeze()
        grad_i = torch.autograd.grad(f_i, model.parameters(), create_graph=True, retain_graph=True)
        grad_i = torch.cat([g.contiguous().view(-1) for g in grad_i])
        hv_i = torch.autograd.grad(grad_i, model.parameters(), grad_outputs=v, retain_graph=True)
        hv_i = torch.cat([h.contiguous().view(-1) for h in hv_i])
        hv += r_i * hv_i
    hv /= n
    return 2 * hv

#############################################
# ： Hessian– hv_fn
#############################################

def power_iteration_total(model, X, y, vt, num_iters=50, eps=1e-8, tol=1e-6, device='cpu', v_initial=None, diag_initial=None, compute_grad_hessian_grad=False):
    # v_initial = None
    # diag_initial = None
    model.eval()
    criterion = nn.MSELoss()
    
    # 
    X, y = X.to(device), y.to(device)
    vt = vt.to(device) if vt is not None else None
    
    # 
    y_pred = model(X)
    loss = criterion(y_pred, y)
    
    # 
    grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
    grads = torch.cat([g.contiguous().view(-1) for g in grads]).to(device)
    
    #  ()
    def _init_vector(initial, default):
        if initial is not None:
            v = initial.detach().clone().to(device)
            return v
        return default / torch.norm(default)
    
    # Hessian
    v_H = _init_vector(v_initial, torch.randn_like(grads, device=device))
    
    # Hessian
    compute_diag = vt is not None
    v_diagH = None
    if compute_diag:
        v_diagH = _init_vector(diag_initial, torch.randn_like(grads, device=device))
    
    # 
    max_eigen_H = 0.0
    max_eigen_diagvH = 0.0 if compute_diag else None
    prev_values = []

    for i in range(num_iters):
        if compute_grad_hessian_grad:
            if i == 0:
                unit_grad = grads / torch.norm(grads)
                Hv_H_grads = torch.autograd.grad(grads, model.parameters(), grad_outputs=unit_grad, retain_graph=True)
                Hv_H = torch.cat([g.contiguous().view(-1) for g in Hv_H_grads]).to(device)
                grad_hessian_grad = torch.dot(Hv_H, unit_grad)
                # print("grad_hessian_grad:", grad_hessian_grad.item())
        #  Hv_H
        Hv_H_grads = torch.autograd.grad(grads, model.parameters(), grad_outputs=v_H, retain_graph=True)
        Hv_H = torch.cat([g.contiguous().view(-1) for g in Hv_H_grads]).to(device)
        hv_h_norm = torch.norm(Hv_H)
        if hv_h_norm > 0:
            v_H = Hv_H / hv_h_norm
            max_eigen_H = hv_h_norm.item()

        # 
        if compute_diag:
            Hv_diagH_grads = torch.autograd.grad(grads, model.parameters(), grad_outputs=v_diagH, retain_graph=True)
            Hv_diagH = torch.cat([g.contiguous().view(-1) for g in Hv_diagH_grads]).to(device)
            Hv_diagH_scaled = Hv_diagH / (torch.sqrt(vt) + eps)
            hv_diagh_norm = torch.norm(Hv_diagH_scaled)
            if hv_diagh_norm > 0:
                v_diagH = Hv_diagH_scaled / hv_diagh_norm
                max_eigen_diagvH = hv_diagh_norm.item()

        # 
        current_values = [max_eigen_H]
        if compute_diag:
            current_values.append(max_eigen_diagvH)
        
        if prev_values and all(abs(c - p) < tol for c, p in zip(current_values, prev_values)):
            break
        prev_values = current_values

    # 
    result_H = (max_eigen_H, v_H.detach())
    result_diag = (max_eigen_diagvH, v_diagH.detach()) if compute_diag else (None, None)
    
    if compute_grad_hessian_grad:
        return result_H, result_diag, grad_hessian_grad.item()
    else:
        return result_H, result_diag



# from torch.autograd.functional import vhp
# def grad_hessian_grad(model, X, y):
#     # 
#     # 
#     def loss_fn(*params):
#         # 
#         with torch.no_grad():
#             for p, new_p in zip(model.parameters(), params):
#                 p.copy_(new_p)
#         # 
#         y_pred = model(X)
#         criterion = nn.MSELoss()
#         loss = criterion(y_pred, y)
#         return loss

#     # 
#     params = tuple(model.parameters())

#     # 
#     with torch.enable_grad():
#         loss = loss_fn(*params)
#         grad = torch.autograd.grad(loss, params, create_graph=True)
#         grad_vec = torch.cat([g.view(-1) for g in grad])

#     # 
#     unit_grad = grad_vec / torch.norm(grad_vec)
#     # 
#     def split_grad(grad_vec, params):
#         idx = 0
#         split_grads = []
#         for p in params:
#             numel = p.numel()
#             split_grads.append(grad_vec[idx:idx + numel].view_as(p))
#             idx += numel
#         return tuple(split_grads)

#     unit_grad_split = split_grad(unit_grad, params)

#     #  Hessian 
#     _, hvp_vec = vhp(loss_fn, params, unit_grad_split)

#     #  Hessian 
#     hvp_vec_cat = torch.cat([h.view(-1) for h in hvp_vec])

#     #  Hessian 
#     hessian_eigenvalue = torch.dot(hvp_vec_cat, unit_grad)
#     print("Hessian :", hessian_eigenvalue.item())
#     return hessian_eigenvalue.item()

# ##  hessian vector product 
# from torch.autograd.functional import vhp

# def power_iteration_total(model, X, y, vt=None, num_iters=50, eps=1e-8, tol=1e-6, device='cpu', v_initial=None, diag_initial=None):
#     model.eval()
#     criterion = nn.MSELoss()
#     X, y = X.to(device), y.to(device)
#     vt = vt.to(device) if vt is not None else None

#     # 
#     def get_flat_params():
#         return torch.cat([p.data.view(-1) for p in model.parameters()]).to(device)

#     def loss_func(flat_params):
#         idx = 0
#         for p in model.parameters():
#             p_size = p.numel()
#             p.data = flat_params[idx:idx+p_size].view_as(p).data
#             idx += p_size
#         y_pred = model(X)
#         return criterion(y_pred, y)

#     # 
#     flat_params = get_flat_params().detach().requires_grad_(True)
#     loss = loss_func(flat_params)

#     # diag
#     compute_diag = vt is not None

#     # 
#     if v_initial is not None:
#         v_H = v_initial.clone().detach().to(device)
#     else:
#         v_H = torch.randn_like(flat_params, device=device)
#     v_H = v_H / torch.norm(v_H)

#     v_diagH = None
#     if compute_diag:
#         if diag_initial is not None:
#             v_diagH = diag_initial.clone().detach().to(device)
#         else:
#             v_diagH = torch.randn_like(flat_params, device=device)
#         v_diagH = v_diagH / torch.norm(v_diagH)

#     max_eigen_H = 0.0
#     max_eigen_diagvH = 0.0 if compute_diag else None
#     prev_values = None

#     for _ in range(num_iters):
#         # Hv_H
#         _, Hv_H = vhp(loss_func, flat_params, v_H)
#         Hv_H_norm = torch.norm(Hv_H)
#         if Hv_H_norm > 0:
#             v_H = Hv_H / Hv_H_norm
#             max_eigen_H = Hv_H_norm.item()

#         # Hv_diagH（）
#         if compute_diag:
#             _, Hv_diagH = vhp(loss_func, flat_params, v_diagH)
#             Hv_diagH_scaled = Hv_diagH / (torch.sqrt(vt) + eps)
#             scaled_norm = torch.norm(Hv_diagH_scaled)
#             if scaled_norm > 0:
#                 v_diagH = Hv_diagH_scaled / scaled_norm
#                 max_eigen_diagvH = scaled_norm.item()

#         # 
#         current_values = [max_eigen_H]
#         if compute_diag:
#             current_values.append(max_eigen_diagvH)

#         if prev_values is not None:
#             diffs = [abs(current - prev) for current, prev in zip(current_values, prev_values)]
#             if all(d < tol for d in diffs):
#                 break
#         prev_values = current_values

#     # 
#     result_H = (max_eigen_H, v_H)
#     result_diag = (max_eigen_diagvH, v_diagH) if compute_diag else (None, None)
    
#     return result_H, result_diag

#  Hessian 
def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False):
    '''Compute the gradient of outputs with respect to inputs'''
    '''outputs: a scalar'''
    '''inputs: a list of tensors'''
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)
    grads = torch.autograd.grad(outputs, inputs, grad_outputs,
                                allow_unused=True,
                                retain_graph=retain_graph,
                                create_graph=create_graph)
    grads = [x if x is not None else torch.zeros_like(
        y) for x, y in zip(grads, inputs)]
    return torch.cat([x.contiguous().view(-1) for x in grads])
    
    
def hessian(output, inputs, out=None, allow_unused=False, create_graph=False):
    '''Compute the Hessian of output with respect to inputs'''
    '''output: a scalar'''
    '''inputs: a list of tensors'''
    '''## : l = loss(net(X), y)  A=hessian(l, net.parameters())'''
    #     assert output.ndimension() == 0
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    n = sum(p.numel() for p in inputs)
    if out is None:
        out = output.new_zeros(n, n)

    ai = 0
    for i, inp in enumerate(inputs):
        [grad] = torch.autograd.grad(
        output, inp, create_graph=True, allow_unused=allow_unused)
        grad = torch.zeros_like(inp) if grad is None else grad
        grad = grad.contiguous().view(-1)
        for j in range(inp.numel()):
            if grad[j].requires_grad:
                row = gradient(
                grad[j], inputs[i:], retain_graph=True, create_graph=create_graph)[j:]
            else:
                row = grad[j].new_zeros(sum(x.numel() for x in inputs[i:]) - j)

            out[ai, ai:].add_(row.type_as(out))  # ai's row
            if ai + 1 < n:
                out[ai + 1:, ai].add_(row[1:].type_as(out))  # ai's column
            del row
            ai += 1
        del grad
    return out


def compute_hessian_vector_product(loss, params, vector):
    # 
    grads = torch.autograd.grad(loss, params, create_graph=True)
    # 
    dot_product = 0
    for g, v in zip(grads, vector):
        dot_product += (g * v).sum()
    # （Hessian）
    hvps = torch.autograd.grad(dot_product, params, retain_graph=True)
    return [hvp.detach() for hvp in hvps]

#  vt
def get_vt_hat_from_optimizer(optimizer, params):
    """
     Adam  vt_hat， params 。
    """
    vt_hat = []
    beta1 = optimizer.param_groups[0]['betas'][0]  #  beta1
    beta2 = optimizer.param_groups[0]['betas'][1]  #  beta2
    for param in params:
        if param in optimizer.state:
            state = optimizer.state[param]
            vt = state['exp_avg_sq'].clone()  #  vt
            t = state['step']  # 
            if t == 0:
                vt_hat.append(torch.zeros_like(param))  #  t=0，vt_hat  0
            else:
                bias_correction = 1 - beta2 ** t  # 
                bias_correction1 = 1 - beta1 ** t
                vt_hat.append(bias_correction1 * vt / bias_correction)  #  vt_hat
        else:
            vt_hat.append(torch.zeros_like(param))  # ， 0
    return vt_hat


# def power_iteration(loss, params, max_iter=50, tol=1e-3, initial_vector=None):
#     '''
#      vt ， Hessian ， vt  eps
#     '''
#     if initial_vector is None:
#         # 
#         v = [torch.randn_like(p) for p in params]
#     else:
#         # 
#         v = initial_vector
    
#     # 
#     v_flat = torch.cat([vi.flatten() for vi in v])
#     v_norm = torch.norm(v_flat).item()
#     if v_norm == 0:
#         v = [torch.randn_like(vi) for vi in v]
#         v_flat = torch.cat([vi.flatten() for vi in v])
#         v_norm = torch.norm(v_flat).item()
#     v = [vi / v_norm for vi in v]
    
#     prev_eigenvalue = None
#     eigenvalue = None
#     for _ in range(max_iter):
#         # Hessian
#         Hv = compute_hessian_vector_product(loss, params, v)
        
#         # （Rayleigh）
#         Hv_flat = torch.cat([hvi.flatten() for hvi in Hv])
#         v_flat = torch.cat([vi.flatten() for vi in v])
#         eigenvalue = torch.dot(v_flat, Hv_flat).item()
        
#         # 
#         Hv_norm = torch.norm(Hv_flat).item()
#         if Hv_norm == 0:
#             break
#         v = [hvi / Hv_norm for hvi in Hv]
        
#         # 
#         if prev_eigenvalue is not None and abs(eigenvalue - prev_eigenvalue) < tol:
#             break
#         prev_eigenvalue = eigenvalue
#     return (eigenvalue, v)

def power_iteration(loss, params, vt=None, eps=1e-8, max_iter=50, tol=1e-3, initial_vector=None):
    '''
     vt ， H_hat = diag(1/sqrt(vt)+eps)H 
    '''
    if initial_vector is None:
        # 
        v = [torch.randn_like(p) for p in params]
    else:
        # 
        v = initial_vector
    
    # 
    v_flat = torch.cat([vi.flatten() for vi in v])
    v_norm = torch.norm(v_flat).item()
    if v_norm == 0:
        v = [torch.randn_like(vi) for vi in v]
        v_flat = torch.cat([vi.flatten() for vi in v])
        v_norm = torch.norm(v_flat).item()
    v = [vi / v_norm for vi in v] #  v 
    
    prev_eigenvalue = None
    eigenvalue = None
    for _ in range(max_iter):
        # Hessian
        Hv = compute_hessian_vector_product(loss, params, v)
        if vt is None:
            Hv_scaled = Hv
        else:
            Hv_scaled = [hvi / (torch.sqrt(vt_com) + eps) for hvi, vt_com in zip(Hv, vt)]
        # （Rayleigh）
        Hv_scaled_flat = torch.cat([hvi.flatten() for hvi in Hv_scaled])
        v_flat = torch.cat([vi.flatten() for vi in v])
        eigenvalue = torch.dot(v_flat, Hv_scaled_flat).item()
        
        # 
        Hv_scaled_norm = torch.norm(Hv_scaled_flat).item()
        if Hv_scaled_norm == 0:
            break
        v = [hvi / Hv_scaled_norm for hvi in Hv_scaled]
        
        # 
        if prev_eigenvalue is not None and abs(eigenvalue - prev_eigenvalue) < tol:
            break
        prev_eigenvalue = eigenvalue
    return (eigenvalue, v)


def power_iteration_combined(loss, params, vt=None, eps=1e-8, max_iter=50, tol=1e-3, initial_vector=None, h_hat_initial_vector=None):
    '''
     H  H_hat = diag(1/sqrt(vt)+eps)H 。
    '''
    
    def initialize_vector(initial_vector, params):
        '''，。'''
        if initial_vector is None:
            return [torch.randn_like(p) for p in params]
        return initial_vector
    
    def normalize_vector(v):
        '''。'''
        v_flat = torch.cat([vi.flatten() for vi in v])
        v_norm = torch.norm(v_flat).item()
        if v_norm == 0:
            v = [torch.randn_like(vi) for vi in v]
            v_flat = torch.cat([vi.flatten() for vi in v])
            v_norm = torch.norm(v_flat).item()
        return [vi / v_norm for vi in v]
    
    def compute_rayleigh_quotient(v, Hv):
        ''' Rayleigh （）。'''
        v_flat = torch.cat([vi.flatten() for vi in v])
        Hv_flat = torch.cat([hvi.flatten() for hvi in Hv])
        return torch.dot(v_flat, Hv_flat).item()
    
    #  v  v_hat
    v = initialize_vector(initial_vector, params)
    v_hat = initialize_vector(h_hat_initial_vector, params)
    
    # 
    v = normalize_vector(v)
    v_hat = normalize_vector(v_hat)
    
    prev_values = []
    eigenvalue, eigenvalue_hat = None, None
    
    for _ in range(max_iter):
        #  H  Hessian 
        Hv = compute_hessian_vector_product(loss, params, v)
        eigenvalue = compute_rayleigh_quotient(v, Hv)
        v = normalize_vector(Hv)
        
        #  H_hat  Hessian 
        Hv_hat = compute_hessian_vector_product(loss, params, v_hat)
        Hv_hat_scaled = [hvi / (torch.sqrt(vt_com) + eps) for hvi, vt_com in zip(Hv_hat, vt)]
        eigenvalue_hat = compute_rayleigh_quotient(v_hat, Hv_hat_scaled)
        v_hat = normalize_vector(Hv_hat_scaled)
        
        # 
        current_values = [eigenvalue, eigenvalue_hat]
        if prev_values and all(abs(c - p) < tol for c, p in zip(current_values, prev_values)):
            break
        prev_values = current_values
    
    # 
    result_H = (eigenvalue, v)
    result_diag = (eigenvalue_hat, v_hat)
    return result_H, result_diag


def compute_grad_hessian_lambda(loss, params, gradients, vt=None, eps=1e-8):
    """
    Hessian
    ：
        loss -  
        params - 
        gradients - （params）
    ：
        grad_hessian_value - Hessian
        
    '''
     vt 
    '''
    """
    if gradients is None:
        return None
    
    # 
    grad_flat = torch.cat([g.flatten() for g in gradients])
    grad_norm = torch.norm(grad_flat).item()
    
    # 
    if grad_norm < 1e-6:
        return 0.0
    
    # 
    unit_grad = [g / grad_norm for g in gradients]
    
    # Hessian
    H_grad = compute_hessian_vector_product(loss, params, unit_grad)
    if vt is None:
        H_grad_scaled = H_grad
    else:
        H_grad_scaled = [hvi / (torch.sqrt(vt_com) + eps) for hvi, vt_com in zip(H_grad, vt)]
    
    # Rayleigh
    H_grad_scaled_flat = torch.cat([hg.flatten() for hg in H_grad_scaled])
    unit_grad_flat = torch.cat([g.flatten() for g in unit_grad])
    hessian_value = torch.dot(unit_grad_flat, H_grad_scaled_flat).item()
    
    return hessian_value

def compute_grad_vt_hessian_lambda(loss, params, gradients, vt, eps=1e-8):
    """
     H_hat = diag(1/sqrt(vt)+epsilon)H 
    ：
        loss -  
        params - 
        gradients - （params）
        vt - Adamvt, （params）
        eps - Adamepsilon
    ：
        grad_hessian_value - Hessian
    """
    if gradients is None:
        return None
    
    # 
    grad_flat = torch.cat([g.flatten() for g in gradients])
    grad_norm = torch.norm(grad_flat).item()
    
    # 
    if grad_norm < 1e-6:
        return 0.0
    
    # 
    unit_grad = [g / grad_norm for g in gradients]
    
    # Hessian
    H_grad = compute_hessian_vector_product(loss, params, unit_grad)
    
    H_grad_flat = torch.cat([hg.flatten() for hg in H_grad])
    vt_flat = torch.cat([vt_com.flatten() for vt_com in vt])
    H_grad_scaled = H_grad_flat / (torch.sqrt(vt_flat) + eps)
    
    unit_grad_flat = torch.cat([g.flatten() for g in unit_grad])
    hessian_value = torch.dot(unit_grad_flat, H_grad_scaled).item()
    
    return hessian_value