# TEST FILE FOR BASIC TISSUE FUNCTIONALITIES


# import packages

import tissue.main, tissue.downstream

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import anndata as ad
import os

import warnings
warnings.filterwarnings("ignore")

#################################################################################################################
print ("Testing TISSUE data loading...")
try:
    adata, RNAseq_adata = tissue.main.load_paired_datasets("tests/data/Spatial_count.txt",
                                                           "tests/data/Locations.txt",
                                                           "tests/data/scRNA_count.txt")
except:
    raise Exception ("Failed data loading from tests/data/ with tissue.main.load_paired_datasets()")

#################################################################################################################
print ("Testing TISSUE preprocessing...")
adata.var_names = [x.lower() for x in adata.var_names]
RNAseq_adata.var_names = [x.lower() for x in RNAseq_adata.var_names]
try:
    tissue.main.preprocess_data(RNAseq_adata, standardize=False, normalize=True)
except:
    raise Exception ("Failed TISSUE preprocessing. Make sure all dependencies are installed.")
gene_names = np.intersect1d(adata.var_names, RNAseq_adata.var_names)
adata = adata[:, gene_names].copy()
target_gene = "plp1"
target_expn = adata[:, target_gene].X.copy()
adata = adata[:, [gene for gene in gene_names if gene != target_gene]].copy()

#################################################################################################################
print("Testing TISSUE spatial gene expression prediction...")
try:
    tissue.main.predict_gene_expression (adata, RNAseq_adata, [target_gene],
                                         method="spage", n_folds=3, n_pv=10)
except:
    raise Exception("TISSUE prediction failed for SpaGE at tissue.main.predict_gene_expression()")

#################################################################################################################
print("Testing TISSUE calibration...")
try:
    tissue.main.build_spatial_graph(adata, method="fixed_radius", n_neighbors=15)
except:
    raise Exception ("Failed TISSUE spatial graph building at tissue.main.build_spatial_graph()")
try:
    tissue.main.conformalize_spatial_uncertainty(adata, "spage_predicted_expression", calib_genes=adata.var_names,
                                                 grouping_method="kmeans_gene_cell", k=4, k2=2)
except:
    raise Exception ("Failed TISSUE cell-centric variability and calibration scores processing at tissue.main.conformalize_spatial_uncertainty()")
try:
    tissue.main.conformalize_prediction_interval (adata, "spage_predicted_expression", calib_genes=adata.var_names,
                                                  alpha_level=0.23, compute_wasserstein=True)
except:
    raise Exception ("Failed TISSUE prediction interval calibration at tissue.main.conformalize_prediction_interval()")

#################################################################################################################
print ("Testing TISSUE multiple imputation t-test...")
adata.obs['condition'] = ['A' if i < round(adata.shape[0]/2) else 'B' for i in range(adata.shape[0])]
try:
    tissue.downstream.multiple_imputation_testing(adata, "spage_predicted_expression",
                                                  calib_genes=adata.var_names,
                                                  condition='condition',
                                                  group1 = "A", # use None to compute for all conditions, condition vs all
                                                  group2 = "B", # use None to compute for group1 vs all
                                                  n_imputations=2)
except:
    raise Exception ("Failed TISSUE MI t-test at tissue.downstream.multiple_imputation_testing()")

#################################################################################################################
print("Testing TISSUE cell filtering")
X_uncertainty = adata.obsm["spage_predicted_expression_hi"].values - adata.obsm["spage_predicted_expression_lo"].values
try:
    keep_idxs = tissue.downstream.detect_uncertain_cells (X_uncertainty,
                                                          proportion="otsu",
                                                          stratification=adata.obs['condition'].values)
except:
    raise Exception ("Failed TISSUE cell filtering at tissue.downstream.detect_uncertain_cells()")
try:
    keep_idxs = tissue.downstream.filtered_PCA (adata, # anndata object
                                                "spage", # prediction method
                                                proportion="otsu",
                                                stratification=adata.obs['condition'].values,
                                                return_keep_idxs=True)
except:
    raise Exception ("Failed TISSUE-filtered PCA at tissue.downstream.filtered_PCA()")

print("TISSUE tests passed!")

# Contains functions for all downstream applications of TISSUE calibration scores and prediction intervals

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import anndata as ad
import os
import sys

#from tissue.main import build_calibration_scores, get_spatial_uncertainty_scores_from_metadata
from .main import build_calibration_scores, get_spatial_uncertainty_scores_from_metadata


def multiple_imputation_testing (adata, predicted, calib_genes, condition, test="ttest", n_imputations=100,
                                 group1=None, group2=None, symmetric=False, return_keys=False, save_mi=False):
    '''
    Uses multiple imputation with the score distributions to perform hypothesis testing
    
    Parameters
    ----------
        adata [AnnData] - contains adata.obsm[predicted] corresponding to the predicted gene expression
        predicted [str] - key in adata.obsm that corresponds to predicted gene expression
        calib_genes [list or arr of str] - names of the genes in adata.var_names that are used in the calibration set
        condition [str] - key in adata.obs for which to compute the hypothesis test
            group1 [value] - value in adata.obs[condition] identifying the first comparison group
                             if None, will perform group vs all comparisons for all unique values in adata.obs[condition]
            group2 [value] - value in adata.obs[condition] identifying the second comparison group
                             if None, will compare against all values that are not group1
        test [str] - statistical test to use:
                        "ttest" - two-sample t-test using Rubin's rules (best theoretical support/guarantee)
                        "wilcoxon_greater" - one-sample wilcoxon (Mann-Whitney U test) for greater expression using p-value transformation
                        "wilcoxon_less" - one-sample wilcoxon (Mann-Whitney U test) for lesser expression using p-value transformation
                        "spatialde" - SpatialDE test using p-value transformation
        n_imputations [int] - number of imputations to use
        symmetric [bool] - whether to have symmetric (or non-symmetric) prediction intervals
        return_keys [bool] - whether to return the keys for which to access the results from adata
        save_mi [False or str] - multiple imputation saving (only used for multiple_imputation_ttest())
        
    Returns
    -------
        Modifies adata in-place to add the statistics and test results to metadata
        Optionally returns the keys to access the results from adata
        
    '''
    #####################################################################
    # T-test (default) - this is the option with best theoretical support
    #####################################################################
    if test == "ttest":
        keys = multiple_imputation_ttest (adata, predicted, calib_genes, condition, n_imputations=n_imputations,
                                 group1=group1, group2=group2, symmetric=symmetric, save_mi=save_mi)
            
    #####################################################################
    # One-sample ("less"/"greater") Wilcoxon test  
    #####################################################################    
    elif test == "wilcoxon_less":
        keys = multiple_imputation_wilcoxon (adata, predicted, calib_genes, condition, n_imputations=n_imputations,
                                 group1=group1, group2=group2, symmetric=symmetric, direction='less')
    elif test == "wilcoxon_greater":
        keys = multiple_imputation_wilcoxon (adata, predicted, calib_genes, condition, n_imputations=n_imputations,
                                 group1=group1, group2=group2, symmetric=symmetric, direction='greater')
                                 
    #####################################################################
    # SpatialDE (spatially variable genes) test 
    ##################################################################### 
    elif test == "spatialde":
        keys = multiple_imputation_spatialde (adata, predicted, calib_genes, n_imputations=n_imputations, symmetric=symmetric)
    
    # raise exception if test does not match options
    else:
        raise Exception ("Specified test not recognized")
        
    if return_keys is True:
        
        return(keys)


def multiple_imputation_spatialde (adata, predicted, calib_genes, n_imputations=100, symmetric=False):
    '''
    Runs TISSUE multiple imputation SpatialDE test using p-value transformation
    
    See multiple_imputation_testing() for details on parameters
    '''
    import SpatialDE
    
    # get uncertainties and scores from saved adata
    scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)
    
    ### Building calibration sets for scores
    
    scores_flattened_dict = build_calibration_scores(adata, predicted, calib_genes, symmetric=symmetric,
                                                     include_zero_scores=True, trim_quantiles=[None, 0.8]) # trim top 20% scores
    
    ### Multiple imputation

    # init dictionary to hold results
    stat_dict = {}
    stat_dict["pvalue"] = {}
    
    for m in range(n_imputations):
        
        # generate new imputation
        new_G = sample_new_imputation_from_scores (G, G_stdev, groups, scores_flattened_dict, symmetric=symmetric)
        
        key = "spatialde"
    
        if m == 0: # init list
            stat_dict["pvalue"][key] = []
        
        # get spatialDE p-values
        normalized_matrix = new_G/(1+np.sum(new_G,axis=1)[:,None])
        normalized_matrix = np.log1p((normalized_matrix-np.min(normalized_matrix)) * 100) 
        sp_df = pd.DataFrame(normalized_matrix,
                          columns=adata.obsm[predicted].columns,
                          index=adata.obsm[predicted].index)

        results = SpatialDE.run(adata.obsm['spatial'], sp_df)

        # sort by gene name order
        results.drop_duplicates(subset = ['g'], keep = 'first', inplace = True) # workaround duplication SpatialDE bug
        results.g = results.g.astype("category")
        results.g = results.g.cat.set_categories(adata.obsm[predicted].columns)
        results = results.sort_values(["g"])

        # get pvalues
        pval = list(results["pval"])
        stat_dict["pvalue"][key].append(pval)

    # pool statistics
    pooled_results_dict = {}
    pooled_results_dict['pvalue'] = {}
    # for each test grouping
    for key in stat_dict['pvalue'].keys():
        pooled_results_dict['pvalue'][key] = []
        pval_arr = np.vstack(stat_dict['pvalue'][key])
        # for each gene, get mi pvalue
        for ci in range(pval_arr.shape[1]):
            mi_pval = multiply_imputed_pvalue (pval_arr[:,ci], method="licht_rubin")
            pooled_results_dict['pvalue'][key].append(mi_pval)
     
    # add stats to adata
    keys_list = []
    for key_measure in pooled_results_dict.keys():
        for key_comparison in pooled_results_dict[key_measure].keys():
            adata.uns[predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure] = pd.DataFrame(np.array(pooled_results_dict[key_measure][key_comparison])[None,:],
                                                                                                 columns=adata.obsm[predicted].columns)
            keys_list.append(predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure)
    
    return(keys_list)


def multiple_imputation_wilcoxon (adata, predicted, calib_genes, condition, n_imputations=100,
                                  group1=None, group2=None, symmetric=False, direction="greater"):
    '''
    Runs TISSUE multiple imputation one-sample Wilcoxon (greater/lesser) test using p-value transformation
    
    See multiple_imputation_testing() for details on parameters
    '''
    from scipy.stats import mannwhitneyu
    
    # get uncertainties and scores from saved adata
    scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)
    
    ### Building calibration sets for scores
    
    scores_flattened_dict = build_calibration_scores(adata, predicted, calib_genes, symmetric=symmetric,
                                                     include_zero_scores=True, trim_quantiles=[None, 0.8]) # trim top 20% scores
    
    ### Multiple imputation

    # init dictionary to hold results
    stat_dict = {}
    stat_dict["pvalue"] = {}
    
    # cast condition to str
    condition = str(condition)
    
    for m in range(n_imputations):
        
        # generate new imputation
        new_G = sample_new_imputation_from_scores (G, G_stdev, groups, scores_flattened_dict, symmetric=symmetric)
            
        if group1 is None: # pairwise comparisons against all
            
            for g1 in np.unique(adata.obs[condition]):
                
                key = str(g1)+"_all"
            
                if m == 0: # init list
                    stat_dict["pvalue"][key] = []
                
                g1_bool = (adata.obs[condition] == g1) # g1
                g2_bool = (adata.obs[condition] != g1) # all other
                
                # get SpatialDE p-values
                pval = []
                for ci in range(new_G.shape[1]):
                    u,p = mannwhitneyu(new_G[g1_bool,ci], new_G[g2_bool,ci], alternative=direction)
                    pval.append(p)
                
                stat_dict["pvalue"][key].append(pval)
                
        elif group2 is None: # group1 vs all
        
            key = str(group1)+"_all"
            
            if m == 0: # init list
                stat_dict["pvalue"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] != group1) # all other
            
            # get wilcoxon p-values
            pval = []
            for ci in range(new_G.shape[1]):
                u,p = mannwhitneyu(new_G[g1_bool,ci], new_G[g2_bool,ci], alternative=direction)
                pval.append(p)
            
            stat_dict["pvalue"][key].append(pval)
            
        else: # group1 vs group2
            
            key = str(group1)+"_"+str(group2)
            
            if m == 0: # init list
                stat_dict["pvalue"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] == group2) # g2
            
            # get wilcoxon p-values
            pval = []
            for ci in range(new_G.shape[1]):
                u,p = mannwhitneyu(new_G[g1_bool,ci], new_G[g2_bool,ci], alternative=direction)
                pval.append(p)
                
            stat_dict["pvalue"][key].append(pval)

    # pool statistics
    pooled_results_dict = {}
    pooled_results_dict['pvalue'] = {}
    # for each test grouping
    for key in stat_dict['pvalue'].keys():
        pooled_results_dict['pvalue'][key] = []
        pval_arr = np.vstack(stat_dict['pvalue'][key])
        # for each gene, get mi pvalue
        for ci in range(pval_arr.shape[1]):
            mi_pval = multiply_imputed_pvalue (pval_arr[:,ci], method="licht_rubin")
            pooled_results_dict['pvalue'][key].append(mi_pval)
     
    # add stats to adata
    keys_list = []
    for key_measure in pooled_results_dict.keys():
        for key_comparison in pooled_results_dict[key_measure].keys():
            adata.uns[predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure] = pd.DataFrame(np.array(pooled_results_dict[key_measure][key_comparison])[None,:],
                                                                                                 columns=adata.obsm[predicted].columns)
            keys_list.append(predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure)
    
    return(keys_list)


def multiply_imputed_pvalue (pvalues, method="licht_rubin"):
    '''
    Computes a multiply imputed p-value from a list of p-values according to Licht-Rubin procedure or median procedure
    
    Parameters
    ----------
        pvalues [array-like] - array of p-values from multiple imputation tests
        method [str] - which method for p-value calculation to use: "licht_rubin" or "median"
        
    Returns
    -------
        mi_pvalue [float] - p-value modified for multiple imputation
        
    See reference for technical details: https://stefvanbuuren.name/fimd/sec-multiparameter.html#sec:chi
    '''
    from scipy.stats import norm
    
    if method == "licht_rubin":
        z = norm.ppf(pvalues)  # transform to z-scale
        num = np.nanmean(z)
        den = np.sqrt(1 + np.nanvar(z))
        mi_pvalue = norm.cdf( num / den) # average and transform back
    
    elif method == "median":
        mi_pvalue = np.nanmedian(pvalues)
    
    else:
        raise Exception ("method for multiply_imputed_pvalue() not recognized")

    return(mi_pvalue)



