import scib
import scanpy as sc

def integration_baseline(adata, batch_key, celltype_key, model='scvi', hvg=None, seed=10, **kwargs):
    if hvg is not None:
        sc.pp.highly_variable_genes(adata, layer='counts',flavor='seurat_v3', subset=True)
    
    if model == 'scvi':
        adata.X = adata.layers['counts'].copy()
        adata = scib.ig.scvi(adata, batch_key, seed=seed)
    elif model == 'scanvi':
        adata.X = adata.layers['counts'].copy()
        adata = scib.ig.scanvi(adata, batch_key, celltype_key, seed=seed)
    elif model == 'scanvi_test':
        adata.X = adata.layers['counts'].copy()
        adata = scib.ig.scanvi_test(adata, batch_key, celltype_key, seed=seed)   
        adata = adata[adata.obs.split == 'test'] 
    elif model == 'trvae':
        adata.X = adata.layers['counts'].copy()
        adata = scib.ig.trvaep(adata, batch_key)
    elif model == 'bbknn':
        adata = scib.ig.bbknn(adata, batch_key)
    elif model == 'combat':
        adata = scib.ig.combat(adata, batch_key)
    elif model == 'desc':
        adata = scib.ig.desc(adata, batch_key, use_gpu=True, **kwargs)
    elif model == 'harmony':
        adata = scib.ig.harmony(adata, batch_key)
    elif model == 'mnn':
        adata = scib.ig.mnn(adata, batch_key)
    elif model == 'saucie':
        adata = scib.ig.saucie(adata, batch_key)
    elif model == 'scanorama':
        adata = scib.ig.scanorama(adata, batch_key)
    else:
        raise NotImplementedError(f"Unsupported model {model}")
    
    scores = integration_eval(adata, batch_key, celltype_key)
    return scores

def integration_eval(adata, batch_key, celltype_key, use_rep='X_emb'):
    sc.pp.neighbors(adata, use_rep=use_rep) #, method='rapids')
    use_rep = 'X_pca' if use_rep is None else use_rep
    metrics_df = scib.metrics.metrics(adata, adata, batch_key, celltype_key, embed=use_rep, cluster_key="cluster",
                                      organism='human', ari_=True, nmi_=True, graph_conn_=True)
    return {
        'nmi': metrics_df.loc['NMI_cluster/label'].values[0],
        'ari': metrics_df.loc['ARI_cluster/label'].values[0],
        'graph_conn': metrics_df.loc['graph_conn'].values[0],
    }