import random
import numpy as np





def simulate_missing_data(X_true, perc,eps=1e-5, Nan=False,return_normalized=False, seed=50):
    d1,d2 = X_true.shape # extract dimensions

    # Missing Data Indicies
    random.seed(seed)
    rmN = int(np.round(d1*d2*perc/100,0)) # number of missing data
    rmInd = random.sample(list(range(d1*d2)),rmN) #randomly samply the missing indicies
    rmIndMatrixRow = np.array(rmInd)%d1
    rmIndMatrixCol = np.floor(np.array(rmInd)/d1).astype(int)

    # Complementary Remaining Set
    Omega = np.array(list(set(list(range(d1*d2))) - set(rmInd)))
    OmegaRow = Omega%d1
    OmegaCol = np.floor(Omega/d1).astype(int)

    # X_avail is the available matrix, X_true is the full true matrix
    X_avail = X_true.copy()
    X_avail[rmIndMatrixRow,rmIndMatrixCol] = np.nan if Nan else 0

    if return_normalized:
        X_norm = X_avail.copy()
        X_true_norm = X_true.copy()
        mu_O = {} # column mean of available entries
        std_O = {} # column std of available entries
        
        # Define the weights and center the available data for PCA
        for i_col in range(X_norm.shape[-1]):
            # column-based standardization
            qq = OmegaRow[OmegaCol==i_col] # Get row indicies of non-missing values in each column
            mu_O[i_col] = np.mean(X_norm[qq,i_col])
            std_O[i_col] = np.std(X_norm[qq,i_col])
            X_norm[:,i_col] = (X_norm[:,i_col]-mu_O[i_col])/(std_O[i_col]+eps)
            X_true_norm[:,i_col] = (X_true_norm[:,i_col]-mu_O[i_col])/(std_O[i_col]+eps)
        mu_O_list = [mu_O_elem for _,mu_O_elem in mu_O.items()]
        std_O_list = [std_O_elem for _,std_O_elem in std_O.items()]
        return X_avail,X_true,X_norm,X_true_norm,OmegaRow,OmegaCol,mu_O_list,std_O_list,rmIndMatrixRow,rmIndMatrixCol
    else:
        return X_avail,X_true,OmegaRow,OmegaCol,rmIndMatrixRow,rmIndMatrixCol







def simulate_missing_data_movielens(movie_rec, perc,eps=1e-5, Nan=False,return_normalized=False, seed=50):
    d1, d2 = int(movie_rec[:,1].max()),int(movie_rec[:,0].max()) # extract dimensions
    
    fullInd = np.ravel_multi_index(movie_rec[:,[1,0]].astype(int).T-1,(d1,d2), order='F')
    fullIndRows = movie_rec[:,1].astype(int)-1
    fullIndCols = movie_rec[:,0].astype(int)-1

    # form true matrix
    X_true = np.zeros((d1,d2))
    X_true[fullIndRows,fullIndCols]=movie_rec[:,2]

    # Randomize missing data
    rmN = int(np.round(len(fullInd)*(perc)/100)) # number of missing data
    random.seed(seed)
    rmInd100k = random.sample(list(range(len(fullInd))),rmN) # randomly sample the missing indicies
    #rmInd = fullInd[rmInd100k]
    rmIndMatrixRow = fullIndRows[rmInd100k]
    rmIndMatrixCol = fullIndCols[rmInd100k]
    # The complementary remaining set
    #Omega =np.array(list(set(fullInd)-set(list(rmInd))))
    OmegaInd100k = np.array(list(set(range(len(fullInd)))-set(rmInd100k)))
    OmegaRow = fullIndRows[OmegaInd100k]
    OmegaCol = fullIndCols[OmegaInd100k]
    #Omega_len = len(Omega)
    del OmegaInd100k,rmInd100k,fullInd,fullIndRows,fullIndCols
    # X_avail is the available matrix, X_true is the full true matrix
    X_avail = X_true.copy()
    X_avail[rmIndMatrixRow,rmIndMatrixCol] = np.nan if Nan else 0

    if return_normalized:
        X_norm = X_avail.copy()
        X_true_norm = X_true.copy()
        mu_O = {} # column mean of available entries
        std_O = {} # column std of available entries
        # Define the weights and center the available data for PCA
        for i_col in range(X_norm.shape[-1]):
            # column-based standardization
            qq = OmegaRow[OmegaCol==i_col] # Get row indicies of non-missing values in each column
            mu_O[i_col] = np.mean(X_norm[qq,i_col])
            std_O[i_col] = np.std(X_norm[qq,i_col])
            X_norm[:,i_col] = (X_norm[:,i_col]-mu_O[i_col])/(std_O[i_col]+eps)
            X_true_norm[:,i_col] = (X_true_norm[:,i_col]-mu_O[i_col])/(std_O[i_col]+eps)
        del qq
        X_norm[rmIndMatrixRow,rmIndMatrixCol] = np.nan if Nan else 0
        mu_O_list = [mu_O_elem for _,mu_O_elem in mu_O.items()]
        std_O_list = [std_O_elem for _,std_O_elem in std_O.items()]
        return X_avail,X_true,X_norm,X_true_norm,OmegaRow,OmegaCol,mu_O_list,std_O_list,rmIndMatrixRow,rmIndMatrixCol
    else:
        return X_avail,X_true,OmegaRow,OmegaCol,rmIndMatrixRow,rmIndMatrixCol