import scanpy as sc
import scvi
import os
import argparse
import tempfile
import pickle
import numpy as np


def parse_args():
    parse = argparse.ArgumentParser(description='Cell Embedding Retrieval from data')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    parse.add_argument('--output_dir', default=None, type=str, help='Output file directory')
    parse.add_argument('--model_save_dir', default=None, type=str, help='Saved model directory')

    args = parse.parse_args()
    return args


args = parse_args()
adata_path = args.input_adata
adata = sc.read(
    adata_path
)

adata.layers['counts'] = adata.X.copy()
model =  scvi.model.SCVI.load(args.model_save_dir)
npy_output = model.get_latent_representation(adata)

print("embedding was saved in {}".format(args.output_dir))
np.save(args.output_dir + "embeds_scvi.npy",npy_output)