def multiple_imputation_ttest (adata, predicted, calib_genes, condition, n_imputations=100,
                               group1=None, group2=None, symmetric=False, save_mi=False):
    '''
    Runs TISSUE multiple imputation two-sample t-test using Rubin's rules
    
    See multiple_imputation_testing() for details on parameters
    
    Additional Parameters
    ---------------------
        save_mi [False or str] - if not False, then saves "multiple_imputations.npy" stacked matrix of imputed gene expression at save_mi path -- NOTE: this requires large memory
    '''

    # get uncertainties and scores from saved adata
    scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)
    
    ### Building calibration sets for scores
    
    scores_flattened_dict = build_calibration_scores(adata, predicted, calib_genes, symmetric=symmetric,
                                                     include_zero_scores=True, trim_quantiles=[None, 0.8]) # trim top 20% scores
    
    ### Multiple imputation

    # init dictionary to hold results (for independent two-sample t-test)
    stat_dict = {}
    stat_dict["mean_difference"] = {}
    stat_dict["standard_deviation"] = {}
    
    # cast condition to str
    condition = str(condition)
    
    new_G_list = [] # for saving multiple imputations
    
    for m in range(n_imputations):
        
        # generate new imputation
        new_G = sample_new_imputation_from_scores (G, G_stdev, groups, scores_flattened_dict, symmetric=symmetric)
        if save_mi is not False:
            new_G_list.append(new_G)
    
        # calculate statistics for the imputation using approach from Palmer & Peer, 2016
        
        if group1 is None: # pairwise comparisons against all
            
            for g1 in np.unique(adata.obs[condition]):
                
                key = str(g1)+"_all"
            
                if m == 0: # init list
                    stat_dict["mean_difference"][key] = []
                    stat_dict["standard_deviation"][key] = []
                
                g1_bool = (adata.obs[condition] == g1) # g1
                g2_bool = (adata.obs[condition] != g1) # all other
                
                mean_diff, pooled_sd = get_ttest_stats(new_G, g1_bool, g2_bool) # get ttest stats
                stat_dict["mean_difference"][key].append(mean_diff)
                stat_dict["standard_deviation"][key].append(pooled_sd)
                
        elif group2 is None: # group1 vs all
        
            key = str(group1)+"_all"
            
            if m == 0: # init list
                stat_dict["mean_difference"][key] = []
                stat_dict["standard_deviation"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] != group1) # all other
            
            mean_diff, pooled_sd = get_ttest_stats(new_G, g1_bool, g2_bool) # get ttest stats
            stat_dict["mean_difference"][key].append(mean_diff)
            stat_dict["standard_deviation"][key].append(pooled_sd)
            
        else: # group1 vs group2
            
            key = str(group1)+"_"+str(group2)
            
            if m == 0: # init list
                stat_dict["mean_difference"][key] = []
                stat_dict["standard_deviation"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] == group2) # g2
            
            mean_diff, pooled_sd = get_ttest_stats(new_G, g1_bool, g2_bool) # get ttest stats
            stat_dict["mean_difference"][key].append(mean_diff)
            stat_dict["standard_deviation"][key].append(pooled_sd)

    # pool statistics and perform t-test
    pooled_results_dict = pool_multiple_stats(stat_dict)
     
    # add stats to adata
    keys_list = []
    for key_measure in pooled_results_dict.keys():
        for key_comparison in pooled_results_dict[key_measure].keys():
            adata.uns[predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure] = pd.DataFrame(pooled_results_dict[key_measure][key_comparison][None,:],
                                                                                                 columns=adata.obsm[predicted].columns)
            keys_list.append(predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure)
    
    # save multiple imputations
    if save_mi is not False:
        # stack all imputations and save
        stacked_mi = np.dstack(new_G_list)
        np.save(os.path.join(save_mi,f"{predicted}.npy"), stacked_mi)
    
    return(keys_list)


def multiple_imputation_gene_signature (sig_dirpath, adata, predicted, calib_genes, condition, n_imputations=100,
                                 group1=None, group2=None, symmetric=False, return_keys=False, load_mi=False):
    '''
    Uses multiple imputation with the score distributions to perform hypothesis testing on gene signatures
    
    Parameters
    ----------
        sig_dirpath [str] - path to the directory containing the gene signatures organized as:
                            sig_dirpath/
                                {name of signature 1}/
                                {name of signature N}/
                                    genes.txt - text file with each row being a gene name
                                    coefficients.txt - optional text file with each row being a float weight for corresponding gene
        adata [AnnData] - contains adata.obsm[predicted] corresponding to the predicted gene expression
        predicted [str] - key in adata.obsm that corresponds to predicted gene expression
        calib_genes [list or arr of str] - names of the genes in adata.var_names that are used in the calibration set
        condition [str] - key in adata.obs for which to compute the hypothesis test
            group1 [value] - value in adata.obs[condition] identifying the first comparison group
                             if None, will perform group vs all comparisons for all unique values in adata.obs[condition]
            group2 [value] - value in adata.obs[condition] identifying the second comparison group
                             if None, will compare against all values that are not group1
        n_imputations [int] - number of imputations to use
        symmetric [bool] - whether to have symmetric (or non-symmetric) prediction intervals
        return_keys [bool] - whether to return the keys for which to access the results from adata
        load_mi [bool] - whether to save "{predicted}.npy" stacked matrix of all multiple imputations at sig_dirpath 
        
    Returns
    -------
        Modifies adata in-place to add the statistics and test results to metadata
        Optionally returns the keys to access the results from adata
        
    '''
    #####################################################################
    # T-test (default) - this is the only option currently for signatures
    #####################################################################
    
    if load_mi is False:
        # get uncertainties and scores from saved adata
        scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)
        
        ### Building calibration sets for scores
        
        scores_flattened_dict = build_calibration_scores(adata, predicted, calib_genes, symmetric=symmetric,
                                                         include_zero_scores=True, trim_quantiles=[None, 0.8]) # trim top 20% scores
    else: # load in saved multiple imputations
        mi_path = os.path.join(sig_dirpath,f"{predicted}.npy") # path to saved multiple imputations
        mi_stacked = np.load(mi_path)
    
    ### Multiple imputation

    # init dictionary to hold results (for independent two-sample t-test)
    stat_dict = {}
    stat_dict["mean_difference"] = {}
    stat_dict["standard_deviation"] = {}
    
    # cast condition to str
    condition = str(condition)
    
    for m in range(n_imputations):
        
        # generate new imputation
        if load_mi is False:
            new_G = sample_new_imputation_from_scores (G, G_stdev, groups, scores_flattened_dict, symmetric=symmetric)
        else:
            new_G = mi_stacked[:,:,m].copy() # take the m-th multiple imputation
        
        # compute all signatures
        imputed_sigs = [] 
        sig_names = []
        
        for sigdir in next(os.walk(sig_dirpath))[1]: # iterate all top-level signature directories
            # read in genes
            with open(os.path.join(sig_dirpath,sigdir,"genes.txt")) as f:
                signature_genes = [line.rstrip() for line in f]
            signature_genes = np.array([x.lower() for x in signature_genes])
            # load coefficients (if any)
            if os.path.isfile(os.path.join(sig_dirpath,sigdir,"coefficients.txt")):
                signature_coefficients = np.loadtxt(os.path.join(sig_dirpath,sigdir,"coefficients.txt"))
            else:
                signature_coefficients = np.ones(len(signature_genes))
            # subset into shared genes
            shared_gene_idxs = [ii for ii in range(len(signature_genes)) if signature_genes[ii] in adata.obsm[predicted].columns]
            signature_genes = signature_genes[shared_gene_idxs]
            signature_coefficients = signature_coefficients[shared_gene_idxs]
            # if non-empty signature, then compute
            if len(signature_genes) > 0:
                # compute signature
                subset_new_G = pd.DataFrame(new_G, columns = adata.obsm[predicted].columns)[signature_genes].values
                sig_value = np.nansum(subset_new_G*signature_coefficients, axis=1)
                # append signature value and name
                imputed_sigs.append(sig_value)
                sig_names.append(sigdir)

        # construct gene signature matrix
        imputed_sigs = np.vstack(imputed_sigs).T
        
        # keep running average of imputed gene signatures
        if m == 0:
            mean_imputed_sigs = imputed_sigs * 1/n_imputations
        else:
            mean_imputed_sigs += imputed_sigs * 1/n_imputations
        
        # calculate statistics for the imputation using approach from Palmer & Peer, 2016
        
        if group1 is None: # pairwise comparisons against all
            
            for g1 in np.unique(adata.obs[condition]):
                
                key = str(g1)+"_all"
            
                if m == 0: # init list
                    stat_dict["mean_difference"][key] = []
                    stat_dict["standard_deviation"][key] = []
                
                g1_bool = (adata.obs[condition] == g1) # g1
                g2_bool = (adata.obs[condition] != g1) # all other
                
                mean_diff, pooled_sd = get_ttest_stats(imputed_sigs, g1_bool, g2_bool) # get ttest stats
                stat_dict["mean_difference"][key].append(mean_diff)
                stat_dict["standard_deviation"][key].append(pooled_sd)
                
        elif group2 is None: # group1 vs all
        
            key = str(group1)+"_all"
            
            if m == 0: # init list
                stat_dict["mean_difference"][key] = []
                stat_dict["standard_deviation"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] != group1) # all other
            
            mean_diff, pooled_sd = get_ttest_stats(imputed_sigs, g1_bool, g2_bool) # get ttest stats
            stat_dict["mean_difference"][key].append(mean_diff)
            stat_dict["standard_deviation"][key].append(pooled_sd)
            
        else: # group1 vs group2
            
            key = str(group1)+"_"+str(group2)
            
            if m == 0: # init list
                stat_dict["mean_difference"][key] = []
                stat_dict["standard_deviation"][key] = []
            
            g1_bool = (adata.obs[condition] == group1) # g1
            g2_bool = (adata.obs[condition] == group2) # g2
            
            mean_diff, pooled_sd = get_ttest_stats(imputed_sigs, g1_bool, g2_bool) # get ttest stats
            stat_dict["mean_difference"][key].append(mean_diff)
            stat_dict["standard_deviation"][key].append(pooled_sd)

    # pool statistics and perform t-test
    pooled_results_dict = pool_multiple_stats(stat_dict)
     
    # add stats to adata
    keys_list = []
    for key_measure in pooled_results_dict.keys():
        for key_comparison in pooled_results_dict[key_measure].keys():
            adata.uns[predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure] = pd.DataFrame(pooled_results_dict[key_measure][key_comparison][None,:],
                                                                                                 columns=sig_names)
            keys_list.append(predicted.split("_")[0]+"_"+key_comparison+"_"+key_measure)
                    
    # add gene sigs to adata
    adata.obsm[predicted+"_gene_signatures"] = pd.DataFrame(mean_imputed_sigs, columns=sig_names, index=adata.obs_names)
    
    if return_keys is True:
        
        return(keys_list)



def sample_new_imputation_from_scores (G, G_stdev, groups, scores_flattened_dict, symmetric=False):
    '''
    Creates a new imputation by sampling from scores and adding to G
    
    Parameters
    ----------
        G, G_stdev, groups - outputs of get_spatial_uncertainty_scores_from_metadata()
        scores_flattened_dict - output of build_calibration_scores()
    
    See multiple_imputation_testing() for more details of arguments
    
    Returns
    -------
        new_G - array of the new sampled predicted gene expression (same dimensions as new_G: cells x genes)
    '''
    new_scores = np.zeros(G.shape) # init array for sampled scores
    new_add_sub = np.zeros(G.shape) # init array for add/subtract coefs
    
    # for each group, sample calibration score and corresponding imputations
    unique_groups, unique_counts = np.unique(groups[~np.isnan(groups)], return_counts=True)
    
    for ui, group in enumerate(unique_groups):
        count = unique_counts[ui] # get number of values in group
        
        # sample scores and add/sub indicators
        if symmetric is True:
            scores_flattened = scores_flattened_dict[str(group)] # get scores
            if len(scores_flattened) < 100: # default to full set if <100 in group
                scores_flattened = scores_flattened_dict[str(np.nan)]
            sampled_scores = np.random.choice(scores_flattened, count, replace=True) # with replacement, sample scores
            add_sub = np.random.choice([-1,1], count, replace=True) # add or subtract
        else:
            scores_lo_flattened = scores_flattened_dict[str(group)][0]
            scores_hi_flattened = scores_flattened_dict[str(group)][1]
            if (len(scores_lo_flattened) < 100) or (len(scores_hi_flattened) < 100): # default to full set if <100 in group
                scores_lo_flattened = scores_flattened_dict[str(np.nan)][0]
                scores_hi_flattened = scores_flattened_dict[str(np.nan)][1]
            scores_flattened = np.concatenate((scores_lo_flattened, scores_hi_flattened))
            lo_hi_indicators = np.concatenate(([-1]*len(scores_lo_flattened), [1]*len(scores_hi_flattened)))
            # sample indices
            sampled_idxs = np.random.choice(np.arange(len(scores_flattened)), count, replace=True) # with replacement
            sampled_scores = scores_flattened[sampled_idxs]
            add_sub = lo_hi_indicators[sampled_idxs]
        
        # append to new_scores and new_add_sub
        new_scores[groups==group] = sampled_scores
        new_add_sub[groups==group] = add_sub
        
    # calculate new imputation
    new_G = G + new_add_sub*(new_scores*G_stdev)

    return (new_G)


def get_ttest_stats(G, g1_bool, g2_bool):
    '''
    Computes mean_diff and pooled SD for each column of G independently
    
    Parameters
    ----------
        G [array] - 2D array with columns as genes and rows as cells
        g1_bool [bool array] - 1D array with length equal to number of rows in G; labels group1
        g2_bool [bool array] - 1D array with length equal to number of rows in G; labels group2
        
    Returns
    -------
        mean_diff - mean difference for t-test
        pooled_sd - pooled standard deviation for t-test
    '''
    mean_diff = np.nanmean(G[g1_bool,:], axis=0) - np.nanmean(G[g2_bool,:], axis=0)
    n1 = np.count_nonzero(~np.isnan(G[g1_bool,:]), axis=0)
    n2 = np.count_nonzero(~np.isnan(G[g2_bool,:]), axis=0)
    sp = np.sqrt( ( (n1-1)*(np.nanvar(G[g1_bool,:],axis=0)) + (n2-1)*(np.nanvar(G[g2_bool,:],axis=0)) ) / (n1+n2-2) )
    pooled_sd = np.sqrt(1/n1 + 1/n2) * sp
    
    return(mean_diff, pooled_sd)


def two_sample_ttest (G, g1_bool, g2_bool):
    '''
    Computes two-sample t-test for unequal sample sizes using get_ttest_stats()
    
    Parameters
    ----------
        G [array] - 2D array with columns as genes and rows as cells
        g1_bool [bool array] - 1D array with length equal to number of rows in G; labels group1
        g2_bool [bool array] - 1D array with length equal to number of rows in G; labels group2
        
    Returns
    -------
        tt - t-statistic
        pp - p-value
    '''
    from scipy import stats
    # calculate t-stat
    mean_diff, pooled_sd = get_ttest_stats(G, g1_bool, g2_bool)
    tt = mean_diff/pooled_sd
    # calculate dof
    n1 = np.count_nonzero(~np.isnan(G[g1_bool,:]), axis=0)
    n2 = np.count_nonzero(~np.isnan(G[g2_bool,:]), axis=0)
    dof = n1+n2-2
    # calculate p-value
    pp = 2*(1 - stats.t.cdf(np.abs(tt), dof))
    
    return(tt, pp)


