import os
import scipy
import scprep
import inspect
import numpy as np
import scanpy as sc
from scipy.sparse import csr_matrix
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F


def r_function(filename, args="sce"):
    assert filename.endswith(".R")

    # get the path to the module that called `r_function`
    curr_frame = inspect.currentframe()
    prev_frame = inspect.getframeinfo(curr_frame.f_back)
    filepath = os.path.join(os.path.dirname(prev_frame.filename), filename)

    with open(filepath, "r") as handle:
        r_code = handle.read()

    out_fun = scprep.run.RFunction(setup="", args=args, body=r_code)
    out_fun.__r_file__ = filepath
    return out_fun

def denoising_baseline(adata, dataset='', model='dca', hvg=None, seed=10, *kwargs):
    if hvg is not None:
        sc.pp.highly_variable_genes(adata, layer='counts',flavor='seurat_v3', subset=True)

    if dataset == 'HLCAsubset':
        adata.X = adata.layers['counts'].copy()
        input_adata = adata[adata.obs.split != 'test']
        test_input_adata = adata[adata.obs.split == 'test']
    else:
        adata.X = adata.layers['counts'].copy()
        input_adata = adata.copy()
        train_mask = adata.layers['train_mask']
        test_mask = adata.layers['test_mask']

        data_mat = adata.layers['counts'].A # adata.X.A
        input_mat = data_mat * train_mask.astype(float)
        input_adata.X = csr_matrix(input_mat)
        target_mat = data_mat.copy() # * test_mask.astype(float)

    if model == 'dca':
        from dca.io import read_dataset, normalize
        from dca.train import train
        from dca.network import AE_types
        def _dca(adata,
                adata_test,
                mode='denoise',
                ae_type='nb-conddisp',
                normalize_per_cell=True,
                scale=True,
                log1p=True,
                hidden_size=(64, 32, 64), # network args
                hidden_dropout=0.,
                batchnorm=True,
                activation='relu',
                init='glorot_uniform',
                network_kwds={},
                epochs=300,               # training args
                reduce_lr=10,
                early_stop=15,
                batch_size=32,
                optimizer='RMSprop',
                learning_rate=None,
                random_state=0,
                threads=None,
                verbose=False,
                training_kwds={},
                return_model=False,
                return_info=False,
                copy=True,
                check_counts=True,
            ):
            # this creates adata.raw with raw counts and copies adata if copy==True
            adata = read_dataset(adata,
                                transpose=False,
                                test_split=False,
                                copy=copy,
                                check_counts=check_counts)
            adata_test = read_dataset(adata_test,
                                transpose=False,
                                test_split=False,
                                copy=copy,
                                check_counts=check_counts)

            # check for zero genes
            # nonzero_genes, _ = sc.pp.filter_genes(adata.X, min_counts=1)
            # assert nonzero_genes.all(), 'Please remove all-zero genes before using DCA.'

            adata = normalize(adata,
                            filter_min_counts=False, # no filtering, keep cell and gene idxs same
                            size_factors=normalize_per_cell,
                            normalize_input=scale,
                            logtrans_input=log1p)
            
            adata_test = normalize(adata_test,
                            filter_min_counts=False, # no filtering, keep cell and gene idxs same
                            size_factors=normalize_per_cell,
                            normalize_input=scale,
                            logtrans_input=log1p)

            network_kwds = {**network_kwds,
                'hidden_size': hidden_size,
                'hidden_dropout': hidden_dropout,
                'batchnorm': batchnorm,
                'activation': activation,
                'init': init
            }
            
            from tensorflow.python.framework.ops import disable_eager_execution
            disable_eager_execution()

            input_size = output_size = adata.n_vars
            net = AE_types[ae_type](input_size=input_size,
                                    output_size=output_size,
                                    **network_kwds)
            net.save()
            net.build()

            training_kwds = {**training_kwds,
                'epochs': epochs,
                'reduce_lr': reduce_lr,
                'early_stop': early_stop,
                'batch_size': batch_size,
                'optimizer': optimizer,
                'verbose': verbose,
                'threads': threads,
                'learning_rate': learning_rate
            }

            hist = train(adata[adata.obs.dca_split == 'train'], net, **training_kwds)
            res = net.predict(adata_test, mode, return_info, copy)
            adata_test = res if copy else adata_test

            if return_model:
                return (adata_test, net) if copy else net
            else:
                return adata_test if copy else None

        pred_adata = _dca(
            input_adata, input_adata, mode='denoise', ae_type='nb-conddisp', normalize_per_cell=True,
            scale=True, log1p=True, hidden_size=(64, 32, 64), hidden_dropout=0., batchnorm=True,
            activation='relu', init='glorot_uniform', epochs=300, reduce_lr=10, early_stop=15,
            batch_size=32, random_state=seed
        )
        pred_mat = pred_adata.X
        norm_flag = True
        if len(pred_mat) != len(target_mat):
            import pandas as pd

            zeros_df = pd.DataFrame(0, index=input_adata.obs.index, columns=input_adata.var.index)
            pred_df = pred_adata.to_df()
            pred_mat = (pred_df + zeros_df).to_numpy()
            
            pred_idx_file = './data/PBMC1K_999_pred_idx.npy'
            if not os.path.exists(pred_idx_file):
                pred_idx = pred_df.index.to_numpy()
                np.save(pred_idx_file, pred_idx)

    elif model == 'magic':
        def _magic(adata, solver, normtype="sqrt", reverse_norm_order=False, **kwargs):
            import scprep
            from magic import MAGIC

            if normtype == "sqrt":
                norm_fn = np.sqrt
                denorm_fn = np.square
            elif normtype == "log":
                norm_fn = np.log1p
                denorm_fn = np.expm1
            else:
                raise NotImplementedError

            X = adata.X
            if reverse_norm_order:
                # inexplicably, this sometimes performs better
                X = scprep.utils.matrix_transform(X, norm_fn)
                X, libsize = scprep.normalize.library_size_normalize(
                    X, rescale=1, return_library_size=True
                )
            else:
                X, libsize = scprep.normalize.library_size_normalize(
                    X, rescale=1, return_library_size=True
                )
                X = scprep.utils.matrix_transform(X, norm_fn)

            Y = MAGIC(solver=solver, **kwargs, verbose=False).fit_transform(
                X, genes="all_genes"
            )

            Y = scprep.utils.matrix_transform(Y, denorm_fn)
            Y = scprep.utils.matrix_vector_elementwise_multiply(Y, libsize, axis=0)

            return Y
        
        pred_mat = _magic(input_adata, solver="exact", normtype="sqrt")
        norm_flag = True
        

    elif model == 'alra':
        _r_alra = r_function("./scripts/alra.R")

        def _alra(adata, normtype="log", reverse_norm_order=False):
            import rpy2.rinterface_lib.embedded
            import scprep

            if normtype == "sqrt":
                norm_fn = np.sqrt
                denorm_fn = np.square
            elif normtype == "log":
                norm_fn = np.log1p
                denorm_fn = np.expm1
            else:
                raise NotImplementedError

            X = adata.X
            if reverse_norm_order:
                # inexplicably, this sometimes performs better
                X = scprep.utils.matrix_transform(X, norm_fn)
                X, libsize = scprep.normalize.library_size_normalize(
                    X, rescale=1, return_library_size=True
                )
            else:
                X, libsize = scprep.normalize.library_size_normalize(
                    X, rescale=1, return_library_size=True
                )
                X = scprep.utils.matrix_transform(X, norm_fn)

            adata.obsm["train_norm"] = X.tocsr()
            # run alra
            # _r_alra takes sparse array, returns dense array
            Y = None
            attempts = 0
            while Y is None:
                try:
                    Y = _r_alra(adata)
                except rpy2.rinterface_lib.embedded.RRuntimeError:  # pragma: no cover
                    if attempts < 10:
                        attempts += 1
                        print(f"alra.R failed (attempt {attempts})")
                    else:
                        raise

            # transform back into original space
            # functions are reversed!
            Y = scprep.utils.matrix_transform(Y, denorm_fn)
            Y = scprep.utils.matrix_vector_elementwise_multiply(Y, libsize, axis=0)

            return Y
        
        pred_mat = _alra(input_adata, normtype="log")
        norm_flag = True

    elif model == 'scgnn2': # too long, keep it in dance
        return None
    
    elif model == 'deepimpute': # too long, keep it in dance
        return None

    elif model == 'saver':
        _r_saver = r_function("./scripts/saver.R")

        def _saver(adata):
            import rpy2.rinterface_lib.embedded

            Y = None
            attempts = 0
            while Y is None:
                try:
                    Y = _r_saver(adata)
                except rpy2.rinterface_lib.embedded.RRuntimeError:  # pragma: no cover
                    if attempts < 10:
                        attempts += 1
                        print(f"alra.R failed (attempt {attempts})")
                    else:
                        raise

            # transform back into original space
            # functions are reversed!
            # Y = torch.tensor(Y).float()
            # library_size = torch.tensor(adata.obs['library_size'].values).float()[:, None]
            # Y = invervse_lib_norm(Y, library_size)

            return Y
        
        pred_mat = _saver(input_adata).T
        norm_flag = True

        if len(pred_mat) != len(target_mat):
            import pandas as pd
            
            pred_idx_file = './data/PBMC1K_999_pred_idx.npy'
            pred_idx = np.load(pred_idx_file, allow_pickle=True)
            zeros_df = pd.DataFrame(0, index=input_adata.obs.index, columns=input_adata.var.index)
            pred_df = pd.DataFrame(pred_mat, index=pred_idx, columns=input_adata.var.index)
            pred_mat = (pred_df + zeros_df).to_numpy()

    elif model == 'saver-x':
        _r_saverx = r_function("./scripts/saver-x.R")
        # sc.pp.normalize_total(adata, target_sum=1e4)
        # sc.pp.log1p(adata)

        def _saver(adata):
            import rpy2.rinterface_lib.embedded

            Y = None
            attempts = 0
            while Y is None:
                try:
                    Y = _r_saverx(adata)
                except rpy2.rinterface_lib.embedded.RRuntimeError:  # pragma: no cover
                    if attempts < 10:
                        attempts += 1
                        print(f"alra.R failed (attempt {attempts})")
                    else:
                        raise

            # transform back into original space
            # functions are reversed!
            # Y = torch.tensor(Y).float()
            # library_size = torch.tensor(adata.obs['library_size'].values).float()[:, None]
            # Y = invervse_lib_norm(Y, library_size)

            return Y
        
        pred_mat = _saver(input_adata)
        norm_flag = True
    
    elif model == 'scvi':

        import scvi as scvi_
        from scvi.model import SCVI

        scvi_.settings.seed = seed

        # Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization
        n_latent = 30
        n_hidden = 128
        n_layers = 2

        # copying to not return values added to adata during setup_anndata
        net_adata = input_adata.copy()
        SCVI.setup_anndata(net_adata, batch_key='batch')

        vae = SCVI(
            net_adata,
            gene_likelihood="nb",
            n_layers=n_layers,
            n_latent=n_latent,
            n_hidden=n_hidden,
        )
        train_kwargs = {"train_size": 1.0}
        vae.train(**train_kwargs)
        pred_mat = vae.get_normalized_expression(library_size=1e4).values
        adata_pred = sc.AnnData(pred_mat)
        sc.pp.log1p(adata_pred)
        pred_mat = adata_pred.X
        norm_flag = True

    else:
        raise NotImplementedError(f"Unsupported model {model}")
    
    scores = denoising_eval(target_mat, pred_mat, test_mask, norm_flag)
    return scores

