import scanpy as sc
import scvi
import os
import argparse
import tempfile
import pickle
import numpy as np

def parse_args():
    parse = argparse.ArgumentParser(description='scvi model training')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    args = parse.parse_args()
    return args

args = parse_args()
scvi.settings.seed = 0
sc.set_figure_params(figsize=(4, 4))
adata_path = args.input_adata
adata = sc.read(
    adata_path
)

adata.layers['counts'] = adata.X.copy()
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="donor_id")
model = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")

model.train()
model.save("downstream_analysis/scvi_model/",save_anndata=True)