def pool_multiple_stats(stat_dict):
    '''
    Pool stats across multiple imputations for t-test
    
    Parameters
    ----------
        stat_dict [dict] - dictionary containing statistical testing results (generated in multiple_imputation_ttest())
        
    Returns
    -------
        results_dict [dict] - dictionary containing the pooled statistics from using Rubin's rules
    '''
    from scipy import stats
    
    # init results_dict
    results_dict = {}
    results_dict["tstat"] = {}
    results_dict["pvalue"] = {}
    
    results_dict["varw"] = {}
    results_dict["varb"] = {}
    results_dict["poolmean"] = {}
    
    for key in stat_dict["mean_difference"].keys():
        
        d = len(stat_dict["mean_difference"][key])
        
        # compute pooled terms
        pooled_mean = np.mean(np.vstack(stat_dict["mean_difference"][key]), axis=0)
        var_w = np.mean(np.vstack(stat_dict["standard_deviation"][key])**2, axis=0) # within-draw sample variance
        var_b = 1/(d-1) * np.sum((np.vstack(stat_dict["mean_difference"][key])-pooled_mean)**2, axis=0) # between-draw sample variance
        var_MI = var_w + (1+1/d)*var_b # multiple imputation variance
        
        test_stat = pooled_mean / np.sqrt(var_MI) # pooled t statistic
        
        # compute pvalue from T distribution
        dof = (d-1)*(1+(d*var_w)/((d+1)*var_b))**2 # degrees of freedom for T distribution
        pval = 2*(1 - stats.t.cdf(np.abs(test_stat), dof))
        
        # Add test statistic and pvalue
        results_dict["tstat"][key] = test_stat
        results_dict["pvalue"][key] = pval
        
        # Add intermediate stats (for debugging, etc)
        results_dict["varw"][key] = var_w
        results_dict["varb"][key] = var_b
        results_dict["poolmean"][key] = pooled_mean
    
    return(results_dict)



def weighted_PCA(adata, imp_method, pca_method="wpca", weighting="inverse_norm_pi_width", quantile_cutoff=None,
                 n_components=15, replace_inf=None, binarize=0.2, binarize_ratio=10, log_transform=False,
                 scale=True, tag="", return_weights=False,):
    '''
    Runs weighted PCA using the "wpca" package: https://github.com/jakevdp/wpca
    
    Parameters
    ----------
        adata [AnnData] - should be the AnnData after running conformalize_prediction_interval()
                        - must include in obsm: {imp_method}_predicted_expression,
                                                {imp_method}_predicted_expression_lo,
                                                {imp_method}_predicted_expression_hi
        imp_method [str] - specifies which imputation method to return PCA for (e.g. 'knn', 'spage', 'tangram')
        pca_method [str] - "wpca" for WPCA (Delchambre, 2014), "empca" for EMPCA (Bailey, 2012), "pca" for PCA
        weighting [str] - "uniform" (regular PCA)
                          "inverse_pi_width" (weights are 1/(prediction interval width))
                          "inverse_norm_pi_width" (weights are predicted expression/(prediction interval width))
        quantile_cutoff [None or float] - quantile (between 0 and 1) for which to set a ceiling for the weights
        n_components [int] - number of principal components
        replace_inf [None, str, float] - what to replace np.inf with (after all other weight transforms); if None, keeps np.inf
                                         can also be "max" or "min" to replace with the max or min weights
        binarize [bool] - binarizes the weights with Otsu threshold -- if larger than threshold, set to 1; else 1e-2
        binarize_ratio [int or float] - how much to "upweight" values greater than the binarized threshold
        log_transform [bool] - whether to log1p transform weights (will be done before binarization if binarize=True)
        scale [bool - whether to scale data with StandardScaler() before running WPCA
        tag [str] - additional tag to append to the obsm key for storing the PCs
        return_weights [bool] - whether to return weights used in WPCA
     
    Returns
    -------
        Stores the result in adata.obsm["{imp_method}_predicted_expression_PC{n_components}_{tag}"]
        Optionally returns the array of weights used in WPCA
    
    Refer to postprocess_weights() for order for weight calculations
    '''
    from wpca import PCA, WPCA, EMPCA
    
    predicted = f"{imp_method}_predicted_expression"
    
    # get gene names/order
    genes = adata.obsm[predicted].columns
    
    # determine weights
    if weighting == "inverse_pi_width":
        weights = 1/(adata.obsm[predicted+'_hi'][genes].values-adata.obsm[predicted+'_lo'][genes].values)
        weights = postprocess_weights(weights, quantile_cutoff, replace_inf, binarize, binarize_ratio, log_transform)
    elif weighting == "inverse_norm_pi_width":
        weights = 1/(adata.obsm[predicted+'_hi'][genes].values-adata.obsm[predicted+'_lo'][genes].values)
        weights = weights / np.nanmean(weights, axis=0)
        weights = postprocess_weights(weights, quantile_cutoff, replace_inf, binarize, binarize_ratio, log_transform)
    elif weighting == "uniform":
        weights = np.ones(adata.obsm[predicted].shape)
    elif weighting == "inverse_residual":
        weights = 1/np.abs(adata.obsm[predicted][genes].values - np.array(adata[:,genes].X))
        weights = postprocess_weights(weights, quantile_cutoff, replace_inf, binarize, binarize_ratio, log_transform)
    elif weighting == "inverse_norm_residual":
        weights = 1/np.abs(adata.obsm[predicted][genes].values - np.array(adata[:,genes].X))
        weights = weights / np.nanmean(weights, axis=0)
        weights = postprocess_weights(weights, quantile_cutoff, replace_inf, binarize, binarize_ratio, log_transform)
    else:
        raise Exception("weighting not recognized")
    
    # scaling
    if scale is True:
        X = StandardScaler().fit_transform(adata.obsm[predicted].values)
    else:
        X = adata.obsm[predicted].values
    
    # run weighted PCA
    if pca_method == "wpca":
        X_red = WPCA(n_components=n_components).fit_transform(X, weights=weights)
    elif pca_method == "empca":
        X_red = EMPCA(n_components=n_components).fit_transform(X, weights=weights)
    elif pca_method == "pca":
        X_red = PCA(n_components=n_components).fit_transform(X)
    elif pca_method == "gwpca": # gene-weighted PCA
        weights = np.nanmean(weights, axis=0)
        X_red = PCA(n_components=n_components).fit_transform(X * weights)
    else:
        raise Exception("pca_method not recognized")
        
    # add PCs to adata
    adata.obsm[predicted+f"_PC{n_components}_{tag}"] = X_red

    if return_weights is True:
        return(weights)


def postprocess_weights(weights, quantile_cutoff, replace_inf, binarize, binarize_ratio, log_transform):
    '''
    Method for postprocessing weights (filter with cutoff, replace inf, etc) for weighted_PCA()
    
    Refer to weighted_pca() for details on arguments
    '''
    # cutoff weights
    if quantile_cutoff is not None:
        cutoff = np.nanquantile(weights, quantile_cutoff)
        weights[np.isfinite(weights) & (weights >= cutoff)] = cutoff
    
    # log-transform
    if log_transform is True:
        weights = np.log1p(weights)
    
    # binarize weights
    if binarize is True:
        from skimage.filters import threshold_otsu
        cutoff = threshold_otsu(weights[np.isfinite(weights)])
        weights[np.isfinite(weights) & (weights >= cutoff)] = 1
        weights[np.isfinite(weights) & (weights < cutoff)] = 1/binarize_ratio
    elif binarize is False:
        pass
    elif isinstance(binarize, float) or isinstance(binarize, int):
        cutoff = np.nanquantile(weights, binarize)
        weights[np.isfinite(weights) & (weights >= cutoff)] = 1
        weights[np.isfinite(weights) & (weights < cutoff)] = 1/binarize_ratio
        
    # deal with infs (from division by zero)
    if replace_inf == "max":
        weights[~np.isfinite(weights)] = np.nanmax(weights[np.isfinite(weights)])
    elif replace_inf == "min":
        weights[~np.isfinite(weights)] = np.nanmin(weights[np.isfinite(weights)])
    elif replace_inf == "mean":
        weights[~np.isfinite(weights)] = np.nanmean(weights[np.isfinite(weights)])
    elif replace_inf == "median":
        weights[~np.isfinite(weights)] = np.nanmedian(weights[np.isfinite(weights)])
    elif isinstance(replace_inf, float) or isinstance(replace_inf, int):
        weights[~np.isfinite(weights)] = replace_inf
    
    return(weights)


def filtered_PCA(adata, imp_method, proportion=0.05, stratification=None, n_components=15, scale=True, normalize=False,
                 tag="", return_keep_idxs=False):
    '''
    Runs filtered PCA using the TISSUE cell filtering approach
    
    Parameters
    ----------
        adata [AnnData] - should be the AnnData after running conformalize_prediction_interval()
                        - must include in obsm: {imp_method}_predicted_expression,
                                                {imp_method}_predicted_expression_lo,
                                                {imp_method}_predicted_expression_hi
        imp_method [str] - specifies which imputation method to return PCA for (e.g. 'knn', 'spage', 'tangram')
        proportion [float] - between 0 and 1; proportion of most uncertain cells to drop
        stratification [None or 1d numpy array] - array of values to stratify the drop by
                                                - same length as number of rows in X
                                                - if None, no stratification
        n_components [int] - number of principal components
        scale [bool] - whether to scale data with StandardScaler() before running PCA
        normalize [bool] - whether to normalize prediction interval width by the absolute predicted expression value
        tag [str] - additional tag to append to the obsm key for storing the PCs
        return_keep_idxs [bool] - whether to return the keep_idxs for filtering
    
    Returns
    -------
        Stores the result in adata.obsm["{imp_method}_predicted_expression_PC{n_components}_{tag}"]
        Optionally returns the indices corresponding to the observations to keep after filtering
    '''    
    predicted = f"{imp_method}_predicted_expression"
    
    # get predicted expression matrices
    X = adata.obsm[predicted].values.copy()
    
    # get uncertainty (PI width) for filtering
    X_uncertainty = adata.obsm[f'{predicted}_hi'].values - adata.obsm[f'{predicted}_lo'].values
    if normalize is True:
        X_uncertainty = X_uncertainty / (1+np.abs(adata.obsm[f'{predicted}'].values))
    
    # filter cells
    keep_idxs = detect_uncertain_cells(X_uncertainty, proportion=proportion, stratification=stratification)
    X_filtered = X[keep_idxs,:].copy()
    
    # scaling
    if scale is True:
        scaler = StandardScaler().fit(X_filtered)
        X = scaler.transform(X)
        X_filtered = scaler.transform(X_filtered)
    
    # run PCA
    pca = PCA(n_components=n_components).fit(X_filtered)
    X_red = pca.transform(X)
    X_red_filtered = pca.transform(X_filtered)
        
    # add PCs to adata
    adata.obsm[predicted+f"_PC{n_components}_{tag}"] = X_red
    adata.uns[predicted+f"_PC{n_components}_filtered_{tag}"] = X_red_filtered
    
    if return_keep_idxs is True:
        return (keep_idxs)



def detect_uncertain_cells (X, proportion=0.05, stratification=None):
    '''
    Method for dropping a portion of the most uncertain cells from the input. 
    
    Parameters
    ----------
        X [2d numpy array] - array of uncertainty values 
        proportion [float] - between 0 and 1; proportion of most uncertain cells to drop
        stratification [None or 1d numpy array] - array of values to stratify the drop by
                                                - same length as number of rows in X
                                                - if None, no stratification
        
    Returns
    -------
        keep_idxs [list] - array of row indices after dropping most uncertain cells
    '''
    from scipy.stats import zscore
    
    if stratification is not None: # drop cells within each strata independently
    
        drop_idxs = []
        
        for strata in np.unique(stratification):
            
            # compute scores
            X_strat = X[stratification==strata,:].copy() # calc gene z-scores
            orig_idxs = np.arange(X.shape[0])[stratification==strata]
            cell_scores = np.nanmean(zscore(X_strat, axis=0), axis=1) # average z-score for each cell
            
            # determine cutoff score and indices to drop
            if (isinstance(proportion, float)) or (isinstance(proportion, int)):
                cutoff_idx = int(np.ceil(proportion*len(cell_scores))) # number of cells to drop
                strata_drop_idxs = np.argsort(cell_scores)[::-1][:cutoff_idx]
            elif proportion == "otsu":
                from skimage.filters import threshold_otsu
                cutoff = threshold_otsu(cell_scores)
                strata_drop_idxs = [i for i in range(len(cell_scores)) if cell_scores[i] > cutoff]
            else:
                raise Exception("proportion specified not valid")
                
            drop_idxs.append(orig_idxs[strata_drop_idxs]) # get idxs of highest scores
            
        drop_idxs = list(np.concatenate(drop_idxs))
    
    else:
        
        # compute scores
        cell_scores = zscore(X, axis=0).mean(axis=1) # average z-score for each cell
        
        # determine cutoff score and indices to drop
        if (isinstance(proportion, float)) or (isinstance(proportion, int)):
            cutoff_idx = int(np.ceil(proportion*len(cell_scores))) # number of cells to drop
            drop_idxs = list(np.argsort(cell_scores)[::-1][:cutoff_idx]) # get idxs of highest scores
        elif proportion == "otsu":
            from skimage.filters import threshold_otsu
            cutoff = threshold_otsu(cell_scores)
            drop_idxs = [i for i in range(len(cell_scores)) if cell_scores[i] > cutoff] 
        else:
            raise Exception("proportion specified not valid")
    
    # return keep indices (determined as indices not in drop indices)
    keep_idxs = [i for i in range(X.shape[0]) if i not in drop_idxs]
    
    return (keep_idxs)

# Contains compound functions for generating results for experiments with TISSUE
# These are unlikely to be used for general applications but were used in our development/testing of TISSUE

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import squidpy as sq
import anndata as ad
import warnings
import os
import gc

#from tissue.utils import large_save, large_load
from .utils import large_save, large_load
from .main import load_spatial_data, conformalize_prediction_interval, get_spatial_uncertainty_scores_from_metadata
from .downstream import multiple_imputation_testing


