R"""


python -i local_scripts/misc1/pynmfk_dev001.py

mpirun -n 4 python local_scripts/misc1/pynmfk_dev001.py
mpirun --mca mpi_common_cuda_verbose 100 -n 4 python local_scripts/misc1/pynmfk_dev001.py

/fruitbasket/users/m/openmpi/bin/mpirun --mca mpi_common_cuda_verbose 100 -n 4 python local_scripts/misc1/pynmfk_dev001.py
/fruitbasket/users/m/openmpi/bin/mpirun --mca opal_cuda_verbose 10 -n 4 python local_scripts/misc1/pynmfk_dev001.py


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/misc1/pynmfk_dev001.py

"""

# import sys
# import pyDNMFk.config as config
# config.init(0)
# if True:
#     from pyDNMFk.pyDNMFk import *
#     from pyDNMFk.data_io import *
#     from pyDNMFk.dist_comm import *
#     from scipy.io import loadmat
#     from mpi4py import MPI

# comm = MPI.COMM_WORLD
# args = parse()


# '''parameters initialization block'''


# # Data Read here
# args.fpath = '/fruitbasket/users/m/other_code/pyDNMFk/data/'
# args.fname = 'wtsi'
# args.ftype = 'mat'
# args.precision = np.float32

# # Distributed Comm config block
# # p_r, p_c = 96, 21
# p_r, p_c = 4, 1

# # NMF config block
# args.norm = 'kl'
# args.method = 'mu'
# args.init = 'nnsvd'
# args.itr = 5000
# args.verbose = True

# # #Cluster config block
# # args.start_k = 1
# # args.end_k = 1
# # args.sill_thr = 0.9

# args.k = 256

# # Data Write
# args.results_path = '/fruitbasket/users/m/tmp/'


# '''Parameters prep block'''


# comms = MPI_comm(comm, p_r, p_c)
# comm1 = comms.comm
# rank = comm.rank
# size = comm.size
# args.size, args.rank, args.comm, args.p_r, args.p_c = size, rank, comms, p_r, p_c
# args.row_comm, args.col_comm, args.comm1 = comms.cart_1d_row(), comms.cart_1d_column(), comm1
# A_ij = data_read(args).read().astype(args.precision)

# # A_ij = np.random.uniform(size=[10_000, 20_000]).astype(args.precision)
# A_ij = np.random.uniform(size=[1000, 2000]).astype(args.precision)

# pynmf = PyNMF(A_ij, factors=None, params=args)
# nopt = pynmf.fit()
# # print('Estimated k with NMFk is ', nopt)
import time
import torch
from torchnmf.nmf import NMF
import numpy as np

n_components = 512

V = torch.from_numpy(np.random.uniform(size=[10000, 20000]).astype(np.float32))
model = NMF(V.shape, rank=n_components)

V = V.cuda()
model = model.cuda()

start = time.time()
model.fit(V)
print(time.time() - start)
# .detach().cpu().numpy()
