import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scipy.stats
from sklearn import linear_model
from sklearn.metrics import r2_score, matthews_corrcoef, f1_score, precision_score
from sklearn import preprocessing
from sklearn import decomposition
import itertools as it

def square(x):
    return x**2

class SimDetection(object):
    """
    Class to simulate model calibration and detection probabilities
    
    ...

    Attributes
    ----------
    y_test: np array
        the empirical mutation count for each window
    y_mean: float
        the mean mutation count across windows
    y_std: float
        the standard deviation of mutation counts across windows
    var_test: np array
        the empirical variance for each window
    var_mean: float
        the mean variance across windows
    var_std: float
        the standard deviation of variance across windows
    tx: fnc
        transform to apply when simulated correlated RVs (e.g. to keep RVs non-negative)
    inv_tx: fnc
        inverse transform to convert from simulation space to mutation space

    Methods
    -------
    simulate_batch_normal
            Simulates model calibration and detection power using normal approximations (fast)
        simulate_batch_nb
            Simulates model calibration and detection power using exact negative binomial model (slow)
    """
        
    def __init__(self, y_test, var_test, tx=np.sqrt, inv_tx=square):
        """
        Parameters
        ----------
        y_test: np array
            the empirical mutation count for each window
        var_test: np array
            the empirical variance for each window
        tx: fnc
            transform to apply when simulated correlated RVs (e.g. to keep RVs non-negative)
            (DO NOT CHANGE UNLESS YOU UNDERSTAND WHAT THS DOES!!)
        inv_tx: fnc
            inverse transform to convert from simulation space to mutation space
            (DO NOT CHANGE UNLESS YOU UNDERSTAND WHAT THIS DOES!!)
        """
        self.tx = tx
        self.inv_tx = inv_tx
        
        self.y_test = y_test
        self.y_mean = np.mean(tx(y_test))
        self.y_std = np.std(tx(y_test))
        
        self.var_test = var_test
        self.var_mean = np.mean(tx(self.var_test))
        self.var_std = np.std(tx(self.var_test))
        
    def simulate_batch_normal(self, p2_mean, p2_var, pdriver, nsamp, nsim=1000000, alpha=0.05, alpha_bounds=(0.01, 0.10), debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using normal approximations

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std

        ## Simulate mean and variance predictions with specified correlation to the truth
        y = self.tx(y_true)
        y_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=y_mean + np.sqrt(p2_mean)*(y-y_mean), 
                                                             scale=np.sqrt(1-p2_mean)*y_std
                                           )
                            )

        s = self.tx(var_true)
        var_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=var_mean + np.sqrt(p2_var)*(s-var_mean), 
                                                               scale=np.sqrt(1-p2_var)*var_std
                                             )
                              )

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2

        ## FDR thresholds for the true and predicted models
        xt_true = scipy.stats.norm.ppf(1-alpha, y_true, np.sqrt(var_true))
        xt_pred = scipy.stats.norm.ppf(1-alpha, y_pred, np.sqrt(var_pred))

        ## True-positive and false-positive rates from true model
        tp_true = scipy.stats.norm.sf(xt_true, y_true+mean_driver, np.sqrt(var_true+var_driver))
        fp_true = alpha
        # ppv_true = 1 - alpha
        # f1_true = 2 * tp_true * ppv_true / (tp_true + ppv_true)

        ## True-positive and false-positive rates from predicted model
        tp_pred = scipy.stats.norm.sf(xt_pred, y_true+mean_driver, np.sqrt(var_true+var_driver))
        fp_pred = scipy.stats.norm.sf(xt_pred, y_true, np.sqrt(var_true))
        # ppv_pred = 1 - fp_pred

        ## Miscalibration and detection probabilities
        mask_cal = (fp_pred < alpha_bounds[1])
        Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size
        tp_rate = tp_pred[mask_cal] / tp_true[mask_cal]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])
            
        if debug:
            return y_true, var_true, xt_true, tp_true, fp_true, \
                   y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
                   r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        else:
            return Pr_miscal, np.mean(tp_rate), r2_mean, r2_var

    def simulate_batch_normal_from_model(self, y_pred, var_pred, pdriver, nsamp, nsim=1000000, alpha=0.05, alpha_bounds=(0.01, 0.10), debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using normal approximations

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std

        ## Simulate mean and variance predictions with specified correlation to the truth
        y_pred = y_pred[ix_lst]
        var_pred = var_pred[ix_lst]

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2

        ## FDR thresholds for the true and predicted models
        xt_true = scipy.stats.norm.ppf(1-alpha/2, y_true, np.sqrt(var_true))
        xt_pred_u = scipy.stats.norm.ppf(1-alpha/2, y_pred, np.sqrt(var_pred))
        xt_pred_l = scipy.stats.norm.ppf(alpha/2, y_pred, np.sqrt(var_pred))

        ## True-positive and false-positive rates from true model
        tp_true = scipy.stats.norm.sf(xt_true, y_true+mean_driver, np.sqrt(var_true+var_driver))
        fp_true = alpha
        # ppv_true = 1 - alpha
        # f1_true = 2 * tp_true * ppv_true / (tp_true + ppv_true)

        ## True-positive and false-positive rates from predicted model
        tp_pred = scipy.stats.norm.sf(xt_pred_u, y_true+mean_driver, np.sqrt(var_true+var_driver))
        fp_pred = scipy.stats.norm.sf(xt_pred_u, y_true, np.sqrt(var_true)) + \
                  scipy.stats.norm.cdf(xt_pred_l, y_true, np.sqrt(var_true))
        # ppv_pred = 1 - fp_pred

        ## Miscalibration and detection probabilities
        mask_cal = (fp_pred < alpha_bounds[1])
        Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size
        tp_rate = tp_pred[mask_cal] / tp_true[mask_cal]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])
            
        if debug:
            return y_true, var_true, xt_true, tp_true, fp_true, \
                   y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
                   r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        else:
            tp_rate[np.isinf(tp_rate)] = 1
            return Pr_miscal, np.mean(tp_rate), r2_mean, r2_var
    
    def simulate_batch_nb(self, p2_mean, p2_std, pdriver, nsamp, nsim=1000000, n_MC=250, alpha=0.05, alpha_bounds=(0.01, 0.10), debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using exact negative binomial distributions.

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        n_MC: int (optional)
            number of monte carlo simulations to perform per window
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std
        
        alpha_true = y_true**2 / var_true
        theta_true = var_true / y_true
        p_true = 1 / (theta_true + 1)
        
        ## Simulate parameters
        y = self.tx(y_true)
        y_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=y_mean + np.sqrt(p2_mean)*(y-y_mean), 
                                                             scale=np.sqrt(1-p2_mean)*y_std
                                           )
                            )

        s = self.tx(var_true)
        var_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=var_mean + np.sqrt(p2_std)*(s-var_mean), 
                                                               scale=np.sqrt(1-p2_std)*var_std
                                             )
                              )

        alpha_pred = y_pred**2 / var_pred
        theta_pred = var_pred / y_pred
        p_pred = 1 / (theta_pred + 1)

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2
        
        ## FDR thresholds for the true and predicted models
        xt_true = scipy.stats.nbinom.ppf(1-alpha, alpha_true, p_true)
        xt_pred = scipy.stats.nbinom.ppf(1-alpha, alpha_pred, p_pred)
        
        ## Simulate background and anomaly mutations
        xb_sim = scipy.stats.nbinom.rvs(alpha_true, p_true, size=(n_MC, n))
        xa_sim = scipy.stats.binom.rvs(nsamp, pdriver, size=(n_MC, n))
        x_sim = xb_sim + xa_sim
        
        ## True-positive and false-positive rates from true model
        tp_true = np.sum(x_sim > xt_true, axis=0) / n_MC
        fp_true = np.sum(xb_sim > xt_true, axis=0) / n_MC
        
        ## True-positive and false-positive rates from true model
        tp_pred = np.sum(x_sim > xt_pred, axis=0) / n_MC
        fp_pred = np.sum(xb_sim > xt_pred, axis=0) / n_MC 
        # ppv_pred = 1 - fp_pred
        
        ## Miscalibration and detection probabilities
        mask_cal = (fp_pred < alpha_bounds[1])
        Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size
        tp_rate = tp_pred[mask_cal] / tp_true[mask_cal]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])

        if debug:
            return y_true, var_true, xt_true, tp_true, fp_true, \
                   y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
                   r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        else:
            return Pr_miscal, np.mean(tp_rate), r2_mean, r2_var

    def simulate_batch_nb_from_model(self, y_pred, var_pred, pdriver, nsamp, nsim=1000000, n_MC=250, alpha=0.05, alpha_bounds=(0.01, 0.10), debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using exact negative binomial distributions.

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        n_MC: int (optional)
            number of monte carlo simulations to perform per window
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std
        
        alpha_true = y_true**2 / var_true
        theta_true = var_true / y_true
        p_true = 1 / (theta_true + 1)
        
        ## predicted parameters
        y_pred = y_pred[ix_lst]
        var_pred = var_pred[ix_lst]
        alpha_pred = y_pred**2 / var_pred
        theta_pred = var_pred / y_pred
        p_pred = 1 / (theta_pred + 1)

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2
        
        ## FDR thresholds for the true and predicted models
        xt_true_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_true, p_true)
        xt_true_l = scipy.stats.nbinom.ppf(alpha/2, alpha_true, p_true)

        xt_pred_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_pred, p_pred)
        xt_pred_l = scipy.stats.nbinom.ppf(alpha/2, alpha_pred, p_pred)
        
        ## Simulate background and anomaly mutations
        xb_sim = scipy.stats.nbinom.rvs(alpha_true, p_true, size=(n_MC, n))
        xa_sim = scipy.stats.binom.rvs(nsamp, pdriver, size=(n_MC, n))
        x_sim = xb_sim + xa_sim
        
        ## True-positive and false-positive rates from true model
        tp_true = np.sum((x_sim > xt_true_u) | (x_sim < xt_true_l), axis=0) / n_MC
        fp_true = np.sum((xb_sim > xt_true_u) | (xb_sim < xt_true_l) , axis=0) / n_MC
        
        ## True-positive and false-positive rates from true model
        tp_pred = np.sum((x_sim > xt_pred_u) | (x_sim < xt_pred_l), axis=0) / n_MC
        fp_pred = np.sum((xb_sim > xt_pred_u) | (xb_sim < xt_pred_l), axis=0) / n_MC 
        # ppv_pred = 1 - fp_pred
        
        ## Miscalibration and detection probabilities
        mask_cal = (fp_pred < alpha_bounds[1])
        Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size

        mask_tp = mask_cal & (tp_true > 0)
        tp_rate = tp_pred[mask_tp] / tp_true[mask_tp]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])

        if debug:
            return Pr_miscal, tp_rate, r2_mean, r2_var, fp_true, fp_pred
            # return y_true, var_true, xt_true, tp_true, fp_true, \
            #        y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
            #        r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        else:
            tp_rate[np.isinf(tp_rate)] = 1
            tp_rate[np.isnan(tp_rate)] = 0
            return Pr_miscal, np.median(tp_rate), r2_mean, r2_var

    def simulate_batch_nb_from_model2(self, y_pred, var_pred, pdriver, nsamp, p_tilde=1, pz=0.01, nsim=1000000, alpha=0.05, debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using exact negative binomial distributions.

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        n_MC: int (optional)
            number of monte carlo simulations to perform per window
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std
        
        alpha_true = y_true**2 / var_true
        theta_true = var_true / y_true
        p_true = 1 / (p_tilde*theta_true + 1)
        
        ## predicted parameters
        y_pred = y_pred[ix_lst]
        var_pred = var_pred[ix_lst]
        alpha_pred = y_pred**2 / var_pred
        theta_pred = var_pred / y_pred
        p_pred = 1 / (p_tilde*theta_pred + 1)

        if np.all(p_pred == p_true):
            print('WTF!!!')

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2

        # print(r2_mean)
        # print(r2_var)
        
        ## FDR thresholds for the true and predicted models
        # xt_true_u = scipy.stats.nbinom.ppf(1-alpha, alpha_true, p_true)
        # xt_true_l = scipy.stats.nbinom.ppf(alpha, alpha_true, p_true)
        xt_true_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_true, p_true)
        xt_true_l = scipy.stats.nbinom.ppf(alpha/2, alpha_true, p_true)

        # xt_pred_u = scipy.stats.nbinom.ppf(1-alpha, alpha_pred, p_pred)
        # xt_pred_l = scipy.stats.nbinom.ppf(alpha, alpha_pred, p_pred)
        xt_pred_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_pred, p_pred)
        xt_pred_l = scipy.stats.nbinom.ppf(alpha/2, alpha_pred, p_pred)
        
        ## Simulate background and anomaly mutations
        xb_sim = scipy.stats.nbinom.rvs(alpha_true, p_true, size=n)
        xa_sim = scipy.stats.binom.rvs(nsamp, pdriver, size=n)
        z_sim = scipy.stats.bernoulli.rvs(pz, size=n)
        x_sim = xb_sim + z_sim*xa_sim
        
        ## True-positive and false-positive rates from true model
        tp_true = np.sum((x_sim[z_sim==1] > xt_true_u[z_sim==1]))
        fn_true = np.sum((x_sim[z_sim==1] <= xt_true_u[z_sim==1]))
        fp_true = np.sum((xb_sim[z_sim==0] > xt_true_u[z_sim==0])|(xb_sim[z_sim==0] < xt_true_l[z_sim==0]))
        tn_true = np.sum((xb_sim[z_sim==0] <= xt_true_u[z_sim==0]))

        zhat_true = np.zeros(n)
        # zhat_true[(x_sim > xt_true_u)] = 1
        zhat_true[(x_sim > xt_true_u) | (x_sim < xt_true_l)] = 1

        f1_true = 2*tp_true / (2*tp_true + fn_true + fp_true)
        # mcc_true = self._calc_mcc(tp_true, tn_true, fp_true, fn_true)
        f1_true = f1_score(z_sim, zhat_true)
        mcc_true = matthews_corrcoef(z_sim, zhat_true)
        ppv_true = precision_score(z_sim, zhat_true)

        # tpr_true = tp_true / (tp_true + fn_true)
        # ppv_true = tp_true / (tp_true + fp_true)
        # f1_true = 2 * tpr_true * ppv_true / (
        # tp_true = np.sum((x_sim > xt_true_u) | (x_sim < xt_true_l), axis=0)
        # fp_true = np.sum((xb_sim > xt_true_u) | (xb_sim < xt_true_l) , axis=0)
        
        ## True-positive and false-positive rates from true model
        tp_pred = np.sum((x_sim[z_sim==1] > xt_pred_u[z_sim==1]))
        fp_pred = np.sum((xb_sim[z_sim==0] > xt_pred_u[z_sim==0]))
        fn_pred = np.sum((x_sim[z_sim==1] <= xt_pred_u[z_sim==1]))
        tn_pred = np.sum((xb_sim[z_sim==0] <= xt_pred_u[z_sim==0]))

        zhat_pred = np.zeros(n)
        # zhat_pred[(x_sim > xt_pred_u)] = 1
        zhat_pred[(x_sim > xt_pred_u) | (x_sim < xt_pred_l)] = 1

        f1_pred = 2*tp_pred / (2*tp_pred + fn_pred + fp_pred)
        f1_pred = f1_score(z_sim, zhat_pred)
        mcc_pred = matthews_corrcoef(z_sim, zhat_pred)
        ppv_pred = precision_score(z_sim, zhat_pred)
        # mcc_pred = self._calc_mcc(tp_pred, tn_pred, fp_pred, fn_pred)
        # ppv_pred = 1 - fp_pred

        tp_true_rate = tp_true / (tp_true + fn_true)
        fp_true_rate = fp_true / (fp_true + tn_true)
        # print(fp_true_rate)

        tp_pred_rate = tp_pred / (tp_pred + fn_pred)
        fp_pred_rate = fp_pred / (fp_pred + tn_pred)

        return tp_true, fn_true, fp_true, f1_true, mcc_true, tp_pred, fn_pred, fp_pred, f1_pred, mcc_pred
        # return tp_true_rate, fp_true_rate, tp_pred_rate, fp_pred_rate
        
        ## Miscalibration and detection probabilities
        # mask_cal = (fp_pred < alpha_bounds[1])
        # Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size

        # mask_tp = mask_cal & (tp_true > 0)
        # tp_rate = tp_pred[mask_tp] / tp_true[mask_tp]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])

        # if debug:
        #     return Pr_miscal, tp_rate, r2_mean, r2_var, fp_true, fp_pred
        #     # return y_true, var_true, xt_true, tp_true, fp_true, \
        #     #        y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
        #     #        r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        # else:
        #     tp_rate[np.isinf(tp_rate)] = 1
        #     tp_rate[np.isnan(tp_rate)] = 0
        #     return Pr_miscal, np.median(tp_rate), r2_mean, r2_var

    def simulate_batch_nb2(self, y_pred, var_pred, pdriver, nsamp, p_tilde=1, pz=0.01, nsim=1000000, alpha=0.05, debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using exact negative binomial distributions.

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        n_MC: int (optional)
            number of monte carlo simulations to perform per window
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std
        
        alpha_true = y_true**2 / var_true
        theta_true = var_true / y_true
        p_true = 1 / (p_tilde*theta_true + 1)
        
        ## predicted parameters
        # y_pred = y_pred[ix_lst]
        # var_pred = var_pred[ix_lst]

        y = self.tx(y_true)
        y_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=y_mean + np.sqrt(p2_mean)*(y-y_mean), 
                                                             scale=np.sqrt(1-p2_mean)*y_std
                                           )
                            )

        s = self.tx(var_true)
        var_pred = self.inv_tx(scipy.stats.norm.rvs(size=n, loc=var_mean + np.sqrt(p2_std)*(s-var_mean), 
                                                               scale=np.sqrt(1-p2_std)*var_std
                                             )
                              )

        alpha_pred = y_pred**2 / var_pred
        theta_pred = var_pred / y_pred
        p_pred = 1 / (p_tilde*theta_pred + 1)

        if np.all(p_pred == p_true):
            print('WTF!!!')

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2

        # print(r2_mean)
        # print(r2_var)
        
        ## FDR thresholds for the true and predicted models
        # xt_true_u = scipy.stats.nbinom.ppf(1-alpha, alpha_true, p_true)
        # xt_true_l = scipy.stats.nbinom.ppf(alpha, alpha_true, p_true)
        xt_true_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_true, p_true)
        xt_true_l = scipy.stats.nbinom.ppf(alpha/2, alpha_true, p_true)

        # xt_pred_u = scipy.stats.nbinom.ppf(1-alpha, alpha_pred, p_pred)
        # xt_pred_l = scipy.stats.nbinom.ppf(alpha, alpha_pred, p_pred)
        xt_pred_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_pred, p_pred)
        xt_pred_l = scipy.stats.nbinom.ppf(alpha/2, alpha_pred, p_pred)
        
        ## Simulate background and anomaly mutations
        xb_sim = scipy.stats.nbinom.rvs(alpha_true, p_true, size=n)
        xa_sim = scipy.stats.binom.rvs(nsamp, pdriver, size=n)
        z_sim = scipy.stats.bernoulli.rvs(pz, size=n)
        x_sim = xb_sim + z_sim*xa_sim
        
        ## True-positive and false-positive rates from true model
        tp_true = np.sum((x_sim[z_sim==1] > xt_true_u[z_sim==1]))
        fn_true = np.sum((x_sim[z_sim==1] <= xt_true_u[z_sim==1]))
        fp_true = np.sum((xb_sim[z_sim==0] > xt_true_u[z_sim==0])|(xb_sim[z_sim==0] < xt_true_l[z_sim==0]))
        tn_true = np.sum((xb_sim[z_sim==0] <= xt_true_u[z_sim==0]))

        zhat_true = np.zeros(n)
        # zhat_true[(x_sim > xt_true_u)] = 1
        zhat_true[(x_sim > xt_true_u) | (x_sim < xt_true_l)] = 1

        f1_true = 2*tp_true / (2*tp_true + fn_true + fp_true)
        # mcc_true = self._calc_mcc(tp_true, tn_true, fp_true, fn_true)
        f1_true = f1_score(z_sim, zhat_true)
        mcc_true = matthews_corrcoef(z_sim, zhat_true)
        ppv_true = precision_score(z_sim, zhat_true)

        # tpr_true = tp_true / (tp_true + fn_true)
        # ppv_true = tp_true / (tp_true + fp_true)
        # f1_true = 2 * tpr_true * ppv_true / (
        # tp_true = np.sum((x_sim > xt_true_u) | (x_sim < xt_true_l), axis=0)
        # fp_true = np.sum((xb_sim > xt_true_u) | (xb_sim < xt_true_l) , axis=0)
        
        ## True-positive and false-positive rates from true model
        tp_pred = np.sum((x_sim[z_sim==1] > xt_pred_u[z_sim==1]))
        fp_pred = np.sum((xb_sim[z_sim==0] > xt_pred_u[z_sim==0]))
        fn_pred = np.sum((x_sim[z_sim==1] <= xt_pred_u[z_sim==1]))
        tn_pred = np.sum((xb_sim[z_sim==0] <= xt_pred_u[z_sim==0]))

        zhat_pred = np.zeros(n)
        # zhat_pred[(x_sim > xt_pred_u)] = 1
        zhat_pred[(x_sim > xt_pred_u) | (x_sim < xt_pred_l)] = 1

        f1_pred = 2*tp_pred / (2*tp_pred + fn_pred + fp_pred)
        f1_pred = f1_score(z_sim, zhat_pred)
        mcc_pred = matthews_corrcoef(z_sim, zhat_pred)
        ppv_pred = precision_score(z_sim, zhat_pred)
        # mcc_pred = self._calc_mcc(tp_pred, tn_pred, fp_pred, fn_pred)
        # ppv_pred = 1 - fp_pred

        tp_true_rate = tp_true / (tp_true + fn_true)
        fp_true_rate = fp_true / (fp_true + tn_true)
        # print(fp_true_rate)

        tp_pred_rate = tp_pred / (tp_pred + fn_pred)
        fp_pred_rate = fp_pred / (fp_pred + tn_pred)

        return tp_true, fn_true, fp_true, f1_true, mcc_true, tp_pred, fn_pred, fp_pred, f1_pred, mcc_pred
        # return tp_true_rate, fp_true_rate, tp_pred_rate, fp_pred_rate
        
        ## Miscalibration and detection probabilities
        # mask_cal = (fp_pred < alpha_bounds[1])
        # Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size

        # mask_tp = mask_cal & (tp_true > 0)
        # tp_rate = tp_pred[mask_tp] / tp_true[mask_tp]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])

        # if debug:
        #     return Pr_miscal, tp_rate, r2_mean, r2_var, fp_true, fp_pred
        #     # return y_true, var_true, xt_true, tp_true, fp_true, \
        #     #        y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
        #     #        r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        # else:
        #     tp_rate[np.isinf(tp_rate)] = 1
        #     tp_rate[np.isnan(tp_rate)] = 0
        #     return Pr_miscal, np.median(tp_rate), r2_mean, r2_var

    def simulate_batch_nb_with_bins(self, y_pred, var_pred, pdriver, nsamp, bins=50, nsim=1000000, alpha=0.05, debug=False, seed=None):
        """ Simulates model calibration and detection probabilities using exact negative binomial distributions.

        Different levels of information can be returned by changing the value of debug

        Parameters
        ----------
        p2_mean: float [0-1] 
            desired R2 of mean accuracy
        p2_std: float [0-1]
            desired R2 of variance accuracy
        pdriver: float [0-1]
            probability of a driver mutation in an individual
        nsamp: int
            number of tumor samples from which to simulate driver mutations
        nsim: int (optional)
            number of simulations to perform
        n_MC: int (optional)
            number of monte carlo simulations to perform per window
        alpha: float [0-1] (optional)
            desired false-positive rate
        alpha_bounds: tuple (optional)
            acceptable FDR interval to consider a model well-calibrates
        debug: Bool (optional)
            return all calculated values?
        seed: int (optional)
            seed for random number generator

        Returns
        -------
        Pr_miscal: float
            probability the model will have an improperly controlled false-discovery rate
        tp_rate: float
            probability of detecting a driver mutation given the model controls FDR
        r2_mean: float
            the achieved R2 accuracy between true mean and simulated mean
        r2_var: float
            the achieved R2 accuracy between true variance and simulated variance
        """
        if seed:
            np.random.seed(seed)

        n = nsim
        ix_lst = np.random.choice(len(self.y_test), size=n, replace=True)

        mean_driver = pdriver*nsamp
        var_driver = nsamp * pdriver * (1-pdriver)
        
        y_true = self.y_test[ix_lst]
        y_mean = self.y_mean
        y_std = self.y_std
        
        var_true = self.var_test[ix_lst]
        var_mean = self.var_mean
        var_std = self.var_std
        
        p_tild = 1. / bins
        alpha_true = y_true**2 / var_true
        theta_true = var_true / y_true
        p_true = 1 / (p_tild * theta_true + 1)
        
        ## predicted parameters
        y_pred = y_pred[ix_lst]
        var_pred = var_pred[ix_lst]
        alpha_pred = y_pred**2 / var_pred
        theta_pred = var_pred / y_pred
        p_pred = 1 / (p_tild * theta_pred + 1)

        ## Calculate achieved R2 for predicted mean and variance 
        r2_mean = scipy.stats.pearsonr(y_true, y_pred)[0]**2
        r2_var  = scipy.stats.pearsonr(var_true, var_pred)[0]**2

        # print(r2_mean)
        # print(r2_var)
        
        ## FDR thresholds for the true and predicted models
        xt_true_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_true, p_true)
        xt_true_l = scipy.stats.nbinom.ppf(alpha/2, alpha_true, p_true)

        xt_pred_u = scipy.stats.nbinom.ppf(1-alpha/2, alpha_pred, p_pred)
        xt_pred_l = scipy.stats.nbinom.ppf(alpha/2, alpha_pred, p_pred)
        
        ## Simulate background and anomaly mutations
        xb_sim = scipy.stats.nbinom.rvs(alpha_true, p_true, size=n)
        xa_sim = scipy.stats.binom.rvs(nsamp, pdriver, size=n)
        z_sim = scipy.stats.bernoulli.rvs(0.01, size=n)
        x_sim = xb_sim + z_sim*xa_sim
        
        ## True-positive and false-positive rates from true model
        tp_true = np.sum((x_sim[z_sim==1] >= xt_true_u[z_sim==1]))
        fn_true = np.sum((x_sim[z_sim==1] < xt_true_u[z_sim==1]))
        fp_true = np.sum((xb_sim[z_sim==0] >= xt_true_u[z_sim==0]))
        tn_true = np.sum((xb_sim[z_sim==0] < xt_true_u[z_sim==0]))

        zhat_true = np.zeros(n)
        zhat_true[(x_sim >= xt_true_u) | (x_sim <= xt_true_l)] = 1

        f1_true = 2*tp_true / (2*tp_true + fn_true + fp_true)
        # mcc_true = self._calc_mcc(tp_true, tn_true, fp_true, fn_true)
        f1_true = f1_score(z_sim, zhat_true)
        mcc_true = matthews_corrcoef(z_sim, zhat_true)

        # tpr_true = tp_true / (tp_true + fn_true)
        # ppv_true = tp_true / (tp_true + fp_true)
        # f1_true = 2 * tpr_true * ppv_true / (
        # tp_true = np.sum((x_sim > xt_true_u) | (x_sim < xt_true_l), axis=0)
        # fp_true = np.sum((xb_sim > xt_true_u) | (xb_sim < xt_true_l) , axis=0)
        
        ## True-positive and false-positive rates from true model
        tp_pred = np.sum((x_sim[z_sim==1] >= xt_pred_u[z_sim==1]))
        fp_pred = np.sum((xb_sim[z_sim==0] >= xt_pred_u[z_sim==0]))
        fn_pred = np.sum((x_sim[z_sim==1] < xt_pred_u[z_sim==1]))
        tn_pred = np.sum((xb_sim[z_sim==0] < xt_pred_u[z_sim==0]))

        zhat_pred = np.zeros(n)
        zhat_pred[(x_sim >= xt_pred_u) | (x_sim <= xt_pred_l)] = 1

        f1_pred = 2*tp_pred / (2*tp_pred + fn_pred + fp_pred)
        f1_pred = f1_score(z_sim, zhat_pred)
        mcc_pred = matthews_corrcoef(z_sim, zhat_pred)
        # mcc_pred = self._calc_mcc(tp_pred, tn_pred, fp_pred, fn_pred)
        # ppv_pred = 1 - fp_pred

        tp_true_rate = tp_true / np.sum(z_sim)
        fp_true_rate = fp_true / (n - np.sum(z_sim))

        tp_pred_rate = tp_pred / np.sum(z_sim)
        fp_pred_rate = fp_pred / (n - np.sum(z_sim))

        return tp_true, fn_true, fp_true, f1_true, mcc_true, tp_pred, fn_pred, fp_pred, f1_pred, mcc_pred
        # return tp_true_rate, fp_true_rate, tp_pred_rate, fp_pred_rate
        
        ## Miscalibration and detection probabilities
        # mask_cal = (fp_pred < alpha_bounds[1])
        # Pr_miscal = 1 - np.sum(mask_cal) / fp_pred.size

        # mask_tp = mask_cal & (tp_true > 0)
        # tp_rate = tp_pred[mask_tp] / tp_true[mask_tp]  ## Relative to ground truth
        # f1_lst = 2 * tp_pred[mask_cal] * ppv_pred[mask_cal] / (tp_pred[mask_cal] + ppv_pred[mask_cal])

        # if debug:
        #     return Pr_miscal, tp_rate, r2_mean, r2_var, fp_true, fp_pred
        #     # return y_true, var_true, xt_true, tp_true, fp_true, \
        #     #        y_pred, var_pred, xt_pred, tp_pred, fp_pred, \
        #     #        r2_mean, r2_var, Pr_miscal, tp_rate, mask_cal
        # else:
        #     tp_rate[np.isinf(tp_rate)] = 1
        #     tp_rate[np.isnan(tp_rate)] = 0
        #     return Pr_miscal, np.median(tp_rate), r2_mean, r2_var

    def _calc_mcc(self, tp, tn, fp, fn):
        num = tp * tn - fp * fn
        denom = np.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))

        if np.isnan(denom):
            print(tp, tn, fp, fn)

        return num / denom