def group_conformalize_from_intermediate(dataset_name, methods, symmetric, alpha_levels,
                                         save_alpha=[0.05], savedir="SCPI", type_dataset="DataUpload"):
    '''
    Function for taking intermediate fold predictions and running group conformalization for all different alpha values
    
    Returns a results dictionary with calibration quality (res_dict) and the AnnData with CI for all folds at alpha of save_alpha [float]
    
    Parameters
    ----------
        dataset_name [str] - name of folder in DataUpload/
        methods [list of str] - list of method keys to use for prediction_sets
        symmetric [bool] - whether to use symmetric prediction intervals
        alpha_levels [array] - alpha levels to calibrate over
        save_alpha [list of float] - alphas to save prediction intervals into adata.obsm
        savedir [str] - folder where the intermediate results are saved (independent folds)
        type_dataset [str] - default to "DataUpload" but may have additional options in the future
        
    Returns
    -------
        res_dict [dict] - dictionary of calibration statistics / coverage statistics across the alpha levels
        adata [AnnData] - anndata with calibration results added to metadata
    '''
    # read in spatial data
    if type_dataset == "DataUpload":
        if os.path.isfile("DataUpload/"+dataset_name+"/Metadata.txt"):
            adata = load_spatial_data("DataUpload/"+dataset_name+"/Spatial_count.txt",
                                      "DataUpload/"+dataset_name+"/Locations.txt",
                                       spatial_metadata = "DataUpload/"+dataset_name+"/Metadata.txt")
        else:
            adata = load_spatial_data("DataUpload/"+dataset_name+"/Spatial_count.txt",
                                      "DataUpload/"+dataset_name+"/Locations.txt")
    else:
        adata = sc.read_h5ad(os.path.join("additional_data",dataset_name,"spatial.h5ad"))
    adata.var_names = [x.lower() for x in adata.var_names]
    
    # results dict
    res_dict = {}
    
    for method in methods:

        res_dict[method] = {}
        res_dict[method]['ind_gene_results'] = {}

        calibration_weight = 0 # for computing weighted average
        test_weight = 0

        dirpath = savedir+"/"+dataset_name+"_intermediate/"+method
        
        folds = np.load(os.path.join(savedir+"/"+dataset_name+"_intermediate/"+method,"folds.npy"), allow_pickle=True)

        # subset spatial data into shared genes
        gene_names = np.concatenate(folds)
        adata = adata[:, gene_names]

        for i, fold in enumerate(folds):

            # load adata within fold
            sub_adata = large_load(os.path.join(dirpath, "fold"+str(i)))
            target_genes = list(fold)

            # subset data
            predicted = method+"_predicted_expression"
            test_genes = target_genes.copy()
            calib_genes = [gene for gene in gene_names if gene not in test_genes]
            test_idxs = [np.where(sub_adata.obsm[predicted].columns==gene)[0][0] for gene in test_genes]
            calib_idxs = [np.where(sub_adata.obsm[predicted].columns==gene)[0][0] for gene in calib_genes]

            # get uncertainties and scores from saved adata
            scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (sub_adata, predicted)

            # init dict for individual gene results
            for g in test_genes:
                if g not in res_dict[method]['ind_gene_results'].keys():
                    res_dict[method]['ind_gene_results'][g] = {}
                    res_dict[method]['ind_gene_results'][g]['1-alpha'] = 1-alpha_levels
                    res_dict[method]['ind_gene_results'][g]['test'] = []

            # iterate over different alphas for conformalization
            test_perc = []
            calib_perc = []

            for alpha_level in alpha_levels:
                sub_adatac = sub_adata.copy()
                conformalize_prediction_interval (sub_adatac, predicted, calib_genes, alpha_level=alpha_level,
                                                  symmetric=symmetric, return_scores_dict=False)
                
                prediction_sets = (sub_adatac.obsm[predicted+"_lo"].values, sub_adatac.obsm[predicted+"_hi"].values)
                
                test_perc.append(np.nanmean(((adata[:,test_genes].X>prediction_sets[0][:,test_idxs]) & (adata[:,test_genes].X<prediction_sets[1][:,test_idxs]))[(G[:,test_idxs]!=0)&(G_stdev[:,test_idxs]!=0)&(adata[:,test_genes].X!=0)]))
                calib_perc.append(np.nanmean(((adata[:,calib_genes].X>prediction_sets[0][:,calib_idxs]) & (adata[:,calib_genes].X<prediction_sets[1][:,calib_idxs]))[(G[:,calib_idxs]!=0)&(G_stdev[:,calib_idxs]!=0)&(adata[:,calib_genes].X!=0)]))

                # Compute individual calibration curves for each gene
                for ti, tg in zip(test_idxs, test_genes):
                    if sub_adatac.obsm[predicted].columns[ti] != tg:
                        raise Warning ("ti not equal to tg: "+str(adata.var_names[ti])+" != "+str(tg))
                    ind_test_ci = np.nanmean(((adata[:,tg].X>prediction_sets[0][:,ti]) & (adata[:,tg].X<prediction_sets[1][:,ti]))[(G[:,ti]!=0)&(G_stdev[:,ti]!=0)&(adata[:,tg].X!=0)])
                    res_dict[method]['ind_gene_results'][tg]['test'].append(ind_test_ci)
                    
                del sub_adatac
                del prediction_sets
                del ind_test_ci
                gc.collect()

            # weighted average
            if i == 0:
                calibration_ci = np.array(calib_perc) * len(calib_genes)
                calibration_weight += len(calib_genes)
                test_ci = np.array(test_perc) * len(test_genes)
                test_weight += len(test_genes)
            else:
                calibration_ci += np.array(calib_perc) * len(calib_genes)
                calibration_weight += len(calib_genes)
                test_ci += np.array(test_perc) * len(test_genes)
                test_weight += len(test_genes)
                
            # Add new predictions
            for si, s_alpha in enumerate(save_alpha):
                conformalize_prediction_interval (sub_adata, predicted, calib_genes, alpha_level=s_alpha,
                                                  symmetric=symmetric, return_scores_dict=False)
                
                if i == 0:
                    if si == 0: # to avoid overwriting multiple times
                        adata.obsm[predicted] = pd.DataFrame(sub_adata.obsm[predicted][fold].values,
                                                          columns=fold,
                                                          index=adata.obs_names)
                    adata.obsm[predicted+f"_lo_{round((1-s_alpha)*100)}"] = pd.DataFrame(sub_adata.obsm[predicted+"_lo"][fold].values,
                                                      columns=fold,
                                                      index=adata.obs_names)
                    adata.obsm[predicted+f"_hi_{round((1-s_alpha)*100)}"] = pd.DataFrame(sub_adata.obsm[predicted+"_hi"][fold].values,
                                                      columns=fold,
                                                      index=adata.obs_names)
                else:
                    if si == 0: # to avoid overwriting multiple times
                        adata.obsm[predicted][fold] = sub_adata.obsm[predicted][fold].values.copy()
                    adata.obsm[predicted+f"_lo_{round((1-s_alpha)*100)}"][fold] = sub_adata.obsm[predicted+"_lo"][fold].values.copy()
                    adata.obsm[predicted+f"_hi_{round((1-s_alpha)*100)}"][fold] = sub_adata.obsm[predicted+"_hi"][fold].values.copy()
                
            del sub_adata
            gc.collect()

        # add results
        calibration_ci = calibration_ci / calibration_weight
        test_ci = test_ci / test_weight

        res_dict[method]['1-alpha'] = 1-alpha_levels
        res_dict[method]['calibration'] = calibration_ci
        res_dict[method]['test'] = test_ci
        
    return(res_dict, adata)


def measure_calibration_error (res_dict, key, method="average"):
    '''
    Scores the calibration results from the res_dict object (dictionary output of group_conformalize_from_intermediate())
    
    Parameters
    ----------
        res_dict [python dict]
        key [str] - key to access for scoring (i.e. the model name)
        method [str] = "average" or "gene" to report either the results on average calibration or average metric across all genes
        
    Returns
    -------
        score [float] - score for calibration error (lower is better)
    '''        
    from sklearn.metrics import auc
    
    if method == "gene":    
        auc_diffs = []
            
        for gene in res_dict[key]['ind_gene_results'].keys():
            diff = np.abs(res_dict[key]['ind_gene_results'][gene]['test'] - res_dict[key]['ind_gene_results'][gene]['1-alpha'])            
            auc_diffs.append(np.trapz(y=diff, x=res_dict[key]['ind_gene_results'][gene]['1-alpha']))
                
        score = np.nanmean(np.abs(auc_diffs))
        
    else:
        diff = np.abs(res_dict[key]['test'] - res_dict[key]['1-alpha'])            
        score = np.abs(np.trapz(y=diff, x=res_dict[key]['1-alpha']))
    
    return (score)


def group_multiple_imputation_testing_from_intermediate(dataset_name, methods, symmetric, condition, n_imputations=100,
                                                        group1=None, group2=None, savedir="SCPI", test="ttest"):
    '''
    Function for taking intermediate fold predictions and running multiple imputation t-tests
    
    Returns AnnData object with all test results saved in adata.var
    
    Parameters
    ----------
        dataset_name [str] - name of folder in DataUpload/
        methods [list of str] - list of method keys to use for prediction_sets
        symmetric [bool] - whether to use symmetric prediction intervals
        condition [str] - key in adata.obs to use for testing
        n_imputations [int] - number of multiple imputations
        group1 [None or str] - value in condition to use for group1 (if None, then will get results for all unique values)
        group2 [None or str] - value in condition to use for group2 (if None, then will use all other values as group2)
        savedir [str] - folder where the intermediate results are saved (independent folds)
        type_dataset [str] - default to "DataUpload" but may have additional options in the future
        
    Returns
    -------
        adata [AnnData] - anndata with testing results added to metadata
    ''' 
    # read in spatial data
    if os.path.isfile("DataUpload/"+dataset_name+"/Metadata.txt"):
        adata = load_spatial_data("DataUpload/"+dataset_name+"/Spatial_count.txt",
                                  "DataUpload/"+dataset_name+"/Locations.txt",
                                   spatial_metadata = "DataUpload/"+dataset_name+"/Metadata.txt")
    else:
        adata = load_spatial_data("DataUpload/"+dataset_name+"/Spatial_count.txt",
                                  "DataUpload/"+dataset_name+"/Locations.txt")
    adata.var_names = [x.lower() for x in adata.var_names]
    
    for method in methods:

        dirpath = savedir+"/"+dataset_name+"_intermediate/"+method
        
        folds = np.load(os.path.join(savedir+"/"+dataset_name+"_intermediate/"+method,"folds.npy"), allow_pickle=True)

        # subset spatial data into shared genes
        gene_names = np.concatenate(folds)
        adata = adata[:, gene_names]

        for i, fold in enumerate(folds):

            # load adata within fold
            sub_adata = large_load(os.path.join(dirpath, "fold"+str(i)))
            target_genes = list(fold)

            # subset data
            predicted = method+"_predicted_expression"
            test_genes = target_genes.copy()
            calib_genes = [gene for gene in gene_names if gene not in test_genes]
            test_idxs = [np.where(sub_adata.obsm[predicted].columns==gene)[0][0] for gene in test_genes]
            calib_idxs = [np.where(sub_adata.obsm[predicted].columns==gene)[0][0] for gene in calib_genes]

            # run multiple imputation test
            keys_list = multiple_imputation_testing (sub_adata, predicted, calib_genes, condition, n_imputations=n_imputations,
                                                     group1=group1, group2=group2, symmetric=symmetric, return_keys=True, test=test)
            
            if i == 0:
                for key in keys_list:
                    adata.var[key] = np.zeros(adata.shape[1])
                    
                adata.obsm[predicted] = pd.DataFrame(sub_adata.obsm[predicted][fold].values,
                                                  columns=fold,
                                                  index=adata.obs_names)
            for key in keys_list:
                adata.var[key][[np.where(adata.var_names==gene)[0][0] for gene in fold]] = sub_adata.uns[key][fold].values.flatten()
                adata.obsm[predicted][fold] = sub_adata.obsm[predicted][fold].values.copy()
                
    return(adata)

# Contains main functions for core TISSUE pipeline: computing cell-centric variability and calibrated prediction intervals

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import squidpy as sq
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold, StratifiedKFold
import anndata as ad
import warnings
import os


def load_paired_datasets (spatial_counts, spatial_loc, RNAseq_counts, spatial_metadata = None,
                          min_cell_prevalence_spatial = 0.0, min_cell_prevalence_RNAseq = 0.01,
                          min_gene_prevalence_spatial = 0.0, min_gene_prevalence_RNAseq = 0.0):
    '''
    Uses datasets in the format specified by Li et al. (2022)
        See: https://drive.google.com/drive/folders/1pHmE9cg_tMcouV1LFJFtbyBJNp7oQo9J
    
    Parameters
    ----------
        spatial_counts [str] - path to spatial counts file; rows are cells
        spatial_loc [str] - path to spatial locations file; rows are cells
        RNAseq_counts [str] - path to RNAseq counts file; rows are genes
        spatial_metadata [None or str] - if not None, then path to spatial metadata file (will be read into spatial_adata.obs)
        min_cell_prevalence_spatial [float between 0 and 1] - minimum prevalence among cells to include gene in spatial anndata object, default=0
        min_cell_prevalence_RNAseq [float between 0 and 1] - minimum prevalence among cells to include gene in RNAseq anndata object, default=0.01
        min_gene_prevalence_spatial [float between 0 and 1] - minimum prevalence among genes to include cell in spatial anndata object, default=0
        min_gene_prevalence_RNAseq [float between 0 and 1] - minimum prevalence among genes to include cell in RNAseq anndata object, default=0
    
    Returns
    -------
        spatial_adata, RNAseq_adata - AnnData objects with counts and location (if applicable) in metadata
    '''
    # Spatial data loading
    spatial_adata = load_spatial_data (spatial_counts,
                                       spatial_loc,
                                       spatial_metadata = spatial_metadata,
                                       min_cell_prevalence_spatial = min_cell_prevalence_spatial,
                                       min_gene_prevalence_spatial = min_gene_prevalence_spatial)
    
    # RNAseq data loading
    RNAseq_adata = load_rnaseq_data (RNAseq_counts,
                                     min_cell_prevalence_RNAseq = min_cell_prevalence_RNAseq,
                                     min_gene_prevalence_RNAseq = min_gene_prevalence_RNAseq)

    return(spatial_adata, RNAseq_adata)


def load_spatial_data (spatial_counts, spatial_loc, spatial_metadata=None,
                       min_cell_prevalence_spatial = 0.0, min_gene_prevalence_spatial = 0.0):
    '''
    Loads in spatial data from text files.
    
    See load_paired_datasets() for details on arguments
    '''
    # read in spatial counts
    df = pd.read_csv(spatial_counts,header=0,sep="\t")
    
    # filter lowly expressed genes
    cells_prevalence = np.mean(df.values>0, axis=0)
    df = df.loc[:,cells_prevalence > min_cell_prevalence_spatial]
    
    # filter sparse cells
    genes_prevalence = np.mean(df.values>0, axis=1)
    df = df.loc[genes_prevalence > min_gene_prevalence_spatial,:]
    
    # create AnnData
    spatial_adata = ad.AnnData(X=df, dtype='float64')
    spatial_adata.obs_names = df.index.values
    spatial_adata.obs_names = spatial_adata.obs_names.astype(str)
    spatial_adata.var_names = df.columns
    del df
    
    # add spatial locations
    locations = pd.read_csv(spatial_loc,header=0,delim_whitespace=True)
    spatial_adata.obsm["spatial"] = locations.loc[genes_prevalence > min_gene_prevalence_spatial, :].values
    
    # add metadata
    if spatial_metadata is not None:
        metadata_df = pd.read_csv(spatial_metadata)
        metadata_df = metadata_df.loc[genes_prevalence > min_gene_prevalence_spatial, :]
        metadata_df.index = spatial_adata.obs_names
        spatial_adata.obs = metadata_df
    
    # remove genes with nan values
    spatial_adata = spatial_adata[:,np.isnan(spatial_adata.X).sum(axis=0)==0].copy()
    
    # make unique obs_names and var_names
    spatial_adata.obs_names_make_unique()
    spatial_adata.var_names_make_unique()
    
    return (spatial_adata)


def load_rnaseq_data (RNAseq_counts, min_cell_prevalence_RNAseq = 0.0, min_gene_prevalence_RNAseq = 0.0):
    '''
    Loads in scRNAseq data from text files.
    
    See load_paired_datasets() for details on arguments
    '''
    # read in RNAseq counts
    df = pd.read_csv(RNAseq_counts,header=0,index_col=0,sep="\t")
    
    # filter lowly expressed genes -- note that df is transposed gene x cell
    cells_prevalence = np.mean(df>0, axis=1)
    df = df.loc[cells_prevalence > min_cell_prevalence_RNAseq,:]
    del cells_prevalence
    
    # filter sparse cells
    genes_prevalence = np.mean(df>0, axis=0)
    df = df.loc[:,genes_prevalence > min_gene_prevalence_RNAseq]
    del genes_prevalence
    
    # create AnnData
    RNAseq_adata = ad.AnnData(X=df.T, dtype='float64')
    RNAseq_adata.obs_names = df.T.index.values
    RNAseq_adata.var_names = df.T.columns
    del df
    
    # remove genes with nan values
    RNAseq_adata = RNAseq_adata[:,np.isnan(RNAseq_adata.X).sum(axis=0)==0].copy()
    
    # make unique obs_names and var_names
    RNAseq_adata.obs_names_make_unique()
    RNAseq_adata.var_names_make_unique()
    
    return (RNAseq_adata)



