import torch


def compute_dZ_dT(X, Y, dX_dT, dY_dT=None):
    """
    Compute the derivative of Z with respect to T, given the derivative of X with respect to T,
    matrix X, and optionally the derivative of Y with respect to T and matrix Y, using PyTorch.
    Y is not considered a function of T if dY_dT is None.

    Parameters:
    - dX_dT: The derivative of X with respect to T, a 4D PyTorch tensor with dimensions (m, n, x_rows, x_cols).
    - X: Matrix X, a 2D PyTorch tensor with dimensions (m, n).
    - dY_dT: The derivative of Y with respect to T, a 4D PyTorch tensor with dimensions (n, p, y_rows, y_cols), or None.
    - Y: Matrix Y, a 2D PyTorch tensor with dimensions (n, p).

    Returns:
    - dZ_dT: The derivative of Z (resulting from X * Y) with respect to T, a 4D PyTorch tensor.
    """
    m, n = X.shape
    n, p = Y.shape
    
    # Initialize output tensor based on dimensions of dX_dT and dY_dT
    if dX_dT is not None:
        dtype=dX_dT.dtype
        device=dX_dT.device
    elif dY_dT is not None:
        dtype=dY_dT.dtype
        device=dY_dT.device

    dZ_dT = torch.zeros((m, p), dtype=dtype, device=device)
    
    # Compute the derivative contribution from dX_dT
    if dX_dT is not None:
        dZ_dT += torch.matmul(dX_dT, Y)
    if dY_dT is not None:
        dZ_dT += torch.matmul(X, dY_dT)
                
    return dZ_dT#.sum(-1).sum(-1) # The sum reduces the full Jacobian

def compute_dZ_dT_for_loop(X, Y, dX_dT, dY_dT=None):
    """
    Compute the derivative of Z with respect to T, given the derivative of X with respect to T,
    matrix X, and optionally the derivative of Y with respect to T and matrix Y, using PyTorch.
    Y is not considered a function of T if dY_dT is None.

    Parameters:
    - dX_dT: The derivative of X with respect to T, a 4D PyTorch tensor with dimensions (m, n, x_rows, x_cols).
    - X: Matrix X, a 2D PyTorch tensor with dimensions (m, n).
    - dY_dT: The derivative of Y with respect to T, a 4D PyTorch tensor with dimensions (n, p, y_rows, y_cols), or None.
    - Y: Matrix Y, a 2D PyTorch tensor with dimensions (n, p).

    Returns:
    - dZ_dT: The derivative of Z (resulting from X * Y) with respect to T, a 4D PyTorch tensor.
    """
    m, n = X.shape
    n, p = Y.shape
    
    # Initialize output tensor based on dimensions of dX_dT and dY_dT
    if dY_dT is None:
        max_rows, max_cols = dX_dT.size(2-2), dX_dT.size(3-2)
        dtype=dX_dT.dtype
        device=dX_dT.device
    elif dX_dT is None:
        max_rows, max_cols = dY_dT.size(2-2), dY_dT.size(3-2)
        dtype=dY_dT.dtype
        device=dY_dT.device
    else:
        max_rows = max(dX_dT.size(2-2), dY_dT.size(2-2))
        max_cols = max(dX_dT.size(3-2), dY_dT.size(3-2))
        dtype=dX_dT.dtype
        device=dX_dT.device

    dZ_dT = torch.zeros((m, p, max_rows, max_cols), dtype=dtype, device=device)
    
    # Compute the derivative contribution from dX_dT
    if dX_dT is not None:
        for k in range(dX_dT.size(2-2)):
            for l in range(dX_dT.size(3-2)):
                # dZ_dT[:, :, k, l] += torch.matmul(dX_dT[:, :, k, l], Y)
                dZ_dT[k, :, k, l] += dX_dT[k, l] * Y[l,:]
            
    # Compute the derivative contribution from dY_dT, if it exists
    if dY_dT is not None:
        for k in range(dY_dT.size(2-2)):
            for l in range(dY_dT.size(3-2)):
                # dZ_dT[:, :, k, l] += torch.matmul(X, dY_dT[:, :, k, l])
                dZ_dT[:, l, k, l] += X[:,k] * dY_dT[k, l]
                
    return dZ_dT.sum(-1).sum(-1) # The sum reduces the full Jacobian

