import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import scanpy as sc
import scvi
import os
import argparse
import tempfile
import pickle
import numpy as np


def parse_args():
    parse = argparse.ArgumentParser(description='scvi model training')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    parse.add_argument('--pair_adata', default=None,  help='query index')
    parse.add_argument('--output_data', default=None)
    parse.add_argument('--output_dir', default=None)

    args = parse.parse_args()
    return args

args = parse_args()
os.makedirs(args.output_dir,exist_ok=True)
scvi.settings.seed = 0
sc.set_figure_params(figsize=(4, 4))
adata_path = args.input_adata
adata = sc.read(
    adata_path
)

adata.layers['counts'] = adata.X.copy()
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
adata.raw = adata
scvi.model.LinearSCVI.setup_anndata(adata, layer="counts")
model =  scvi.model.LinearSCVI(adata)
model.train()
pair = sc.read(args.pair_adata)
pair.layers['counts'] = pair.X.copy()
sc.pp.normalize_total(pair)
sc.pp.log1p(pair)
pair.raw = pair

npy_output = model.get_latent_representation(adata)
np.save("linearscvi-overlap-RNA-reference/10x-Multiome-Pbmc10k-small-RNA-overlap-linearscvi.npy" ,npy_output)
npy_output = model.get_latent_representation(pair)
np.save("linearscvi-overlap-RNA-reference/10x-Multiome-Pbmc10k-small-ACTIVE-overlap-linearscvi.npy" ,npy_output)