def preprocess_data (adata, standardize=False, normalize=False):
    '''
    Preprocesses adata inplace:
        1. sc.pp.normalize_total() if normalize is True
        2. sc.pp.log1p() if normalize is True
        3. Not recommended: standardize each gene (subtract mean, divide by standard deviation)
    
    Parameters
    ----------
        standardize [Boolean] - whether to standardize genes; default is False
        normalize [Boolean] - whether to normalize data; default is False (based on finding by Li et al., 2022)
    
    Returns
    -------
        Modifies adata in-place
    
    NOTE: Under current default settings for TISSUE, this method does nothing to adata
    '''
    # normalize data
    if normalize is True:
        sc.pp.normalize_total(adata)
        sc.pp.log1p(adata)
    
    # standardize data
    if standardize is True:
        adata.X = np.divide(adata.X - np.mean(adata.X, axis=0), np.std(adata.X, axis=0))


def build_spatial_graph (adata, method="fixed_radius", spatial="spatial", radius=None, n_neighbors=20, set_diag=True):
    '''
    Builds a spatial graph from AnnData according to specifications. Uses Squidpy implementations for building spatial graphs.
    
    Parameters
    ----------
        adata [AnnData] - spatial data, must include adata.obsm[spatial]
        method [str]:
            - "radius" (all cells within radius are neighbors)
            - "delaunay" (triangulation)
            - "delaunay_radius" (triangulation with pruning by max radius; DEFAULT)
            - "fixed" (the k-nearest cells are neighbors determined by n_neighbors)
            - "fixed_radius" (knn by n_neighbors with pruning by max radius)
        spatial [str] - column name for adata.obsm to retrieve spatial coordinates
        radius [None or float/int] - radius around cell centers for which to detect neighbor cells; defaults to Q3+1.5*IQR of delaunay (or fixed for fixed_radius) neighbor distances
        n_neighbors [None or int] - number of neighbors to get for each cell (if method is "fixed" or "fixed_radius" or "radius_fixed"); defaults to 20
        set_diag [True or False] - whether to have diagonal of 1 in adjacency (before normalization); False is identical to theory and True is more robust; defaults to True
    
    Returns
    -------
        Modifies adata in-place
    '''
    # delaunay graph
    if method == "delaunay": # triangulation only
        sq.gr.spatial_neighbors(adata, delaunay=True, coord_type="generic", set_diag=set_diag)
    
    # neighborhoods determined by fixed radius
    elif method == "radius":
        if radius is None: # compute 90th percentile of delaunay triangulation as default radius
            sq.gr.spatial_neighbors(adata, delaunay=True, coord_type="generic")
            if isinstance(adata.obsp["spatial_distances"],np.ndarray): # numpy array
                dists = adata.obsp['spatial_distances'][np.nonzero(adata.obsp['spatial_distances'])] # get nonzero array
            else: # sparse matrix
                adata.obsp['spatial_distances'].eliminate_zeros() # remove hard-set zeros
                dists = adata.obsp['spatial_distances'].data # get non-zero values in sparse matrix
            radius = np.percentile(dists, 75) + 1.5*(np.percentile(dists, 75) - np.percentile(dists, 25))
        # build graph
        sq.gr.spatial_neighbors(adata, radius=radius, coord_type="generic", set_diag=set_diag)
    
    # delaunay graph with removal of outlier edges with distance > radius
    elif method == "delaunay_radius":
        # build initial graph
        sq.gr.spatial_neighbors(adata, delaunay=True, coord_type="generic", set_diag=set_diag)
        if radius is None: # compute default radius as 75th percentile + 1.5*IQR
            if isinstance(adata.obsp["spatial_distances"],np.ndarray): # numpy array
                dists = adata.obsp['spatial_distances'][np.nonzero(adata.obsp['spatial_distances'])] # get nonzero array
            else: # sparse matrix
                adata.obsp['spatial_distances'].eliminate_zeros() # remove hard-set zeros
                dists = adata.obsp['spatial_distances'].data # get non-zero values in sparse matrix
            radius = np.percentile(dists, 75) + 1.5*(np.percentile(dists, 75) - np.percentile(dists, 25))
        # prune edges by radius
        adata.obsp['spatial_connectivities'][adata.obsp['spatial_distances']>radius] = 0
        adata.obsp['spatial_distances'][adata.obsp['spatial_distances']>radius] = 0
    
    # fixed neighborhood size with removal of outlier edges with distance > radius
    elif method == "fixed_radius":
        # build initial graph
        sq.gr.spatial_neighbors(adata, n_neighs=n_neighbors, coord_type="generic", set_diag=set_diag)
        if radius is None: # compute default radius as 75th percentile + 1.5*IQR
            if isinstance(adata.obsp["spatial_distances"],np.ndarray): # numpy array
                dists = adata.obsp['spatial_distances'][np.nonzero(adata.obsp['spatial_distances'])] # get nonzero array
            else: # sparse matrix
                adata.obsp['spatial_distances'].eliminate_zeros() # remove hard-set zeros
                dists = adata.obsp['spatial_distances'].data # get non-zero values in sparse matrix
            radius = np.percentile(dists, 75) + 1.5*(np.percentile(dists, 75) - np.percentile(dists, 25))
        # prune edges by radius
        adata.obsp['spatial_connectivities'][adata.obsp['spatial_distances']>radius] = 0
        adata.obsp['spatial_distances'][adata.obsp['spatial_distances']>radius] = 0
            
    # fixed neighborhood size
    elif method == "fixed":
        sq.gr.spatial_neighbors(adata, n_neighs=n_neighbors, coord_type="generic", set_diag=set_diag)
            
    else:
        raise Exception ("method not recognized")


def load_spatial_graph(adata, npz_filepath, add_identity=True):
    '''
    Reads in scipy sparse adjacency matrix from the specified npz_filepath and adds it to adata.obsp["spatial_connectivities"]
    
    Parameters
    ----------
        add_identity [bool] - whether to add a diagonal of 1's to ensure compatability with TISSUE (i.e. fully connected)
    
    Returns
    -------
        Modifies adata in-place
    
    If graph is weighted, then you should set weight="spatial_connectivities" in downstream TISSUE calls for cell-centric variability calculation
    '''
    from scipy import sparse
    a = sparse.load_npz(npz_filepath)
    
    if add_identity is True:
        a += sparse.identity(a.shape[0]) # add identity matrix

    adata.obsp["spatial_connectivities"] = a
    
    print("If graph is weighted, then you should set weight='spatial_connectivities' in downstream call of conformalize_spatial_uncertainty()")
    

def predict_gene_expression (spatial_adata, RNAseq_adata,
                             target_genes, conf_genes=None,
                             method="spage", n_folds=None, random_seed=444, **kwargs):
    '''
    Leverages one of several methods to predict spatial gene expression from a paired spatial and scRNAseq dataset
    
    Parameters
    ----------
        spatial_adata [AnnData] = spatial data
        RNAseq_adata [AnnData] = RNAseq data, RNAseq_adata.var_names should be superset of spatial_adata.var_names
        target_genes [list of str] = genes to predict spatial expression for; must be a subset of RNAseq_adata.var_names
        conf_genes [list of str] = genes in spatial_adata.var_names to use for confidence measures; Default is to use all genes in spatial_adata.var_names
        method [str] = baseline imputation method
            "knn" (uses average of k-nearest neighbors in RNAseq data on Harmony joint space)
            "spage" (SpaGE imputation by Abdelaal et al., 2020)
            "tangram" (Tangram cell positioning by Biancalani et al., 2021)
            Others TBD
        n_folds [None or int] = number of cv folds to use for conf_genes, cannot exceed number of conf_genes, None is keeping each gene in its own fold
        random_seed [int] = used to see n_folds choice (defaults to 444)
    
    Returns
    -------
        Adds to adata the [numpy matrix]: spatial_adata.obsm["predicted_expression"], spatial_adata.obsm["combined_loo_expression"]
            - matrix of predicted gene expressions (same number of rows as spatial_adata, columns are target_genes)
    '''
    # change all genes to lower
    target_genes = [t.lower() for t in target_genes]
    spatial_adata.var_names = [v.lower() for v in spatial_adata.var_names]
    RNAseq_adata.var_names = [v.lower() for v in RNAseq_adata.var_names]
    
    # drop duplicates if any (happens in Dataset14)
    if RNAseq_adata.var_names.duplicated().sum() > 0:
        RNAseq_adata = RNAseq_adata[:,~RNAseq_adata.var_names.duplicated()].copy()
    if spatial_adata.var_names.duplicated().sum() > 0:
        spatial_adata = spatial_adata[:,~spatial_adata.var_names.duplicated()].copy()
    
    # raise warning if any target_genes in spatial data already
    if any(x in target_genes for x in spatial_adata.var_names):
        warnings.warn("Some target_genes are already measured in the spatial_adata object!")
    
    # First pass over all genes using specified method
    if method == "knn":
        predicted_expression_target = knn_impute(spatial_adata,RNAseq_adata,genes_to_predict=target_genes,**kwargs)
    elif method == "spage":
        predicted_expression_target = spage_impute(spatial_adata,RNAseq_adata,genes_to_predict=target_genes,**kwargs)
    elif method == "gimvi":
        predicted_expression_target = gimvi_impute(spatial_adata,RNAseq_adata,genes_to_predict=target_genes,**kwargs)
    elif method == "tangram":
        predicted_expression_target = tangram_impute(spatial_adata,RNAseq_adata,genes_to_predict=target_genes,**kwargs)
    else:
        raise Exception ("method not recognized")
        
    # Second pass over conf_genes using specified method using cross-validation
    
    if conf_genes is None:
        conf_genes = list(spatial_adata.var_names)
    conf_genes = [c.lower() for c in conf_genes]
    conf_genes_unique = [c for c in conf_genes if c not in target_genes] # removes any conf_genes also in target_genes
    if len(conf_genes_unique) < len(conf_genes):
        print("Found "+str(len(conf_genes)-len(conf_genes_unique))+" duplicate conf_gene in target_genes.")
    conf_genes_RNA = [c for c in conf_genes_unique if c in RNAseq_adata.var_names] # remove any conf genes not in RNAseq
    if len(conf_genes_RNA) < len(conf_genes_unique):
        print("Found "+str(len(conf_genes_unique)-len(conf_genes_RNA))+" conf_gene not in RNAseq_adata.")
    conf_genes = conf_genes_RNA
    
    # raise error if no conf_genes
    if len(conf_genes) == 0:
        raise Exception ("No suitable conf_genes specified!")
    
    # create folds if needed
    if n_folds is None:
        n_folds = len(conf_genes)
    elif n_folds > len(conf_genes):
        raise Warning ("n_folds in predict_gene_expression() is greater than length of conf_genes...")
        n_folds = len(conf_genes)

    np.random.seed(random_seed)
    np.random.shuffle(conf_genes)
    folds = np.array_split(conf_genes, n_folds)
    
    # run prediction on each fold
    for gi, fold in enumerate(folds):
        if method == "knn":
            loo_expression = knn_impute(spatial_adata[:,~spatial_adata.var_names.isin(fold)],RNAseq_adata,genes_to_predict=list(fold)+target_genes,**kwargs)
        elif method == "spage":
            loo_expression = spage_impute(spatial_adata[:,~spatial_adata.var_names.isin(fold)],RNAseq_adata,genes_to_predict=list(fold)+target_genes,**kwargs)
        elif method == "gimvi":
            loo_expression = gimvi_impute(spatial_adata[:,~spatial_adata.var_names.isin(fold)],RNAseq_adata,genes_to_predict=list(fold)+target_genes,**kwargs)
        elif method == "tangram":
            loo_expression = tangram_impute(spatial_adata[:,~spatial_adata.var_names.isin(fold)],RNAseq_adata,genes_to_predict=list(fold)+target_genes,**kwargs)
        else:
            raise Exception ("method not recognized")
    
        # Update 
        if gi == 0:
            predicted_expression_conf = loo_expression.copy()
        else:
            predicted_expression_conf['index'] = range(predicted_expression_conf.shape[0])
            loo_expression['index'] = range(loo_expression.shape[0])
            predicted_expression_conf.set_index('index')
            loo_expression.set_index('index')
            predicted_expression_conf = pd.concat((predicted_expression_conf,loo_expression)).groupby(by="index").sum().reset_index().drop(columns=['index'])
    
    # Take average of target_genes (later overwritten by "all genes"-predicted)
    predicted_expression_conf[target_genes] = predicted_expression_conf[target_genes]/(len(conf_genes))
    
    # Update spatial_adata
    predicted_expression_target.index = spatial_adata.obs_names
    predicted_expression_conf.index = spatial_adata.obs_names

    # gets predictions for target genes followed by conf genes
    predicted_expression_target[conf_genes] = predicted_expression_conf[conf_genes].copy()
    spatial_adata.obsm[method+"_predicted_expression"] = predicted_expression_target
    
    spatial_adata.uns["conf_genes_used"] = conf_genes
    spatial_adata.uns["target_genes_used"] = target_genes


def knn_impute (spatial_adata, RNAseq_adata, genes_to_predict, n_neighbors, **kwargs):
    '''
    Runs basic kNN imputation using Harmony subspace
    
    See predict_gene_expression() for details on arguments
    '''
    from scanpy.external.pp import harmony_integrate
    from scipy.spatial.distance import cdist
    
    # combine anndatas
    intersection = np.intersect1d(spatial_adata.var_names, RNAseq_adata.var_names)
    subRNA = RNAseq_adata[:, intersection]
    subspatial = spatial_adata[:, intersection]
    joint_adata = ad.AnnData(X=np.vstack((subRNA.X,subspatial.X)), dtype='float32')
    joint_adata.obs_names = np.concatenate((subRNA.obs_names.values,subspatial.obs_names.values))
    joint_adata.var_names = subspatial.var_names.values
    joint_adata.obs["batch"] = ["rna"]*len(subRNA.obs_names.values)+["spatial"]*len(spatial_adata.obs_names.values)
    
    # run Harmony
    sc.tl.pca(joint_adata)
    harmony_integrate(joint_adata, 'batch', verbose=False)
    
    # kNN imputation
    knn_mat = cdist(joint_adata[joint_adata.obs["batch"] == "spatial"].obsm['X_pca_harmony'][:,:np.min([30,joint_adata.obsm['X_pca_harmony'].shape[1]])],
                     joint_adata[joint_adata.obs["batch"] == "rna"].obsm['X_pca_harmony'][:,:np.min([30,joint_adata.obsm['X_pca_harmony'].shape[1]])])
    k_dist_threshold = np.sort(knn_mat)[:, n_neighbors-1]
    knn_mat[knn_mat > k_dist_threshold[:,np.newaxis]] = 0 # sets all dist > thresh to 0
    knn_mat[knn_mat > 0] = 1 # 1 for connection to a nn
    row_sums = knn_mat.sum(axis=1)
    knn_mat = knn_mat / row_sums[:,np.newaxis]
    predicted_expression = knn_mat @ RNAseq_adata.X
    
    predicted_expression = pd.DataFrame(predicted_expression, columns=RNAseq_adata.var_names.values)
    predicted_expression = predicted_expression[genes_to_predict]
    
    return(predicted_expression)
    
    