def compute_d2Z_dTT(X, Y, dX_dT=None, dY_dT=None, d2X_dTT=None, d2Y_dTT=None):
    """
    Compute the full Hessian of Z(T) = X(T)Y(T) with respect to T, incorporating the second-order derivatives,
    and allowing for cases where derivatives with respect to Y are None and handling different dimensions for T.
    """
    # return None
    
    m, n, p = X.shape[0], X.shape[1], Y.shape[1]
    
    # Adjust for potentially different T dimensions in d2X_dTT and d2Y_dTT with corrected condition checks
    t_rows_xx, t_cols_xx = (d2X_dTT.shape[2-2], d2X_dTT.shape[3-2]) if d2X_dTT is not None else (0, 0)
    t_rows_yy, t_cols_yy = (d2Y_dTT.shape[2-2], d2Y_dTT.shape[3-2]) if d2Y_dTT is not None else (0, 0)

    # Adjust for potentially different T dimensions in dX_dT and dY_dT with corrected condition checks
    t_rows_x, t_cols_x = (dX_dT.shape[2-2], dX_dT.shape[3-2]) if dX_dT is not None else (0, 0)
    t_rows_y, t_cols_y = (dY_dT.shape[2-2], dY_dT.shape[3-2]) if dY_dT is not None else (0, 0)

    # Determine the maximum T dimensions to accommodate both d2X_dTT and d2Y_dTT
    t_rows, t_cols = max(t_rows_x, t_rows_y, t_rows_xx, t_rows_yy), max(t_cols_x, t_cols_y, t_cols_xx, t_cols_yy)
    
    # Determine min for mixing
    k_mix = min(t_rows_x, t_rows_y)
    l_mix = min(t_cols_x, t_cols_y)
    
    if d2X_dTT is not None:
        d2Z_dTT += torch.matmul(d2X_dTT, Y)
    if d2Y_dTT is not None:
        d2Z_dTT += torch.matmul(X, d2Y_dTT)
    if dX_dT is not None and dY_dT is not None:
        d2Z_dTT += 2 * torch.matmul(dX_dT,dY_dT)


    return d2Z_dTT#.sum(-1).sum(-1).sum(-1).sum(-1)


def compute_d2Z_dT1T2(X, Y, X_eps1=None, X_eps2=None, Y_eps1=None, Y_eps2=None, d2X_dTT=None, d2Y_dTT=None):
    """
    Compute the full Hessian of Z(T) = X(T)Y(T) with respect to T, incorporating the second-order derivatives,
    and allowing for cases where derivatives with respect to Y are None and handling different dimensions for T.
    """
    # return None
    
    m, n, p = X.shape[0], X.shape[1], Y.shape[1]
    
    d2Z_dTT = torch.zeros((m, p)).to(X.device)
    if d2X_dTT is not None:
        d2Z_dTT += torch.matmul(d2X_dTT, Y)
    if d2Y_dTT is not None:
        d2Z_dTT += torch.matmul(X, d2Y_dTT)
    if X_eps1 is not None and Y_eps1 is not None:
        d2Z_dTT += torch.matmul(X_eps1,Y_eps2) + torch.matmul(X_eps2,Y_eps1)
    return d2Z_dTT



############
def compute_d2Z_dTT_accurate(X, Y, dX_dT=None, dY_dT=None, d2X_dTT=None, d2Y_dTT=None):
    """
    Compute the full Hessian of Z(T) = X(T)Y(T) with respect to T, incorporating the second-order derivatives,
    and allowing for cases where derivatives with respect to Y are None and handling different dimensions for T.
    """
    # return None
    
    m, n, p = X.shape[0], X.shape[1], Y.shape[1]
    
    # Adjust for potentially different T dimensions in d2X_dTT and d2Y_dTT with corrected condition checks
    t_rows_xx, t_cols_xx = (d2X_dTT.shape[2], d2X_dTT.shape[3]) if d2X_dTT is not None else (0, 0)
    t_rows_yy, t_cols_yy = (d2Y_dTT.shape[2], d2Y_dTT.shape[3]) if d2Y_dTT is not None else (0, 0)

    # Adjust for potentially different T dimensions in dX_dT and dY_dT with corrected condition checks
    t_rows_x, t_cols_x = (dX_dT.shape[2], dX_dT.shape[3]) if dX_dT is not None else (0, 0)
    t_rows_y, t_cols_y = (dY_dT.shape[2], dY_dT.shape[3]) if dY_dT is not None else (0, 0)

    # Determine the maximum T dimensions to accommodate both d2X_dTT and d2Y_dTT
    t_rows, t_cols = max(t_rows_x, t_rows_y, t_rows_xx, t_rows_yy), max(t_cols_x, t_cols_y, t_cols_xx, t_cols_yy)
    
    # Determine min for mixing
    k_mix = min(t_rows_x, t_rows_y)
    l_mix = min(t_cols_x, t_cols_y)
    
    # Initialize the tensor for the full Hessian of Z with respect to T
    d2Z_dTT = torch.zeros((m, p, t_rows, t_cols, t_rows, t_cols), dtype=X.dtype, device=X.device)
    
    # Compute the full Hessian
    for i in range(m):
        for j in range(p):
            for k1 in range(t_rows):
                for l1 in range(t_cols):
                    for k2 in range(t_rows):
                        for l2 in range(t_cols):
                            # Direct second-order effects of X on Z
                            if d2X_dTT is not None and k1 < t_rows_x and l1 < t_cols_x and k2 < t_rows_x and l2 < t_cols_x:
                                for r in range(n):
                                    d2Z_dTT[i, j, k1, l1, k2, l2] += d2X_dTT[i, r, k1, l1, k2, l2] * Y[r, j]
                            # Direct second-order effects of Y on Z
                            if d2Y_dTT is not None and k1 < t_rows_y and l1 < t_cols_y and k2 < t_rows_y and l2 < t_cols_y:
                                for r in range(n):
                                    d2Z_dTT[i, j, k1, l1, k2, l2] += X[i, r] * d2Y_dTT[r, j, k1, l1, k2, l2]
                            # Mixed effects from first-order derivatives of X and Y
                            if dX_dT is not None and dY_dT is not None:
                                for r in range(n):
                                    if k1 < k_mix and l1 < l_mix and k2 < k_mix and l2 < l_mix:
                                        d2Z_dTT[i, j, k1, l1, k2, l2] += (dX_dT[i, r, k1, l1] * dY_dT[r, j, k2, l2] +
                                                                         dX_dT[i, r, k2, l2] * dY_dT[r, j, k1, l1])

    return d2Z_dTT.sum(-1).sum(-1).sum(-1).sum(-1)


