import torch
import numpy as np
import torch.nn.functional as F
# ==================================================================================
# Define Obiective function
# ==================================================================================

def obj_F_square(X, omega, d, lamb, diag_ind):
    """
    Compute objective function ||F||²
    """
    # S @ omega + d = 0
    row = torch.arange(len(diag_ind))

    prox_log = (omega[row,diag_ind] - d[row,diag_ind] - torch.sqrt((omega[row,diag_ind] + d[row,diag_ind]) ** 2 + 4))/2
    
    dd = omega - torch.nn.functional.softshrink(omega + d, lambd=lamb)    
    dd[row, diag_ind] = prox_log
    
    return torch.linalg.vector_norm(dd.flatten()) ** 2

def obj_F_square_vec(X, omega, d, lamb, diag_ind):
    """
    - Calculates ||F||² for both 1D vectors and 2D matrices.
    - Returns a scalar for 1D input and a vector (each row's objective function value) for 2D input.
    """
    tmp = omega + d

    if omega.dim() == 1: # 1D input
        prox_log = (omega[diag_ind] - d[diag_ind] - torch.sqrt(tmp[diag_ind]**2 + 4)) / 2
        dd = omega - torch.nn.functional.softshrink(tmp, lambd=lamb)

        dd[diag_ind] = prox_log
        dd.pow_(2)
        result = dd.sum()
        del dd, tmp
        return result

    else: # 2D input
        idx = torch.arange(omega.size(0), device=omega.device)
        prox_log = (omega[idx, diag_ind] - d[idx, diag_ind] - torch.sqrt(tmp[idx, diag_ind]**2 + 4)) / 2
        dd = omega - torch.nn.functional.softshrink(tmp, lambd=lamb)
        
        dd[idx, diag_ind] = prox_log
        dd.pow_(2)
        result = dd.sum(dim=1)
        del dd, tmp
        return result

def obj_accord(X, temp, omega, lamb, diag):
    """
    Calculates Accord objective function
    """
    row = torch.arange(len(diag), device = omega.device)
    n, p = X.shape
    XT = X.T

    # (1/2) * || X^T * omega ||^2 / n
    loss = 0.5 * (torch.linalg.vector_norm(temp, dim=1)**2) / n

    # log term
    log_term = -torch.log(omega[row, diag])

    # L1 norm regularization term
    mask = torch.ones_like(omega)
    mask[row, diag] = 0
    l1_norm = lamb * torch.sum(torch.abs(omega) * mask, dim=1)

    return loss + log_term + l1_norm 

def sparse_offdiag_l1_norm(omega):
    """
    Computes the L1 norm of off-diagonal elements in a sparse tensor
    """
    omega = omega.coalesce() 
    row, col = omega.indices()
    values = omega.values()

    mask = row != col
    return torch.sum(torch.abs(values[mask]))

def sparse_l1_norm(omega):
    """
    Computes the L1 norm of off-diagonal elements in a sparse tensor
    """
    omega = omega.coalesce() 
    values = omega.values()

    return torch.sum(torch.abs(values))

def sparse_diagonal(omega):
    """
    Extracts diagonal elements from a sparse tensor and returns them as a dense vector
    """
    omega = omega.coalesce()
    row, col = omega.indices()
    values = omega.values()

    diag_mask = row == col
    diag_indices = row[diag_mask]
    diag_values = values[diag_mask]

    diag = torch.zeros(omega.size(0), device=omega.device)
    diag[diag_indices] = diag_values
    return diag

def obj_accord_sparse(X, omega, lamb, device):
    """
    Calculates Accord objective function with sparse matrix
    """
    n, p = X.shape
    X = X.to(device)
    omega = omega.to(device)

    product = torch.sparse.mm(omega, X.T) 
    loss = 0.5 * torch.linalg.vector_norm(product) ** 2 / n

    diag = sparse_diagonal(omega)
    log_term = -torch.log(diag).sum()

    l1_norm = lamb * sparse_offdiag_l1_norm(omega)
    
    return loss + log_term + l1_norm