def spage_impute (spatial_adata, RNAseq_adata, genes_to_predict, **kwargs):
    '''
    Runs SpaGE gene imputation
    
    See predict_gene_expression() for details on arguments
    '''
    #from tissue.SpaGE.main import SpaGE
    from .SpaGE.main import SpaGE
    
    # transform adata in spage input data format
    if isinstance(spatial_adata.X,np.ndarray):
        spatial_data = pd.DataFrame(spatial_adata.X.T)
    else:
        spatial_data = pd.DataFrame(spatial_adata.X.T.toarray())
    spatial_data.index = spatial_adata.var_names.values
    if isinstance(RNAseq_adata.X,np.ndarray): # convert to array if needed
        RNAseq_data = pd.DataFrame(RNAseq_adata.X.T)
    else:
        RNAseq_data = pd.DataFrame(RNAseq_adata.X.T.toarray())
    RNAseq_data.index = RNAseq_adata.var_names.values
    
    # predict with SpaGE
    predicted_expression = SpaGE(spatial_data.T,RNAseq_data.T,genes_to_predict=genes_to_predict,**kwargs)
    
    return(predicted_expression)


def tangram_impute (spatial_adata, RNAseq_adata, genes_to_predict, **kwargs):
    '''
    Run Tangram gene imputation (positioning) using the more efficient cluster-level approach with Leiden clustering
    
    See predict_gene_expression() for details on arguments
    '''
    import torch
    from torch.nn.functional import softmax, cosine_similarity, sigmoid
    import tangram as tg
    
    # clustering and preprocessing
    RNAseq_adata_label = RNAseq_adata.copy()
    sc.pp.highly_variable_genes(RNAseq_adata_label)
    RNAseq_adata_label = RNAseq_adata[:, RNAseq_adata_label.var.highly_variable].copy()
    sc.pp.scale(RNAseq_adata_label, max_value=10)
    sc.tl.pca(RNAseq_adata_label)
    sc.pp.neighbors(RNAseq_adata_label)
    sc.tl.leiden(RNAseq_adata_label, resolution = 0.5)
    RNAseq_adata.obs['leiden'] = RNAseq_adata_label.obs.leiden
    del RNAseq_adata_label
    tg.pp_adatas(RNAseq_adata, spatial_adata) # genes=None default using all genes shared between two data
    
    # gene projection onto spatial
    ad_map = tg.map_cells_to_space(RNAseq_adata, spatial_adata, mode='clusters', cluster_label='leiden', density_prior='rna_count_based', verbose=False)
    ad_ge = tg.project_genes(ad_map, RNAseq_adata, cluster_label='leiden')
    predicted_expression = pd.DataFrame(ad_ge[:,genes_to_predict].X, index=ad_ge[:,genes_to_predict].obs_names, columns=ad_ge[:,genes_to_predict].var_names)
    
    return(predicted_expression)


def gimvi_impute (spatial_adata, RNAseq_adata, genes_to_predict, **kwargs):
    '''
    Run gimVI gene imputation
    
    See predict_gene_expression() for details on arguments
    '''
    import scvi
    from scvi.external import GIMVI
    
    # preprocessing of data
    spatial_adata = spatial_adata[:, spatial_adata.var_names.isin(RNAseq_adata.var_names)].copy()
    predict_idxs = [list(RNAseq_adata.var_names).index(gene) for gene in genes_to_predict]
    spatial_dim0 = spatial_adata.shape[0]
    
    # indices for filtering out zero-expression cells
    filtered_cells_spatial = (spatial_adata.X.sum(axis=1) > 1)
    filtered_cells_RNAseq = (RNAseq_adata.X.sum(axis=1) > 1)
    
    # make copies of subsets
    spatial_adata = spatial_adata[filtered_cells_spatial,:].copy()
    RNAseq_adata = RNAseq_adata[filtered_cells_RNAseq,:].copy()
    
    # setup anndata for scvi
    GIMVI.setup_anndata(spatial_adata)
    GIMVI.setup_anndata(RNAseq_adata)
    
    # train gimVI model
    model = GIMVI(RNAseq_adata, spatial_adata, generative_distributions=['nb', 'nb'], **kwargs) # 'nb' tends to be less buggy
    model.train(200)
    
    # apply trained model for imputation
    _, imputation = model.get_imputed_values(normalized=False)
    imputed = imputation[:, predict_idxs]
    predicted_expression = np.zeros((spatial_dim0, imputed.shape[1]))
    predicted_expression[filtered_cells_spatial,:] = imputed
    predicted_expression = pd.DataFrame(predicted_expression, columns=genes_to_predict)
    
    return(predicted_expression)

    
def conformalize_spatial_uncertainty (adata, predicted, calib_genes, weight='exp_cos', add_one=True,
                                      grouping_method=None, k='auto', k2='auto', n_pc=None, n_pc2=None, weight_n_pc=10):
    '''
    Generates cell-centric variability and then performs stratified grouping and conformal score calculation
    
    Parameters
    ----------
        adata - AnnData object with adata.obsm[predicted] and adata.obsp['spatial_connectivites']
        predicted [str] - string corresponding to key in adata.obsm that contains the predicted transcript expression
        calib_genes [list or np.1darray] - strings corresponding to the genes to use in calibration
        weight [str] - weights to use when computing spatial variability (either 'exp_cos' or 'spatial_connectivities')
        add_one [bool] - whether to add an intercept term of one to the spatial standard deviation
        weight_n_pc [None or int] - if not None, then specifies number of top principal components to use for weight calculation if weight is 'exp_cos' (default is None)
        For grouping_method [str], k [int>0 or 'auto'], k2 [None or int>0 or 'auto'], n_pc [None or int>0], n_pc2 [None or int>0]; refer to get_grouping()
    
    Returns
    -------
        Saves the uncertainty in adata.obsm[predicted+"_uncertainty"]
        Saves the scores in adata.obsm[predicted+"_score"]
        Saves an upper and lower bound in adata.obsm[predicted+"_lo"/"_hi"]
    '''
    # get spatial uncertainty and add to annotations
    scores, residuals, G_stdev, G = get_spatial_uncertainty_scores(adata, predicted, calib_genes,
                                                                   weight=weight,
                                                                   add_one=add_one,
                                                                   weight_n_pc=weight_n_pc)
    
    adata.obsm[predicted+"_uncertainty"] = pd.DataFrame(G_stdev,
                                                        columns=adata.obsm[predicted].columns,
                                                        index=adata.obsm[predicted].index)
    adata.obsm[predicted+"_score"] = pd.DataFrame(scores,
                                                  columns=calib_genes,
                                                  index=adata.obsm[predicted].index)
    adata.obsm[predicted+"_error"] = pd.DataFrame(residuals,
                                                  columns=calib_genes,
                                                  index=adata.obsm[predicted].index)                                              
        
    # define group
    if grouping_method is None:
        groups = np.zeros(G.shape)
    else:
        groups, k_final, k2_final = get_grouping(G, method=grouping_method, k=k, k2=k2, n_pc=n_pc, n_pc2=n_pc2)
    
    # add grouping and k-values to anndata
    adata.obsm[predicted+"_groups"] = groups
    adata.uns[predicted+"_kg"] = k_final
    adata.uns[predicted+"_kc"] = k2_final
    

def get_spatial_uncertainty_scores (adata, predicted, calib_genes, weight='exp_cos',
                                    add_one=True, weight_n_pc=None):
    '''
    Computes spatial uncertainty scores (i.e. cell-centric variability)
    
    Parameters
    ----------
        adata - AnnData object with adata.obsm[predicted] and adata.obsp['spatial_connectivites']
        predicted [str] - string corresponding to key in adata.obsm that contains the predicted transcript expression
        calib_genes [list or np.1darray] - strings corresponding to the genes to use in calibration
        weight [str] - weights to use when computing spatial variability (either 'exp_cos' or 'spatial_connectivities')
                     - 'spatial_connectivities' will use values in adata.obsp['spatial_connectivities']
        add_one [bool] - whether to add one to the uncertainty
        weight_n_pc [None or int] - if not None, then specifies number of top principal components to use for weight calculation if weight is 'exp_cos' (default is None)
        
    Returns
    -------
        scores - spatial uncertainty scores for all calib_genes
        residuals - prediction errors matching scores dimensions
        G_stdev - spatial standard deviations measured; same shape as adata.obsm[predicted]
        G - adata.obsm[predicted].values
    '''
    if weight not in ["exp_cos", "spatial_connectivities"]:
        raise Exception('weight not recognized')
    
    if 'spatial_connectivities' not in adata.obsp.keys():
        raise Exception ("'spatial_connectivities' not found in adata.obsp and is required")
    
    # init prediction array and uncertainties array
    A = adata.obsp['spatial_connectivities']
    A.eliminate_zeros()
    G = adata.obsm[predicted].values.copy()
    G_stdev = np.zeros_like(G)
    
    # init for exp_cos weighting
    if weight == "exp_cos":
        from sklearn.metrics.pairwise import cosine_similarity
        if weight_n_pc is not None: # perform PCA first and then compute cosine weights from PCs
            G_pca = StandardScaler().fit_transform(G)
            G_pca = PCA(n_components=weight_n_pc, random_state=444).fit_transform(G_pca)
    
    # compute cell-centric variability
    for i in range(G.shape[0]): # iterate cells
        
        # get its neighbors only
        cell_idxs = np.nonzero(A[i,:])[1]
        c_idx = np.where(cell_idxs==i)[0][0] # center idx in subsetted array
        
        # compute weights for cell neighbors
        if weight == "exp_cos": # use TISSUE cosine similarity weighting
            if weight_n_pc is not None: # perform PCA first and then compute cosine weights from PCs
                cos_weights = cosine_similarity(G_pca[i,:].reshape(1,-1), G_pca[cell_idxs,:])
            else: # compute cosine weights from gene expression
                cos_weights = cosine_similarity(G[i,:].reshape(1,-1), G[cell_idxs,:])
            weights = np.exp(cos_weights).flatten()
        
        elif weight == "spatial_connectivities": # use preset weights
            weights = A[i,cell_idxs].toarray().flatten()
            weights[np.isnan(weights)] = 0
        
        else: # set uniform weights
            weights = np.ones(len(cell_idxs))
        
        # compute CCV for each gene
        nA_std = []
        for j in range(G.shape[1]): # iterate genes
            
            # get expression of gene for cell and neighbors
            expression_vec = G[cell_idxs,j]
            
            # compute CCV for cell
            nA_std.append(cell_centered_variability(expression_vec, weights=weights, c_idx=c_idx))
        
        nA_std = np.array(nA_std)
        
        # add one if specified
        if add_one is True:
            nA_std += 1
        
        # update G_stdev with uncertainties
        G_stdev[i,:] = nA_std
    
    # compute scores based on confidence genes (prediction residuals)
    calib_idxs = [np.where(adata.obsm[predicted].columns==gene)[0][0] for gene in calib_genes]
    residuals = adata[:, calib_genes].X - adata.obsm[predicted][calib_genes].values # Y-G
    
    warnings.filterwarnings("ignore", category=RuntimeWarning) # suppress RuntimeWarning for division by zero
    scores = np.abs(residuals) / G_stdev[:, calib_idxs] # scores
    warnings.filterwarnings("default", category=RuntimeWarning)
    
    return(scores, residuals, G_stdev, G)


def cell_centered_variability (values, weights, c_idx):
    '''
    Takes in an array and weights to compute cell-centric variability:
    
    Parameters
    ----------
        values [1d arr] - array with cell's masked neighborhood expression (non-neighbors are nan)
        weights [1d arr] - same dim as values; contains weights for computing CCV_c
        c_idx [int] - index for which element of nA corresponds to center cell
        
    Returns
    -------
        ccv [float] - cell-centric varaiblity
    '''
    values_f = values[np.isfinite(values)]
    weights_f = weights[np.isfinite(values)]
    average = values[c_idx] # "average" is simply the center cell value
    variance = np.average((values_f-average)**2, weights=weights_f)
    ccv = np.sqrt(variance)
    
    return(ccv)


def get_spatial_uncertainty_scores_from_metadata(adata, predicted):
    '''
    Returns scores, residuals, G_stdev, G (outputs of get_spatial_uncertainty_scores) from precomputed entries
    in the AnnData (adata) object. Note, these must have been computed and saved in the same was as in
    conformalize_spatial_uncertainty().
    
    Parameters
    ----------
        adata [AnnData] - object that has saved results in obsm
        predicted [str] - key for predictions in obsm
        
    Returns
    -------
        scores - array of calibration scores [cell x gene]
        residuals - prediction error [cell x gene]
        G_stdev - array of cell-centric variability measures [cell x gene]
        groups - array of indices for group assignment [cell x gene]
    '''
    scores = np.array(adata.obsm[predicted+"_score"]).copy()
    residuals = np.array(adata.obsm[predicted+"_error"]).copy()
    G_stdev = np.array(adata.obsm[predicted+"_uncertainty"]).copy()
    G = np.array(adata.obsm[predicted]).copy()
    groups = np.array(adata.obsm[predicted+"_groups"]).copy()
    
    return(scores, residuals, G_stdev, G, groups)


def get_grouping(G, method, k='auto', k2='auto', min_samples=5, n_pc=None, n_pc2=None):
    '''
    Given the predicted gene expression matrix G (rows=cells, cols=genes),
    creates a grouping of the different genes (or cells) determined by:
    
    Parameters
    ----------
        G [numpy matrix/array] - predicted gene expression; columns are genes
        method [str] - 'kmeans_gene_cell' to separate by genes and the by cells by k-means clustering
        k [int] - number of groups; only for cv_exp, kmeans_gene, kmeans_cell and kmeans_gene_cell
                  if <=1 then defaults to one group including all values
        k2 [int] - second number of groups for kmeans_gene_cell
                  if <=1 then defaults to one group including all values
        min_samples [int] - min number of samples; only for dbscan clustering
        n_pc and npc2 [None or int] - number of PCs to use before KMeans clustering
                           - NOTE: It is recommended to do this for methods: "kmeans_gene" and "kmeans_gene_cell"
        
    Returns
    -------
        groups [numpy array] - same dimension as G with values corresponding to group number (integer)
    '''
    # for auto k searches
    k_list = [2,3,4]
            
    # grouping by genes then by cells
    if method == "kmeans_gene_cell":
        
        ### Gene grouping
        X = StandardScaler().fit_transform(G.T)
        if n_pc is not None:
            X = PCA(n_components=n_pc, random_state=444).fit_transform(X)
        # if "auto", then select best k (k_gene)
        if k == 'auto':
            k = get_best_k(X, k_list)
        # group genes
        if k > 1:
            kmeans_genes = KMeans(n_clusters=k, random_state=444).fit(X)
            cluster_genes = kmeans_genes.labels_
        else:
            cluster_genes = np.zeros(X.shape[0])
        
        # set up groups
        groups = np.ones(G.shape)*np.nan # init groups array
        counter = 0 # to index new groups with integers
        
        ### Cell grouping
        # if "auto", then select best k2 (k_cell)
        if k2 == 'auto':
            X = StandardScaler().fit_transform(G)
            if n_pc2 is not None:
                X = PCA(n_components=n_pc2, random_state=444).fit_transform(X)
            k2 = get_best_k(X, k_list)
        # within each gene group, group cells        
        for cg in np.unique(cluster_genes):
            if k2 > 1: # group if more than one cell group needed
                G_group = G[:, cluster_genes==cg]
                X_group = StandardScaler().fit_transform(G_group)
                if n_pc2 is not None:
                    X_group = PCA(n_components=n_pc2, random_state=444).fit_transform(X_group)
                kmeans_cells = KMeans(n_clusters=k2, random_state=444).fit(X_group)
                cluster_cells = kmeans_cells.labels_
            else: # set same labels for all cells
                cluster_cells = np.zeros(G.shape[0])
            # assign cell-gene stratified groupings
            for cc in np.unique(cluster_cells): 
                groups[np.ix_(cluster_cells==cc, cluster_genes==cg)] = counter
                counter += 1
        
    else:
        raise Exception("method for get_grouping() is not recognized")
    
    return(groups, k, k2)