def scnormalize(x, target_sum=1e4, eps=1e-8):
    x = x * target_sum / (x.sum(1, keepdim=True) + eps)
    x = torch.log1p(x)
    return x

def invervse_scnormalize(x, library_size=1e4, eps=1e-8):
    x = torch.exp(x) - 1
    x = x * library_size / (x.sum(1, keepdim=True) + eps)
    return x

def invervse_lib_norm(x, library_size=1e4, eps=1e-8):
    x = x * library_size / (x.sum(1, keepdim=True) + eps)
    return x

def as_tensor(x, assert_type: bool = False):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    elif not isinstance(x, torch.Tensor) and assert_type:
        raise TypeError(f"Expecting tensor or numpy array, got, {type(x)}")
    return x

def masked_rmse(pred, true, mask):
    pred_masked = pred * mask
    true_masked = true * mask
    size = mask.sum()
    return (F.mse_loss(pred_masked, true_masked, reduction='sum') / size).sqrt()


def masked_stdz(x, mask):
    size = mask.sum(1, keepdim=True).clamp(1)
    x = x * mask
    x_ctrd = x - (x.sum(1, keepdim=True) / size) * mask
    # NOTE: multiplied by the factor of sqrt of N
    x_std = x_ctrd.pow(2).sum(1, keepdim=True).sqrt()
    return x_ctrd / x_std


