import numpy as np
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# ==================================================================================
# Procedure of genereating Precision Matrix and Data
# ==================================================================================
def proj_precision_mat(Theta, nz_idx, lower_weight = 0.1, lower_eig = 0.2):
    """
    Projects the given precision matrix to satisfy
    """
    eig_val, eig_vec = np.linalg.eigh(Theta)

    eig_val_new = np.maximum(eig_val, lower_eig)
    Theta_temp = eig_vec @ np.diag(eig_val_new) @ eig_vec.T
    
    Theta_new = np.zeros_like(Theta_temp)
    Theta_new[nz_idx] = np.sign(Theta_temp[nz_idx]) * np.maximum(np.abs(Theta_temp[nz_idx]), lower_weight)
    
    np.fill_diagonal(Theta_new, 1.0)
    return Theta_new

def generate_precision(p, Skel, lower_weight = 0.1, lower_eig = 0.2, random_state = 2025):
    """
     Generates a random precision matrix given a skeleton (graph structure).
    """
    np.random.seed(random_state)
    Theta = np.zeros((p, p))
    
    nz_idx = np.nonzero(np.tril(Skel)) # get lower-triangle nonzero indices
    edge_number = len(nz_idx[0]) # number of edge
    edge_weights = np.random.uniform(low = -1, high = 1, size = edge_number) # generate weiget of edge
    Theta[nz_idx] = edge_weights 

    # make symmetric and set diagonal to 1
    nz_idx = np.concatenate((nz_idx[0], nz_idx[1])), np.concatenate((nz_idx[1], nz_idx[0]))
    Theta = np.tril(Theta) + np.tril(Theta).T 
    np.fill_diagonal(Theta, 1.0) 

    # projection
    for i in range(1000):
        Theta_new = proj_precision_mat(Theta, nz_idx, lower_weight = lower_weight, lower_eig = lower_eig)
        if np.max(np.abs(Theta - Theta_new)) < 1e-3:
            Theta = Theta_new
            break
        Theta = Theta_new
    
    Theta = np.real(Theta)
    return Theta

def generate_data(p, n, Sigma, std=True, random_state=2025):
    """
    Generates multivariate normal data
    """
    rs = np.random.RandomState(random_state)
    X = rs.multivariate_normal(mean=np.zeros(p), cov=Sigma, size=n) # draw samples
    if std:
        X = (X - np.mean(X, axis=0)) / np.std(X, axis=0) # standardize
    return X