def get_best_k (X, k_list):
    '''
    Given a matrix X to perform KMeans clustering and list of k parameter values,
    searches for the best k value
    
    k_list should be in ascending order since get_best_k will terminate once the
    silhouette score decreases
    
    Parameters
    ----------
        X - array to perform K-means clustering on
        k_list - list of positive integers for number of clusters to use
        
    Returns
    -------
        best_k [int] - k value that returns the highest silhouette score
    '''
    from sklearn.metrics import silhouette_score
    
    # init search
    current_best = -np.inf
    best_k = 1
    
    # search along k_list
    for k in k_list:
        kmeans = KMeans(n_clusters=k, random_state=444).fit(X)
        score = silhouette_score(X, kmeans.labels_)
        if score > current_best: # update if score increases
            current_best = score
            best_k = k
        else: # stop if score decreases
            break
            
    return(best_k)



def conformalize_prediction_interval (adata, predicted, calib_genes, alpha_level=0.33, symmetric=True, return_scores_dict=False, compute_wasserstein=False):
    '''
    Builds conformal prediction interval sets for the predicted gene expression
    
    Parameters
    ----------
        adata [AnnData] - contains adata.obsm[predicted] corresponding to the predicted gene expression
        predicted [str] - key in adata.obsm that corresponds to predicted gene expression 
        calib_genes [list or arr of str] - names of the genes in adata.var_names that are used in the calibration set
        alpha_level [float] - between 0 and 1; determines the alpha level; the CI will span the (1-alpha_level) interval
                              default value is alpha_level = 0.33 corresponding to 67% CI
        symmetric [bool] - whether to report symmetric prediction intervals or non-symmetric intervals; default is True (symmetric)
        return_scores_dict [bool] - whether to return the scores dictionary
        compute_wasserstein [bool] - whether to compute the Wasserstein distance of the score distributions between each subgroup and its calibration set
                                   - added to adata.obsm["{predicted}_wasserstein"]
                                   
    Returns
    -------
        Modifies adata in-place
        Optionally returns the scores_flattened_dict (dictionary containing calibration scores and group assignments)
    '''
    # get uncertainties and scores from saved adata
    scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)
    
    ### Building calibration sets for scores
    
    scores_flattened_dict = build_calibration_scores(adata, predicted, calib_genes, symmetric=symmetric)
    
    ### Building prediction intervals

    prediction_sets = (np.zeros(G.shape), np.zeros(G.shape)) # init prediction sets
    
    if compute_wasserstein is True: # set up matrix to store Wasserstein distances
        from scipy.stats import wasserstein_distance
        score_dist_wasserstein = np.ones(G.shape).astype(G.dtype)*np.nan

    # conformalize independently within groups of genes
    for group in np.unique(groups[~np.isnan(groups)]):
        
        # for symmetric intervals
        if symmetric is True:
            scores_flattened = scores_flattened_dict[str(group)] # flatten scores
            n = len(scores_flattened)
            if (n < 100): # if less than 100 samples in either set, then use the full group set
                scores_flattened = scores_flattened_dict[str(np.nan)]
                n = len(scores_flattened)-np.isnan(scores_flattened).sum()
            try:
                qhat = np.nanquantile(scores_flattened, np.ceil((n+1)*(1-alpha_level))/n)
            except:
                qhat = np.nan
            prediction_sets[0][groups==group] = (G-G_stdev*qhat)[groups==group] # lower bound
            prediction_sets[1][groups==group] = (G+G_stdev*qhat)[groups==group] # upper bound
        
        # for asymmetric intervals (Default)
        else:
            scores_lo_flattened = scores_flattened_dict[str(group)][0]
            scores_hi_flattened = scores_flattened_dict[str(group)][1]
            n_lo = len(scores_lo_flattened)-np.isnan(scores_lo_flattened).sum()
            n_hi = len(scores_hi_flattened)-np.isnan(scores_hi_flattened).sum()
            # compute qhat for lower and upper bounds
            if (n_lo < 100) or (n_hi < 100): # if less than 100 samples in either set, then use the full group set
                scores_lo_flattened = scores_flattened_dict[str(np.nan)][0]
                scores_hi_flattened = scores_flattened_dict[str(np.nan)][1]
                n_lo = len(scores_lo_flattened)-np.isnan(scores_lo_flattened).sum()
                n_hi = len(scores_hi_flattened)-np.isnan(scores_hi_flattened).sum()
            try:
                qhat_lo = np.nanquantile(scores_lo_flattened, np.ceil((n_lo+1)*(1-alpha_level))/n_lo)
                qhat_hi = np.nanquantile(scores_hi_flattened, np.ceil((n_hi+1)*(1-alpha_level))/n_hi)
            except:
                qhat_lo = np.nan
                qhat_hi = np.nan
            # compute bounds of prediction interval
            prediction_sets[0][groups==group] = (G-G_stdev*qhat_lo)[groups==group] # lower bound
            prediction_sets[1][groups==group] = (G+G_stdev*qhat_hi)[groups==group] # upper bound
            
        # Wasserstein distances
        if compute_wasserstein is True:
            # set up mask for calibration genes
            calib_idxs = [np.where(adata.obsm[predicted].columns==gene)[0][0] for gene in calib_genes]
            calib_mask = np.full(G_stdev.shape, False)
            calib_mask[:,calib_idxs] = True
            # get CCV measures
            v = G_stdev[(groups==group)&~(calib_mask)].flatten() # group CCV
            if len(v) > 0: # skip if no observations in group
                if symmetric is True:
                    if n < 100:
                        u = G_stdev[calib_mask].flatten() # calibration CCV
                    else:
                        u = G_stdev[(groups==group)&(calib_mask)].flatten() # calibration CCV
                else:
                    if (n_lo < 100) or (n_hi < 100):
                        u = G_stdev[calib_mask].flatten() # calibration CCV
                    else:
                        u = G_stdev[(groups==group)&(calib_mask)].flatten() # calibration CCV
                # calculate wasserstein distance for the CCV distributions
                score_dist_wasserstein[groups==group] = wasserstein_distance(u, v).astype(G.dtype)
            
    # add prediction intervals to adata
    adata.uns['alpha'] = alpha_level
    adata.obsm[predicted+"_lo"] = pd.DataFrame(prediction_sets[0],
                                               columns=adata.obsm[predicted].columns,
                                               index=adata.obsm[predicted].index)
    adata.obsm[predicted+"_hi"] = pd.DataFrame(prediction_sets[1],
                                               columns=adata.obsm[predicted].columns,
                                               index=adata.obsm[predicted].index)
    # add wasserstein distances to adata        
    if compute_wasserstein is True:
        adata.obsm[predicted+"_wasserstein"] = pd.DataFrame(score_dist_wasserstein,
                                               columns=adata.obsm[predicted].columns,
                                               index=adata.obsm[predicted].index)
    
    
    if return_scores_dict is True:
    
        return(scores_flattened_dict)
        
        
        
def build_calibration_scores (adata, predicted, calib_genes, symmetric=False, include_zero_scores=False,
                              trim_quantiles=[None,None]):
    '''
    Builds calibration score sets
    
    Parameters
    ----------
        adata [AnnData] - contains adata.obsm[predicted] corresponding to the predicted gene expression
        predicted [str] - key in adata.obsm with predicted gene expression values
        calib_genes [list or arr of str] - names of the genes in adata.var_names that are used in the calibration set
        symmetric [bool] - whether to have symmetric (or non-symmetric) prediction intervals
        include_zero_scores [bool] - whether to exclude zero scores
        trim_quantiles [list of len 2; None or float between 0 and 1] - specifies what quantile range of scores to trim to; None implies no bounds
        
    Returns
    -------
        scores_flattened_dict - dictionary containing the calibration scores for each stratified group
    '''
    
    # get uncertainties and scores from saved adata
    scores, residuals, G_stdev, G, groups = get_spatial_uncertainty_scores_from_metadata (adata, predicted)

    scores_flattened_dict = {}
    
    # get calibration genes
    calib_idxs = [np.where(adata.obsm[predicted].columns==gene)[0][0] for gene in calib_genes]
    
    # iterate groups and build conformal sets of calibration scores
    for group in np.unique(groups[~np.isnan(groups)]):
        if (np.isnan(group)) or (group not in groups[:, calib_idxs]): # defer to using full calibration set
            scores_group = scores.copy()
            residuals_group = residuals.copy()
        else: # for groups that are found in the calibration set, build group-specific sets
            scores_group = scores.copy()[groups[:, calib_idxs]==group]
            residuals_group = residuals.copy()[groups[:, calib_idxs]==group]
        if symmetric is True: # symmetric calibration set
            if include_zero_scores is False:
                scores_flattened = scores_group[residuals_group != 0].flatten() # exclude zeros -- empirically this way is fastest
            else:
                scores_flattened = scores_group.flatten()
            scores_flattened_dict[str(group)] = scores_flattened[np.isfinite(scores_flattened)] # add to dict
        else: # separate into hi/lo non-symmetric calibration sets
            if include_zero_scores is False:
                scores_lo_flattened = scores_group[residuals_group < 0].flatten()
                scores_hi_flattened = scores_group[residuals_group > 0].flatten()
            else:
                scores_lo_flattened = scores_group[residuals_group <= 0].flatten()
                scores_hi_flattened = scores_group[residuals_group >= 0].flatten()
            scores_flattened_dict[str(group)] = (scores_lo_flattened[np.isfinite(scores_lo_flattened)],
                                                 scores_hi_flattened[np.isfinite(scores_hi_flattened)]) # add to dict

    # build nan group consisting of all scores
    if symmetric is True: # symmetric calibration set
        if include_zero_scores is False:
            scores_flattened = scores[residuals != 0].flatten() # exclude zeros
        else:
            scores_flattened = scores.flatten()
        scores_flattened_dict[str(np.nan)] = scores_flattened[np.isfinite(scores_flattened)] # add to dict
    else: # separate into hi/lo non-symmetric calibration sets
        if include_zero_scores is False:
            scores_lo_flattened = scores[residuals < 0].flatten()
            scores_hi_flattened = scores[residuals > 0].flatten()
        else:
            scores_lo_flattened = scores[residuals <= 0].flatten()
            scores_hi_flattened = scores[residuals >= 0].flatten()
        scores_flattened_dict[str(np.nan)] = (scores_lo_flattened[np.isfinite(scores_lo_flattened)],
                                             scores_hi_flattened[np.isfinite(scores_hi_flattened)]) # add to dict
    
    # trim all scores if specified
    for key in scores_flattened_dict.keys():
    
        # determine quantiles from original scores
        if symmetric is True:
            if trim_quantiles[0] is not None:
                lower_bound = np.nanquantile(scores_flattened_dict[key], trim_quantiles[0])
            if trim_quantiles[1] is not None:
                upper_bound = np.nanquantile(scores_flattened_dict[key], trim_quantiles[1])
        else:
            if trim_quantiles[0] is not None:
                lower_bound_lo = np.nanquantile(scores_flattened_dict[key][0], trim_quantiles[0])
                lower_bound_hi = np.nanquantile(scores_flattened_dict[key][1], trim_quantiles[0])
            if trim_quantiles[1] is not None:
                upper_bound_lo = np.nanquantile(scores_flattened_dict[key][0], trim_quantiles[1])
                upper_bound_hi = np.nanquantile(scores_flattened_dict[key][1], trim_quantiles[1])
        
        # trim based on quantiles
        if symmetric is True:
            if trim_quantiles[0] is not None:    
                scores_flattened_dict[key] = scores_flattened_dict[key][scores_flattened_dict[key]>lower_bound]
            if trim_quantiles[1] is not None:    
                scores_flattened_dict[key] = scores_flattened_dict[key][scores_flattened_dict[key]<upper_bound]
        else:
            if trim_quantiles[0] is not None:    
                scores_flattened_dict[key] = (scores_flattened_dict[key][0][scores_flattened_dict[key][0]>lower_bound_lo],
                                              scores_flattened_dict[key][1][scores_flattened_dict[key][1]>lower_bound_hi])
            if trim_quantiles[1] is not None:    
                scores_flattened_dict[key] = (scores_flattened_dict[key][0][scores_flattened_dict[key][0]<upper_bound_lo],
                                              scores_flattened_dict[key][1][scores_flattened_dict[key][1]<upper_bound_hi])
       
    return (scores_flattened_dict)

# Contains utility functions for TISSUE

import numpy as np
import pandas as pd
import anndata as ad
import os


def large_save(adata, dirpath):
    '''
    Saves anndata objects by saving each obsm value with its {key}.csv as pandas dataframe
    Saves each uns value that is a dataframe with uns/{key}.csv as pandas dataframe
    Then saves the anndata object with obsm removed.
    
    Parameters
    ----------
        adata [AnnData] - AnnData object to save
        
        dirpath [str] - path to directory for where to save the h5ad and csv files; will create if not existing
            adata will be saved as {dirpath}/adata.h5ad
            obsm will be saved as {dirpath}/{key}.csv
        
    Returns
    -------
        Saves anndata object in "large" folder format
    '''
    # check if dirpath exists; else create it
    if not os.path.exists(dirpath):
        os.makedirs(dirpath)
    
    # extract the obsm metadata and save it as separate csv files
    for key, value in adata.obsm.items():
        df = pd.DataFrame(value)
        df.to_csv(os.path.join(dirpath, f"{key}.csv"), index=False)

    # remove the obsm metadata from the anndata object
    adatac = adata.copy()
    adatac.obsm = {}
    
    # extract the uns metadata and save it as separate csv files
    del_keys = []
    for key, value in adatac.uns.items():
        if isinstance(value, pd.DataFrame):
            if not os.path.exists(os.path.join(dirpath,"uns")):
                os.makedirs(os.path.join(dirpath,"uns"))
            df = pd.DataFrame(value)
            df.to_csv(os.path.join(dirpath,"uns",f"{key}.csv"), index=False)
            del_keys.append(key)
    
    # remove uns metadata from the anndata object
    for key in del_keys:
        del adatac.uns[key]

    # save the new anndata object
    adatac.write(os.path.join(dirpath, "adata.h5ad"))



def large_load(dirpath, skipfiles=[]):
    '''
    Loads in anndata and associated pandas dataframe csv files to be added to obsm metadata and uns metadata.
    Input is the directory path to the output directory of large_save()
    
    Parameters
    ----------
        dirpath [str] - path to directory for where outputs of large_save() are located
        skipfiles [list] - list of filenames to exclude from anndata object
    
    Returns
    -------
        adata - AnnData object loaded from dirpath along with all obsm and uns key values added to metadata
    '''
    # read h5ad anndata object
    adata = ad.read_h5ad(os.path.join(dirpath, "adata.h5ad"))
    
    # read and load in obsm from CSV files
    for fn in os.listdir(dirpath):
        if (".csv" in fn) and (fn not in skipfiles):
            df = pd.read_csv(os.path.join(dirpath, fn))
            df.index = adata.obs_names
            key = fn.split(".")[0]
            adata.obsm[key] = df
            
    # read and load any usn metadata from CSV files
    if os.path.isdir(os.path.join(dirpath,"uns")):
        for fn in os.listdir(os.path.join(dirpath,"uns")):
            if (".csv" in fn) and (fn not in skipfiles):
                df = pd.read_csv(os.path.join(dirpath,"uns",fn))
                key = fn.split(".")[0]
                adata.uns[key] = df
            
    return(adata)