def masked_corr(pred, true, mask):
    pred_masked_stdz = masked_stdz(pred, mask)
    true_masked_stdz = masked_stdz(true, mask)
    corr = torch.nanmean((pred_masked_stdz * true_masked_stdz).sum(1))
    return corr

def PearsonCorr(y_pred, y_true):
    y_true_c = y_true - torch.mean(y_true, 1)[:, None]
    y_pred_c = y_pred - torch.mean(y_pred, 1)[:, None]
    pearson = torch.nanmean(
        torch.sum(y_true_c * y_pred_c, 1)
        / torch.sqrt(torch.sum(y_true_c * y_true_c, 1))
        / torch.sqrt(torch.sum(y_pred_c * y_pred_c, 1))
    )
    return pearson

def PearsonCorr1d(y_true, y_pred):
    y_true_c = y_true - torch.mean(y_true)
    y_pred_c = y_pred - torch.mean(y_pred)
    pearson = torch.nanmean(
        torch.sum(y_true_c * y_pred_c)
        / torch.sqrt(torch.sum(y_true_c * y_true_c))
        / torch.sqrt(torch.sum(y_pred_c * y_pred_c))
    )
    return pearson

@torch.inference_mode()
def denoising_eval(true, pred, mask, norm_flag=True):
    if norm_flag:
        import scanpy as sc

        adata_pred = sc.AnnData(pred)
        adata_true = sc.AnnData(true)
        sc.pp.normalize_total(adata_pred, target_sum=1e4)
        sc.pp.log1p(adata_pred)
        sc.pp.normalize_total(adata_true, target_sum=1e4)
        sc.pp.log1p(adata_true)
        pred = adata_pred.X
        true = adata_true.X

    true = as_tensor(true, assert_type=True)
    pred = as_tensor(pred, assert_type=True)
    mask = as_tensor(mask, assert_type=True).bool()

    rmse_normed = masked_rmse(pred, true, mask).item()
    corr_normed = masked_corr(pred, true, mask).item()
    global_corr_normed = PearsonCorr1d(pred[mask], true[mask]).item()

    # nonzero_masked = (true > 0) * mask
    # rmse_normed_nonzeros = masked_rmse(pred, true, nonzero_masked).item()
    # corr_normed_nonzeros = masked_corr(pred, true, nonzero_masked).item()

    corr_normed_all = PearsonCorr(pred, true).item()
    rmse_normed_all = F.mse_loss(pred, true).sqrt().item()

    r = scipy.stats.linregress(pred[mask].cpu().numpy(), true[mask].cpu().numpy())[2]
    # r_all = scipy.stats.linregress(pred.ravel().cpu().numpy(), true.ravel().cpu().numpy())[2]

    return {
        'denoise_rmse_normed': rmse_normed,
        'denoise_corr_normed': corr_normed,
        'denoise_global_corr_normed': global_corr_normed,
        'denoise_global_r2_normed': r ** 2,
        # 'denoise_rmse_normed_nonzeros': rmse_normed_nonzeros,
        # 'denoise_corr_normed_nonzeros': corr_normed_nonzeros,
        'denoise_rmse_normed_all': rmse_normed_all,
        'denoise_corr_normed_all': corr_normed_all,
        # 'denoise_global_r2_normed_all': r_all ** 2,
    }