def iterate_and_modify_tensors(tensor_list):
    total_elements = sum(t.numel() for t in tensor_list)
    current_global_index = 0
    
    for _ in range(total_elements):
        # Clone and reset all tensors to zero
        modified_tensors = [t.clone().zero_() for t in tensor_list]
        
        temp_index = current_global_index
        for t in modified_tensors:
            num_elements = t.numel()
            if temp_index < num_elements:
                # Manually calculate the indices for multi-dimensional tensors
                indices = []
                for dim_size in reversed(t.shape):
                    indices.append(temp_index % dim_size)
                    temp_index //= dim_size
                indices.reverse()
                
                # Set the selected element to one
                t[tuple(indices)] = 1
                break
            else:
                temp_index -= num_elements
        
        yield modified_tensors
        current_global_index += 1


def P_proj(v1, h):
    P_tilde = v1.view(-1,1) @ v1.view(1,-1) * h
    return P_tilde

def P_proj_control(v1, h, c = -1.):
    # A good value for c seems to be -D
    D = v1.shape[0]
    C = torch.eye(D) * (D + 2)
    x = v1.view(-1,1)
    P_tilde = x @ x.t() * h + c * (x @ x.t() * (x.t() @ x).item() - C)
    return P_tilde

def P_proj_control_diag(v1, h, c = -1.):
    # A good value for c seems to be -D
    D = v1.shape[0]
    C = (torch.ones(D) * (D + 2))
    x = v1
    P_tilde_diag = x ** 2 * h + c * (x ** 2 * (x.t() @ x).item() - C)
    return P_tilde_diag

def projection_to_Hessian(P_tilde):
    D = P_tilde.shape[0]
    A = torch.diag(torch.ones(D)*2) + torch.ones(D,D)
    diag = torch.linalg.solve(A, P_tilde.diag())
    return (1 - torch.eye(D))/2. * P_tilde + diag.diag()

def projection_to_Hessian_diag(P_tilde_diag):
    D = P_tilde_diag.shape[0]
    # A = torch.diag(torch.ones(D)*2) + torch.ones(D,D) 
    
    inverse_row_sum = 1. / (D+2.) # A[0].sum()
    diag_proportion = D * 0.5 + 0.5 # A[0,0] / A[0].sum()
    
    inverse_diag_val = diag_proportion * inverse_row_sum
    inverse_row_off_diag_val = (inverse_row_sum - inverse_diag_val) / (D-1)
    x = torch.zeros(D)
    x[1:] = inverse_row_off_diag_val
    x[0] = inverse_diag_val
    off_P = P_tilde_diag * inverse_row_off_diag_val
    diag = P_tilde_diag * inverse_diag_val
    # diag = torch.zeros(D)
    # for i in range(D):
    #     diag[i] = (torch.roll(x,i) * P_tilde_diag).sum()
    # return diag
    for i in range(D):
        diag[i] = diag[i] + off_P[:i].sum() + off_P[i+1:].sum()
    return diag


def hess_vector_build(N):
    '''
    Generator for use in building full Hessian matrix. It iterates through all the combination of basis matrices
    needed to build a full Hessian. E.g. if N = 2, then the output is:
    
    (tensor([1., 0.]), tensor([1., 0.])) -> H[0,0]
    (tensor([1., 0.]), tensor([0., 1.])) -> H[0,1]
    (tensor([0., 1.]), tensor([1., 0.])) -> H[1,0]
    (tensor([0., 1.]), tensor([0., 1.])) -> H[1,1]
    '''
    for i in range(N):
        for j in range(N):
            x = torch.zeros(N)
            y = torch.zeros(N)

            x[i] = 1.
            y[j] = 1.
            yield x, y

def jac_vector_build(N):
    '''
    Generator for use in building full Jacobian matrix. It iterates through all the basis matrices
    needed to build a full Jacobian. E.g. if N = 2, then the output is:
    
    tensor([1., 0.]) -> J[0]
    tensor([0., 1.]) -> J[1]
    '''
    for i in range(N):
        x = torch.zeros(N)
        x[i] = 1.
        yield x