import time
import scanpy as sc
import numpy as np
import argparse
import warnings
from scimilarity import CellAnnotation, align_dataset
from scimilarity.utils import lognorm_counts
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description='Process single-cell data.')
parser.add_argument('--input_adata', type=str, required=True, help='Path to input .h5ad file')
parser.add_argument('--output_adata', type=str, required=True, help='Path to output .npy file')
parser.add_argument('--model_path', type=str, required=True, help='Path to scimilarity model')
args = parser.parse_args()

adams = sc.read(args.input_adata)
annotation_path = args.model_path + 'models/annotation_model_v1'
ca = CellAnnotation(model_path=annotation_path)
adams.obs.columns = ['soma_joinid', 'dataset_id', 'assay', 'assay_ontology_term_id',
       'celltype_raw', 'cell_type_ontology_term_id', 'development_stage',
       'development_stage_ontology_term_id', 'disease',
       'disease_ontology_term_id', 'donor_id', 'is_primary_data',
       'self_reported_ethnicity', 'self_reported_ethnicity_ontology_term_id',
       'sex', 'sex_ontology_term_id', 'suspension_type', 'tissue',
       'tissue_ontology_term_id', 'tissue_general',
       'tissue_general_ontology_term_id', 'raw_sum', 'nnz', 'raw_mean_nnz',
       'raw_variance_nnz', 'n_measured_vars']

tmp = list()
for i in range(len(adams.var.index)):
    tmp.append(adams.var.feature_name[adams.var.index[i]])

adams.var.index = tmp
adams = align_dataset(adams, ca.gene_order)
adams.layers['counts'] = adams.X
adams = lognorm_counts(adams)

print('Getting embeddings')
start_time = time.time()
X_scimilarity = np.array(ca.get_embeddings(adams.X))
end_time = time.time()
execution_time = end_time - start_time
print(f"It took {execution_time} seconds to get embeddings to {adams.shape[0]} cells")

print('Saving the embeddings in npy format')
np.save(args.output_adata, X_scimilarity)