def convert_adata_to_dataupload (adata, savedir):
    '''
    Saves AnnData object into TISSUE input directory
    
    Parameters
    ----------
        adata - AnnData object to be saved with all metadata in adata.obs and spatial coordinates in adata.obsm['spatial']
        savedir [str] - path to existing directory to save the files for TISSUE loading
        
    Returns
    -------
        Saves all TISSUE input files into the specified directory for the given AnnData object
        
    NOTE: You will need to independently include scRNA_count.txt in savedir for TISSUE inputs to be complete
    '''
    locations = pd.DataFrame(adata.obsm['spatial'], columns=['x','y'])
    locations.to_csv(os.path.join(savedir,"Locations.txt"), sep="\t", index=False)
    
    df = pd.DataFrame(adata.X, columns=adata.var_names)
    df.to_csv(os.path.join(savedir,"Spatial_count.txt"), sep="\t", index=False)
    
    meta = pd.DataFrame(adata.obs)
    meta.to_csv(os.path.join(savedir,"Metadata.txt"))

'''TISSUE (Transcript Imputation with Spatial Single-cell Uncertainty Estimation) provides tools for estimating well-calibrated uncertainty measures for gene expression predictions in single-cell spatial transcriptomics datasets and utilizing them in downstream analyses'''

__version__ = "1.0.1"

""" Dimensionality Reduction
@author: Soufiane Mourragui
This module extracts the domain-specific factors from the high-dimensional omics
dataset. Several methods are here implemented and they can be directly
called from string name in main method method. All the methods
use scikit-learn implementation.
Notes
-------
	-
	
References
-------
	[1] Pedregosa, Fabian, et al. (2011) Scikit-learn: Machine learning in Python.
	Journal of Machine Learning Research
"""

import numpy as np
from sklearn.decomposition import PCA, FastICA, FactorAnalysis, NMF, SparsePCA
from sklearn.cross_decomposition import PLSRegression


def process_dim_reduction(method='pca', n_dim=10):
    """
    Default linear dimensionality reduction method. For each method, return a
    BaseEstimator instance corresponding to the method given as input.
	Attributes
    -------
    method: str, default to 'pca'
    	Method used for dimensionality reduction.
    	Implemented: 'pca', 'ica', 'fa' (Factor Analysis), 
    	'nmf' (Non-negative matrix factorisation), 'sparsepca' (Sparse PCA).
    
    n_dim: int, default to 10
    	Number of domain-specific factors to compute.
    Return values
    -------
    Classifier, i.e. BaseEstimator instance
    """

    if method.lower() == 'pca':
        clf = PCA(n_components=n_dim)

    elif method.lower() == 'ica':
        print('ICA')
        clf = FastICA(n_components=n_dim)

    elif method.lower() == 'fa':
        clf = FactorAnalysis(n_components=n_dim)

    elif method.lower() == 'nmf':
        clf = NMF(n_components=n_dim)

    elif method.lower() == 'sparsepca':
        clf = SparsePCA(n_components=n_dim, alpha=10., tol=1e-4, verbose=10, n_jobs=1)

    elif method.lower() == 'pls':
        clf = PLS(n_components=n_dim)
		
    else:
        raise NameError('%s is not an implemented method'%(method))

    return clf


class PLS():
    """
    Implement PLS to make it compliant with the other dimensionality
    reduction methodology.
    (Simple class rewritting).
    """
    def __init__(self, n_components=10):
        self.clf = PLSRegression(n_components)

    def get_components_(self):
        return self.clf.x_weights_.transpose()

    def set_components_(self, x):
        pass

    components_ = property(get_components_, set_components_)

    def fit(self, X, y):
        self.clf.fit(X,y)
        return self

    def transform(self, X):
        return self.clf.transform(X)

    def predict(self, X):
        return self.clf.predict(X)

""" SpaGE [1]
@author: Tamim Abdelaal
This function integrates two single-cell datasets, spatial and scRNA-seq, and 
enhance the spatial data by predicting the expression of the spatially 
unmeasured genes from the scRNA-seq data.
The integration is performed using the domain adaption method PRECISE [2]
	
References
-------
    [1] Abdelaal T., Mourragui S., Mahfouz A., Reiders M.J.T. (2020)
    SpaGE: Spatial Gene Enhancement using scRNA-seq
    [2] Mourragui S., Loog M., Reinders M.J.T., Wessels L.F.A. (2019)
    PRECISE: A domain adaptation approach to transfer predictors of drug response
    from pre-clinical models to tumors
"""

import numpy as np
import pandas as pd
import scipy.stats as st
from sklearn.neighbors import NearestNeighbors
#from tissue.SpaGE.principal_vectors import PVComputation
from .principal_vectors import PVComputation

def SpaGE(Spatial_data,RNA_data,n_pv,genes_to_predict=None):
    """
        @author: Tamim Abdelaal
        This function integrates two single-cell datasets, spatial and scRNA-seq, 
        and enhance the spatial data by predicting the expression of the spatially 
        unmeasured genes from the scRNA-seq data.
        
        Parameters
        -------
        Spatial_data : Dataframe
            Normalized Spatial data matrix (cells X genes).
        RNA_data : Dataframe
            Normalized scRNA-seq data matrix (cells X genes).
        n_pv : int
            Number of principal vectors to find from the independently computed
            principal components, and used to align both datasets. This should
            be <= number of shared genes between the two datasets.
        genes_to_predict : str array 
            list of gene names missing from the spatial data, to be predicted 
            from the scRNA-seq data. Default is the set of different genes 
            (columns) between scRNA-seq and spatial data.
            
        Return
        -------
        Imp_Genes: Dataframe
            Matrix containing the predicted gene expressions for the spatial 
            cells. Rows are equal to the number of spatial data rows (cells), 
            and columns are equal to genes_to_predict,  .
    """
    
    if genes_to_predict is SpaGE.__defaults__[0]:
        genes_to_predict = np.setdiff1d(RNA_data.columns,Spatial_data.columns)
        
    RNA_data_scaled = pd.DataFrame(data=st.zscore(RNA_data,axis=0),
                                   index = RNA_data.index,columns=RNA_data.columns)
    Spatial_data_scaled = pd.DataFrame(data=st.zscore(Spatial_data,axis=0),
                                   index = Spatial_data.index,columns=Spatial_data.columns)
    Common_data = RNA_data_scaled[np.intersect1d(Spatial_data_scaled.columns,RNA_data_scaled.columns)]
    
    Imp_Genes = pd.DataFrame(np.zeros((Spatial_data.shape[0],len(genes_to_predict))),
                                 columns=genes_to_predict)
    
    pv_Spatial_RNA = PVComputation(
            n_factors = n_pv,
            n_pv = n_pv,
            dim_reduction = 'pca',
            dim_reduction_target = 'pca'
    )
    
    pv_Spatial_RNA.fit(Common_data,Spatial_data_scaled[Common_data.columns])
    
    S = pv_Spatial_RNA.source_components_.T
        
    Effective_n_pv = sum(np.diag(pv_Spatial_RNA.cosine_similarity_matrix_) > 0.3)
    S = S[:,0:Effective_n_pv]
    
    Common_data_projected = Common_data.dot(S)
    Spatial_data_projected = Spatial_data_scaled[Common_data.columns].dot(S)
        
    nbrs = NearestNeighbors(n_neighbors=50, algorithm='auto',
                            metric = 'cosine').fit(Common_data_projected)
    distances, indices = nbrs.kneighbors(Spatial_data_projected)
    
    for j in range(0,Spatial_data.shape[0]):
    
        weights = 1-(distances[j,:][distances[j,:]<1])/(np.sum(distances[j,:][distances[j,:]<1]))
        weights = weights/(len(weights)-1)
        Imp_Genes.iloc[j,:] = np.dot(weights,RNA_data[genes_to_predict].iloc[indices[j,:][distances[j,:] < 1]])
        
    return Imp_Genes


""" Principal Vectors
@author: Soufiane Mourragui
This module computes the principal vectors from two datasets, i.e.:
- perform linear dimensionality reduction independently for both dataset, resulting
in set of domain-specific factors.
- find the common factors using principal vectors [1]
This result in set of pairs of vectors. Each pair has one vector from the source and one
from the target. For each pair, a similarity score (cosine similarity) can be computed
between the principal vectors and the pairs are naturally ordered by decreasing order
of this similarity measure.
Example
-------
    Examples are given in the vignettes.
Notes
-------
	Examples are given in the vignette
	
References
-------
	[1] Golub, G.H. and Van Loan, C.F., 2012. "Matrix computations" (Vol. 3). JHU Press.
	[2] Mourragui, S., Loog, M., Reinders, M.J.T., Wessels, L.F.A. (2019)
    PRECISE: A domain adaptation approach to transfer predictors of drug response
    from pre-clinical models to tumors
"""

import numpy as np
import pandas as pd
import scipy
from pathlib import Path
from sklearn.preprocessing import normalize

#from tissue.SpaGE.dimensionality_reduction import process_dim_reduction
from .dimensionality_reduction import process_dim_reduction

class PVComputation:
    """
    Attributes
    -------
    n_factors: int
        Number of domain-specific factors to compute.
    n_pv: int
        Number of principal vectors.
    dim_reduction_method_source: str
        Dimensionality reduction method used for source data
    dim_reduction_target: str
        Dimensionality reduction method used for source data
    source_components_ : numpy.ndarray, shape (n_pv, n_features)
        Loadings of the source principal vectors ranked by similarity to the
        target. Components are in the row.
    source_explained_variance_ratio_: numpy.ndarray, shape (n_pv)
        Explained variance of the source on each source principal vector.
    target_components_ : numpy.ndarray, shape (n_pv, n_features)
        Loadings of the target principal vectors ranked by similarity to the
        source. Components are in the row.
    target_explained_variance_ratio_: numpy.ndarray, shape (n_pv)
        Explained variance of the target on each target principal vector.
    cosine_similarity_matrix_: numpy.ndarray, shape (n_pv, n_pv)
        Scalar product between the source and the target principal vectors. Source
        principal vectors are in the rows while target's are in the columns. If
        the domain adaptation is sensible, a diagonal matrix should be obtained.
    """

    def __init__(self, n_factors,n_pv,
                dim_reduction='pca',
                dim_reduction_target=None,
                project_on=0):
        """
        Parameters
        -------
        n_factors : int
            Number of domain-specific factors to extract from the data (e.g. using PCA, ICA).
        n_pv : int
            Number of principal vectors to find from the independently computed factors.
        dim_reduction : str, default to 'pca' 
            Dimensionality reduction method for the source data,
            i.e. 'pca', 'ica', 'nmf', 'fa', 'sparsepca', pls'.
        dim_reduction_target : str, default to None 
            Dimensionality reduction method for the target data,
            i.e. 'pca', 'ica', 'nmf', 'fa', 'sparsepca', pls'. If None, set to dim_reduction.
    	project_on: int or bool, default to 0
    		Where data should be projected on. 0 means source PVs, -1 means target PVs and 1 means
            both PVs.
        """
        self.n_factors = n_factors
        self.n_pv = n_pv
        self.dim_reduction_method_source = dim_reduction
        self.dim_reduction_method_target = dim_reduction_target or dim_reduction
        self.dim_reduction_source = self._process_dim_reduction(self.dim_reduction_method_source)
        self.dim_reduction_target = self._process_dim_reduction(self.dim_reduction_method_target)

        self.source_components_ = None
        self.source_explained_variance_ratio_ = None
        self.target_components_ = None
        self.target_explained_variance_ratio_ = None
        self.cosine_similarity_matrix_ = None

    def _process_dim_reduction(self, dim_reduction):
        if type(dim_reduction) == str:
            return process_dim_reduction(method=dim_reduction, n_dim=self.n_factors)
        else:
            return dim_reduction

    def fit(self, X_source, X_target, y_source=None):
        """
    	Compute the common factors between two set of data.
    	IMPORTANT: Same genes have to be given for source and target, and in same order
        Parameters
        -------
        X_source : np.ndarray, shape (n_components, n_genes)
            Source dataset
        X_target : np.ndarray, shape (n_components, n_genes)
            Target dataset
        y_source : np.ndarray, shape (n_components, 1) (optional, default to None)
            Eventual output, in case one wants to give ouput (for instance PLS)
        Return values
        -------
        self: returns an instance of self.
        """
        # Compute factors independently for source and target. Orthogonalize the basis
        Ps = self.dim_reduction_source.fit(X_source, y_source).components_
        Ps = scipy.linalg.orth(Ps.transpose()).transpose()

        Pt = self.dim_reduction_target.fit(X_target, y_source).components_
        Pt = scipy.linalg.orth(Pt.transpose()).transpose()

        # Compute the principal factors
        self.compute_principal_vectors(Ps, Pt)

        # Compute variance explained
        self.source_explained_variance_ratio_ = np.var(self.source_components_.dot(X_source.transpose()), axis=1)/\
                                                np.sum(np.var(X_source), axis=0)
        self.target_explained_variance_ratio_ = np.var(self.target_components_.dot(X_target.transpose()), axis=1)/\
                                                np.sum(np.var(X_target), axis=0)

        return self

    def compute_principal_vectors(self, source_factors, target_factors):
        """
    	Compute the principal vectors between the already computed set of domain-specific
        factors, using approach presented in [1,2].
    	IMPORTANT: Same genes have to be given for source and target, and in same order
        Parameters
        -------
    	source_factors: np.ndarray, shape (n_components, n_genes)
    		Source domain-specific factors.
    	target_factors: np.ndarray, shape (n_components, n_genes)
    		Target domain-specific factors.
        Return values
        -------
        self: returns an instance of self.
        """

        # Find principal vectors using SVD
        u,sigma,v = np.linalg.svd(source_factors.dot(target_factors.transpose()))
        self.source_components_ = u.transpose().dot(source_factors)[:self.n_pv]
        self.target_components_ = v.dot(target_factors)[:self.n_pv]
        # Normalize to make sure that vectors are unitary
        self.source_components_ = normalize(self.source_components_, axis = 1)
        self.target_components_ = normalize(self.target_components_, axis = 1)

        # Compute cosine similarity matrix
        self.initial_cosine_similarity_matrix_ = source_factors.dot(target_factors.transpose())
        self.cosine_similarity_matrix_ = self.source_components_.dot(self.target_components_.transpose())

        # Compute angles
        self.angles_ = np.arccos(np.diag(self.cosine_similarity_matrix_))

        return self


    def transform(self, X, project_on=None):
        """
    	Projects data onto principal vectors.
        Parameters
        -------
        X : numpy.ndarray, shape (n_samples, n_genes)
            Data to project.
        project_on: int or bool, default to None
            Where data should be projected on. 0 means source PVs, -1 means target PVs and 1 means
            both PVs. If None, set to class instance value.
    	Return values
        -------
        Projected data as a numpy.ndarray of shape (n_samples, n_factors)
        """

        project_on = project_on or self.project_on

        # Project on source
        if project_on == 0:
            return X.dot(self.source_components_.transpose())

        # Project on target
        elif project_on == -1:
            return X.dot(self.target_components_.transpose())

        # Project on both
        elif project_on == 1:
            return X.dot(np.concatenate([self.source_components_.transpose(), self.target_components_.transpose()]))

        else:
            raise ValueError('project_on should be 0 (source), -1 (target) or 1 (both). %s not correct value'%(project_on))

