import numpy as np
import matplotlib.pyplot as plt
import torch

def getSparse(sparse_matrix,sparse=False):
    if sparse:
        # convert the numpy sparse matrix to a PyTorch sparse tensor
        values = sparse_matrix.data
        indices = np.vstack((sparse_matrix.indices, sparse_matrix.indptr))

        i = torch.LongTensor(indices)
        # v = torch.FloatTensor(values)
        v = torch.tensor(values, dtype = torch.float32)
        shape = sparse_matrix.shape

        foo = torch.sparse.FloatTensor(i, v, torch.Size(shape))
    else:
        # convert the sparse matrix to a dense numpy array
        dense_array = sparse_matrix.toarray()

        # convert the dense numpy array to a PyTorch tensor
        foo = torch.from_numpy(dense_array).float()

    return foo