import time
import warnings
import anndata
import Cell_BLAST as cb
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os
warnings.simplefilter("ignore")

cb.config.RANDOM_SEED = 0

def parse_args():
    parse = argparse.ArgumentParser(description='cellblast model training')
    parse.add_argument('--input_adata', default=None, type=str, help='Input file path')
    parse.add_argument('--query_index', default=None,  help='query index')
    parse.add_argument('--target_index', default=None,  help='target index')
    parse.add_argument('--output_data', default=None)
    parse.add_argument('--output_dir', default=None)

    args = parse.parse_args()
    return args

args = parse_args()
os.makedirs(args.output_dir,exist_ok=True)
adata = anndata.read_h5ad(args.input_adata)
adata.var_names = [i for i in adata.var.feature_name]
temp_df = adata.to_df()
temp_df = temp_df.groupby(by=temp_df.columns, axis=1).sum()

adata = sc.AnnData(temp_df)
adata.obs_names = temp_df.index
adata.var_names = temp_df.columns

if args.query_index is not None:
    query_indexes = np.load(args.query_index)
    train_index = np.array([i for i in range(len(adata)) if i not in query_indexes])
    training_adata = adata[train_index]
elif args.target_index is not None:
    target_indexs = np.loadtxt(args.target_index,dtype=int)
    training_adata = adata[target_indexs]
else:
    training_adata=adata

model = cb.directi.fit_DIRECTi(
        training_adata,
        latent_dim=10, cat_dim=20, epoch=20, learning_rate=0.001,batch_size=20
    )

adata.obsm["X_latent"] = model.inference(adata)
embeddings = adata.obsm["X_latent"]
np.save(args.output_data , embeddings)

