import warnings

import anndata as ad
from tqdm import tqdm
from typing import Callable
from functools import partial
from abc import abstractmethod
from scipy.sparse import csr_matrix

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

from .util import *
from celldiff.evaluate import evaluate_annotation, integration_eval
from celldiff.util import instantiate_from_config
from celldiff.modules.distributions import Poisson, NegativeBinomial, ZeroInflatedNegativeBinomial


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config):
        super().__init__()
        self.diffusion_model = instantiate_from_config(diff_model_config)

    def forward(self, x, t, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True, mask=None):
        out = self.diffusion_model(x, t, pe_input, conditions, input_gene_list, mask_flag, mask)
        return out

class BaseCellJump(pl.LightningModule):
    def __init__(
            self,
            model_config,
            scheduler_config=None,
            diff_loss="kl",
            nll_loss="both", # x_0, x_t, both
            timesteps=1000,
            beta_start=0.001,
            beta_end='auto',
            beta_schedule='linear',
            logsnr_end=-12,
            signal_stat=1.,
            input_key="input",
            library_size_key="library_size",
            pe_input_key="coord",
            input_gene_list_key="input_gene_list",
            cond_key="cond",
            top_de_key='top_de_dict',
            t_sample=500,
            monitor="val/loss",
            latent_flag=True,
            recon_flag=True,
            integrate_flag=True,
            impute_flag=False,
            pert_flag=False,
            classify_flag=True,
            classifier_config=None,
            guidance_flag=False,
            guidance_config=None,
            cond_to_ignore=None,
            cond_names=None,
            test_target_sum=1e4,
            **kwargs
    ):
        super().__init__()
        self.timesteps = timesteps
        self.diff_loss = diff_loss
        self.nll_loss = nll_loss
        self.input_key = input_key
        self.library_size_key = library_size_key
        self.pe_input_key = pe_input_key
        self.input_gene_list_key = input_gene_list_key
        self.cond_key = cond_key
        self.top_de_key = top_de_key
        self.t_sample = t_sample
        self.cond_to_ignore = cond_to_ignore
        self.test_target_sum = test_target_sum
        self.model = DiffusionWrapper(model_config)
        self.use_scheduler = scheduler_config is not None
        if self.use_scheduler:
            self.scheduler_config = scheduler_config
        if monitor is not None:
            self.monitor = monitor
        
        self.latent_flag = latent_flag
        self.recon_flag = recon_flag
        self.integrate_flag = integrate_flag
        self.impute_flag = impute_flag
        self.pert_flag = pert_flag

        self.classify_flag = classify_flag
        if self.classify_flag:
            assert classifier_config is not None, "Classify flag set but classifier config not passed"
            self.classifier = instantiate_from_config(classifier_config, model=self)

            if cond_names is None:
                cond_names = ["batch", "celltype"]
                warnings.warn(
                    f"Condition names cond_names not set, using default of {cond_names!r}. "
                    "This is needed for specifying metric names for classification. "
                    "Please specify to suppress the warning.",
                    UserWarning,
                    stacklevel=2,
                )
            self.cond_names = cond_names

        self.guidance_flag = guidance_flag
        if self.guidance_flag:
            assert guidance_config is not None, "Guidance flag set but guidance config not passed"
            self.guidance = instantiate_from_config(guidance_config)

        self.register_schedule(beta_start, beta_end, beta_schedule, timesteps, logsnr_end, signal_stat)
        self.register_buffer("unique_conditions", None)

        self.test_step_outputs = []

    def register_schedule(self, beta_start=0.001, beta_end='auto', beta_schedule="linear", timesteps=1000,
                          logsnr_end=-12., signal_stat=1.):      
        # TODO: investigate schedule 
        schedule_dict = get_decay_schedule(f'beta_{beta_schedule}', timesteps, signal_stat=signal_stat, 
                                              beta_start=beta_start, beta_end=beta_end, logsnr_end=logsnr_end)
        betas = schedule_dict['betas']
        alphas = schedule_dict['alphas']
        alphas_prev = np.append(1., alphas[:-1])
        deltas = alphas_prev - alphas
        assert alphas.shape[0] == timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas', to_torch(alphas))
        self.register_buffer('alphas_prev', to_torch(alphas_prev))
        self.register_buffer('deltas', to_torch(deltas))

    def _get_alpha(self, x, t):
        alpha = extract(self.alphas, x, t)
        return alpha
    
    def _get_delta(self, x, t, asym=False):
        delta = extract(self.deltas, x, t)
        return delta
    
    def sample_lbd(self, x, t, library_size, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True):
        prior_dict, q_mask = self.model(x, t, pe_input=pe_input, conditions=conditions, 
                                      input_gene_list=input_gene_list, mask_flag=mask_flag)
        lbd, prior_dist = self.sample_lbd_from_prior(prior_dict, library_size)
        return lbd, prior_dist, q_mask

    def sample(self, x, t, library_size, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True):
        lbd, _, _ = self.sample_lbd(x, t, library_size, pe_input, conditions, input_gene_list, mask_flag)
        lbd = torch.clamp(lbd, max=1e8)
        samples = torch.poisson(lbd)
        return samples
    
    def sample_from_lbd(self, lbd):
        lbd = torch.clamp(lbd, max=1e8)
        samples = torch.poisson(lbd)
        return samples
    
    def q_sample(self, x_0, t, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True):
        library_size = x_0.sum(1, keepdim=True)
        t_0 = torch.full((x_0.shape[0],), fill_value=0, device=x_0.device)
        lbd_0, prior_dist_0, q_mask = self.sample_lbd(x_0, t_0, library_size, pe_input, conditions, input_gene_list, mask_flag)
        rate = self._get_alpha(x_0, t) * lbd_0
        x_t = self.sample_from_lbd(rate)
        return lbd_0, prior_dist_0, x_t, q_mask
        
    def q_posterior(self, x_0, x_t, t, library_size=None, pe_input=None, conditions=None, input_gene_list=None,
                    mask_flag=True):
        library_size = x_0.sum(1, keepdim=True) if library_size is None else library_size
        lbd, _, _ = self.sample_lbd(x_0, t, library_size, pe_input, conditions, input_gene_list, mask_flag)
        rate = self._get_delta(x_0, t) * lbd
        rate = torch.clamp(rate, max=1e8)
        x_prev = torch.poisson(rate) + x_t
        return x_prev

    def scnormalize(self, x, target_sum=1e4, eps=1e-8):
        x = x * target_sum / (x.sum(1, keepdim=True) + eps)
        x = torch.log1p(x)
        return x

    def invervse_scnormalize(self, x, library_size=1e4, eps=1e-8):
        x = torch.exp(x) - 1
        x = x * library_size / (x.sum(1, keepdim=True) + eps)
        return x
    
    @torch.no_grad()
    def p_sample_step(self, x_t, t, library_size, pe_input=None, conditions=None, input_gene_list=None,
                      return_pred=False):
        _t = torch.full((x_t.shape[0],), fill_value=t, device=x_t.device)
        # if self.guidance_flag:
        #     model_out = self.guidance(self.model, sample_flag=True, x_orig=x_0, x_noised=x_t, 
        #                               t=t, pe_input=pe_input, conditions=conditions,
        #                               input_gene_list=input_gene_list)
        if self.cond_to_ignore is not None:
            assert len(self.cond_to_ignore) <= conditions.shape[1]
            assert all([0 <= x < conditions.shape[1] for x in self.cond_to_ignore])
            conditions[:, self.cond_to_ignore] = 0
        # pred_x_0 = self.sample(x_t, _t, library_size, pe_input, conditions, input_gene_list, mask_flag=False)
        pred_x_0, _, _ = self.sample_lbd(x_t, _t, library_size, pe_input, conditions, input_gene_list, mask_flag=False)
        x_prev = self.q_posterior(pred_x_0, x_t, _t, library_size, pe_input, conditions, input_gene_list, mask_flag=False)
        return (x_prev, pred_x_0) if return_pred else x_prev

    @torch.no_grad()
    def p_sample(self, x_start, t_start, pe_input=None, conditions=None, input_gene_list=None, 
                 return_pred=False, library_size=None):
        assert t_start <= self.timesteps
        library_size = x_start.sum(1, keepdim=True) if library_size is None else library_size
        if t_start == self.timesteps:
            x_t = torch.zeros_like(x_start)
        else:
            _, _, x_t, _ = self.q_sample(x_0=x_start, t=t_start)

        pred_x_0 = torch.zeros_like(x_t)
        for t in tqdm(reversed(range(0, t_start)), desc='Sampling t', total=int(t_start + 1)):
            x_t, pred_x_0 = self.p_sample_step(
                x_t, t, library_size, pe_input, conditions, input_gene_list, return_pred=True,
            )
        x_0 = x_t
        return x_0, (pred_x_0 if return_pred else None)

    def get_loss(self, x_0, t, pe_input=None, conditions=None, input_gene_list=None, mask_flag=True,
                 w_diff=None, w_nll=None):
        lbd_0, prior_dist_0, x_t, mask = self.q_sample(x_0, t, pe_input, conditions, input_gene_list, mask_flag=mask_flag)
        # if self.guidance_flag:
        #     model_out, mask = self.guidance(self.model, sample_flag=False, x_orig=x_0, x_noised=x_t, 
        #                                     t=t, pe_input=pe_input, conditions=conditions, 
        #                                     input_gene_list=input_gene_list)
        model_out, _ = self.model(x_t, t, pe_input, conditions, input_gene_list, mask=mask)
        loss, diff_loss, nll_loss = self.calculate_loss(t, x_0, lbd_0, prior_dist_0, model_out, mask, w_diff, w_nll)
        return loss, diff_loss, nll_loss
    
    def calculate_loss(self, t, x_0, lbd_0, prior_dist_0, out_dict, mask, w_diff=None, w_nll=None):
        library_size = x_0.sum(1, keepdim=True)
        lbd_t, prior_dist_t = self.sample_lbd_from_prior(out_dict, library_size)

        if self.diff_loss == "kl":  # no re-weighting
            delta = self._get_delta(lbd_0, t)
            diff_loss = flat_mean(poisson_kl(delta * lbd_0, delta * lbd_t))
        elif self.diff_loss == "kl_rev":  # no re-weighting
            delta = self._get_delta(lbd_0, t)
            diff_loss = flat_mean(poisson_kl(delta * lbd_t, delta * lbd_0))
        elif self.diff_loss == "kl_simple":  # weight: 1 / delta
            diff_loss = flat_mean(poisson_kl(lbd_0, lbd_t))
        elif self.diff_loss == "kl_alpha":  # weight: sqrt(alpha) / delta
            alpha = self._get_alpha(lbd_0, t)
            diff_loss = flat_mean(poisson_kl(lbd_0, lbd_t)) * alpha.sqrt()
        
        if self.nll_loss == 'x_t':
            nll_loss = -flat_mean(prior_dist_t.log_prob(x_0) * mask)
        elif self.nll_loss == 'x_0':
            nll_loss = -flat_mean(prior_dist_0.log_prob(x_0) * mask)
        elif self.nll_loss == 'both':
            nll_loss = -flat_mean((prior_dist_0.log_prob(x_0) + prior_dist_t.log_prob(x_0)) * mask)
        else:
            raise NotImplementedError(f'Unsupported nll loss type: {self.nll_loss}')
        
        w_diff = self.w_diff if w_diff is None else w_diff
        w_nll = self.w_nll if w_nll is None else w_nll
        loss = w_diff * diff_loss + w_nll * nll_loss
        return loss.mean(), diff_loss.mean(), nll_loss.mean()

    def forward(self, x, *args, **kwargs):
        # t = torch.randint(0, self.timesteps, (x.shape[0],), dtype=torch.int64, device=self.device)
        t = torch.randint(0, self.timesteps, (1,), dtype=torch.int64, device=self.device)
        return self.get_loss(x, t, *args, **kwargs)

    def get_input(self, batch, k):
        if k in batch.keys():
            x = batch[k]
            if isinstance(x, torch.Tensor):
                x = x.to(memory_format=torch.contiguous_format).float()
        else:
            x = None
        return x

    def maybe_record_conditions(self, batch):
        """Gather conditions information over the full dataset in the first
        training epoch.

        """
        conditions = self.get_input(batch, self.cond_key)
        if (self.current_epoch == 0) and (conditions is not None):
            self.unique_conditions = (
                conditions.unique(dim=0)
                if self.unique_conditions is None
                else torch.cat((self.unique_conditions, conditions)).unique(dim=0)
            )

    def shared_step(self, batch):
        x = self.get_input(batch, self.input_key)
        pe_input = self.get_input(batch, self.pe_input_key)
        conditions = self.get_input(batch, self.cond_key)
        input_gene_list = self.get_input(batch, self.input_gene_list_key)
        loss, diff_loss, nll_loss = self(x, pe_input=pe_input, conditions=conditions, input_gene_list=input_gene_list)
        return loss, diff_loss, nll_loss

    def training_step(self, batch, batch_idx):
        self.maybe_record_conditions(batch)

        loss, diff_loss, nll_loss = self.shared_step(batch)
        loss_dict = {
            'train/loss': loss,
            'train/diff_loss': diff_loss,
            'train/nll_loss': nll_loss,
        }
        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
        self.log('global_step', self.global_step,
                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
        if self.use_scheduler:
            lr = self.optimizers().param_groups[0]['lr']
            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        loss, diff_loss, nll_loss = self.shared_step(batch)
        loss_dict = {
            'val/loss': loss,
            'val/diff_loss': diff_loss,
            'val/nll_loss': nll_loss,
        }
        self.log_dict(loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True)

    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        x = self.get_input(batch, self.input_key)
        pe_input = self.get_input(batch, self.pe_input_key)
        conditions = self.get_input(batch, self.cond_key)
        input_gene_list = self.get_input(batch, self.input_gene_list_key)
        t_sample = torch.tensor([self.t_sample]).to(x).int()
        recon, _ = self.p_sample(x, t_sample, pe_input, conditions, input_gene_list)
        if self.latent_flag:
            input_conditions = conditions.detach().clone()
            if self.cond_to_ignore is not None:
                assert len(self.cond_to_ignore) <= input_conditions.shape[1]
                assert all([0 <= x < input_conditions.shape[1] for x in self.cond_to_ignore])
                input_conditions[:, self.cond_to_ignore] = 0
            latent = self.model.diffusion_model.get_latent(x, pe_input=pe_input, conditions=input_conditions,
                                                           input_gene_list=input_gene_list)

        if self.classify_flag:
            input_conditions = conditions.detach().clone()
            if self.cond_to_ignore is not None:
                assert len(self.cond_to_ignore) <= input_conditions.shape[1]
                assert all([0 <= x < input_conditions.shape[1] for x in self.cond_to_ignore])
                input_conditions[:, self.cond_to_ignore] = 0
            pred_conditions = self.classifier(x, input_conditions)

        out = {
            'x': x.cpu(),
            'recon': recon.cpu(),
            'latent': latent.cpu() if self.latent_flag else None,
            'pred_conditions': pred_conditions.cpu() if self.classify_flag else None,
            'conditions': conditions.cpu(),
        }
        self.test_step_outputs.append(out)

        return out

    @torch.no_grad()
    def on_test_epoch_end(self):
        outputs = self.test_step_outputs
        x = torch.cat([outdict['x'].cpu() for outdict in outputs])
        recon = torch.cat([outdict['recon'].cpu() for outdict in outputs])
        conditions = torch.cat([outdict['conditions'].cpu() for outdict in outputs]).numpy()

        metrics_dict = {}
        if self.classify_flag:
            pred_conditions = torch.cat([outdict['pred_conditions'].cpu() for outdict in outputs]).numpy()
            pred_correct = pred_conditions == conditions

            # Classification metrics for individual conditions
            for i, cond_name in enumerate(self.cond_names):
                if not (self.cond_to_ignore is not None and i in self.cond_to_ignore):
                    pred, true = pred_conditions[:, i], conditions[:, i]
                    metrics_dict.update(evaluate_annotation(true, pred, cond_name))

            # Accuracy for combined conditions (global avg, not macro)
            metrics_dict['overall_acc'] = pred_correct.all(1).mean()
        
        poisson_nll = nn.PoissonNLLLoss(log_input=False, full=True)(recon, x).item()
        rmse = np.sqrt(F.mse_loss(recon, x).item())
        x_normed = self.scnormalize(x, target_sum=self.test_target_sum)
        recon_normed = self.scnormalize(recon, target_sum=self.test_target_sum)
        rmse_normed = np.sqrt(F.mse_loss(recon_normed, x_normed).item())

        if self.integrate_flag:
            # adata = ad.AnnData(csr_matrix(recon.numpy()))
            adata = ad.AnnData(csr_matrix(recon_normed.numpy()))
            adata.obs[['batch', 'cell_type']] = conditions[:, :2]
            adata.obs[['batch', 'cell_type']] = adata.obs[['batch', 'cell_type']].astype('category')
            if self.latent_flag:
                latent = torch.cat([outdict['latent'].cpu() for outdict in outputs]).numpy()
                adata.obsm['X_emb'] = latent
                metrics_dict.update(integration_eval(adata))
            else:
                metrics_dict.update(integration_eval(adata, use_rep=None))
            del adata
        metrics_dict['rmse'] = rmse
        metrics_dict['rmse_normed'] = rmse_normed
        metrics_dict['poisson_nll'] = poisson_nll

        self.log_dict(metrics_dict, prog_bar=True, logger=True, on_step=False, on_epoch=True)
        self.log('t_sample', self.t_sample, prog_bar=True, logger=True, on_step=False, on_epoch=True)

        self.test_step_outputs.clear()

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

class PoissonCellJump(BaseCellJump):
    def __init__(
            self,
            model_config,
            scheduler_config=None,
            diff_loss="kl",
            nll_loss="both",
            timesteps=1000,
            beta_start=0.001,
            beta_end='auto',
            beta_schedule='linear',
            logsnr_end=-12,
            signal_stat=1.,
            input_key="input",
            pe_input_key="coord",
            input_gene_list_key="input_gene_list",
            cond_key="cond",
            top_de_key='top_de_dict',
            t_sample=500,
            monitor="val/loss",
            latent_flag=True,
            recon_flag=True,
            integrate_flag=True,
            impute_flag=False,
            pert_flag=False,
            classify_flag=True,
            classifier_config=None,
            guidance_flag=False,
            guidance_config=None,
            cond_to_ignore=None,
            cond_names=None,
            test_target_sum=1e4,
            w_diff=1.,
            w_nll=1.,
            rescale_flag=False,
            **kwargs
    ):
        super().__init__(
            model_config=model_config, scheduler_config=scheduler_config, diff_loss=diff_loss, nll_loss=nll_loss, timesteps=timesteps,
            beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, logsnr_end=logsnr_end, signal_stat=signal_stat,
            input_key=input_key, pe_input_key=pe_input_key, input_gene_list_key=input_gene_list_key, cond_names=cond_names,
            cond_key=cond_key, top_de_key=top_de_key, t_sample=t_sample, monitor=monitor, recon_flag=recon_flag, latent_flag=latent_flag,
            integrate_flag=integrate_flag, impute_flag=impute_flag, pert_flag=pert_flag, classify_flag=classify_flag, 
            classifier_config=classifier_config, guidance_flag=guidance_flag, guidance_config=guidance_config, 
            cond_to_ignore=cond_to_ignore, test_target_sum=test_target_sum)
        
        self.dist_func = Poisson
        self.w_diff = w_diff
        self.w_nll = w_nll
        self.rescale_flag = rescale_flag
    
    def sample_lbd_from_prior(self, prior_dict, library_size):
        mean = prior_dict['lbd'] * library_size if self.rescale_flag else torch.exp(prior_dict['lbd'])
        prior_dist = self.dist_func(
            rate=mean,
        )
        return prior_dist.rate, prior_dist


class NegativeBinomialCellJump(BaseCellJump):
    def __init__(
            self,
            model_config,
            scheduler_config=None,
            diff_loss="kl",
            nll_loss="both",
            timesteps=1000,
            beta_start=0.001,
            beta_end='auto',
            beta_schedule='linear',
            logsnr_end=-12,
            signal_stat=1.,
            input_key="input",
            pe_input_key="coord",
            input_gene_list_key="input_gene_list",
            cond_key="cond",
            top_de_key='top_de_dict',
            t_sample=500,
            monitor="val/loss",
            latent_flag=True,
            recon_flag=True,
            integrate_flag=True,
            impute_flag=False,
            pert_flag=False,
            classify_flag=True,
            classifier_config=None,
            guidance_flag=False,
            guidance_config=None,
            cond_to_ignore=None,
            cond_names=None,
            w_diff=1.,
            w_nll=1.,
            test_target_sum=1e4,
            rescale_flag=False,
            **kwargs
    ):
        super().__init__(
            model_config=model_config, scheduler_config=scheduler_config, diff_loss=diff_loss, nll_loss=nll_loss, timesteps=timesteps,
            beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, logsnr_end=logsnr_end, signal_stat=signal_stat,
            input_key=input_key, pe_input_key=pe_input_key, input_gene_list_key=input_gene_list_key, cond_names=cond_names,
            cond_key=cond_key, top_de_key=top_de_key, t_sample=t_sample, monitor=monitor, recon_flag=recon_flag, latent_flag=latent_flag,
            integrate_flag=integrate_flag, impute_flag=impute_flag, pert_flag=pert_flag, classify_flag=classify_flag, 
            classifier_config=classifier_config, guidance_flag=guidance_flag, guidance_config=guidance_config, 
            cond_to_ignore=cond_to_ignore, test_target_sum=test_target_sum)
        
        self.dist_func = NegativeBinomial
        self.w_diff = w_diff
        self.w_nll = w_nll
        self.rescale_flag = rescale_flag
    
    def sample_lbd_from_prior(self, prior_dict, library_size):
        mean = prior_dict['mu'] * library_size if self.rescale_flag else torch.exp(prior_dict['mu'])
        prior_dist = self.dist_func(
            mu=mean, 
            phi=torch.exp(prior_dict['phi']),
        )            
        return prior_dist.mu, prior_dist


class ZeroInflatedNegativeBinomialCellJump(BaseCellJump):
    def __init__(
            self,
            model_config,
            scheduler_config=None,
            diff_loss="kl",
            nll_loss="both",
            timesteps=1000,
            beta_start=0.001,
            beta_end='auto',
            beta_schedule='linear',
            logsnr_end=-12,
            signal_stat=1.,
            input_key="input",
            pe_input_key="coord",
            input_gene_list_key="input_gene_list",
            cond_key="cond",
            top_de_key='top_de_dict',
            t_sample=500,
            monitor="val/loss",
            latent_flag=True,
            recon_flag=True,
            integrate_flag=True,
            impute_flag=False,
            pert_flag=False,
            classify_flag=True,
            classifier_config=None,
            guidance_flag=False,
            guidance_config=None,
            cond_to_ignore=None,
            cond_names=None,
            test_target_sum=1e4,
            w_diff=1.,
            w_nll=1.,
            rescale_flag=False,
            **kwargs
    ):
        super().__init__(
            model_config=model_config, scheduler_config=scheduler_config, diff_loss=diff_loss, nll_loss=nll_loss, timesteps=timesteps,
            beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, logsnr_end=logsnr_end, signal_stat=signal_stat,
            input_key=input_key, pe_input_key=pe_input_key, input_gene_list_key=input_gene_list_key, cond_names=cond_names,
            cond_key=cond_key, top_de_key=top_de_key, t_sample=t_sample, monitor=monitor, recon_flag=recon_flag, latent_flag=latent_flag,
            integrate_flag=integrate_flag, impute_flag=impute_flag, pert_flag=pert_flag, classify_flag=classify_flag, 
            classifier_config=classifier_config, guidance_flag=guidance_flag, guidance_config=guidance_config, 
            cond_to_ignore=cond_to_ignore, test_target_sum=test_target_sum)
        
        self.dist_func = ZeroInflatedNegativeBinomial
        self.w_diff = w_diff
        self.w_nll = w_nll
        self.rescale_flag = rescale_flag
    
    def sample_lbd_from_prior(self, prior_dict, library_size):
        mean = prior_dict['mu'] * library_size if self.rescale_flag else torch.exp(prior_dict['mu'])
        prior_dist = self.dist_func(
            mu=mean, 
            phi=torch.exp(prior_dict['phi']),
            pi_logits=prior_dict['pi_logits'],
        )
        lbd = prior_dist.mu
        is_zero = torch.rand_like(lbd) <= prior_dist.pi_probs
        lbd = torch.where(is_zero, torch.zeros_like(lbd), lbd)
        return lbd, prior_dist
