"""Zero-shot classification from diffusion model.

Modified from
https://github.com/diffusion-classifier/diffusion-classifier/blob/a5c4eb8f4d5d68cf85067eb0847255da3b5dcf6e/eval_prob_adaptive.py

@misc{li2023diffusion,
      title={Your Diffusion Model is Secretly a Zero-Shot Classifier},
      author={Alexander C. Li and Mihir Prabhudesai and Shivam Duggal and Ellis Brown and Deepak Pathak},
      year={2023},
      eprint={2303.16203},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

"""
import inspect
import itertools
import os
import warnings
from copy import deepcopy
from functools import partial
from glob import glob
from typing import Any, Callable, Literal, List, Dict, Optional, Union, get_args

import numpy as np
import pytorch_lightning as pl
import torch
import tqdm
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange
from natsort import natsorted

from celldiff.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from celldiff.util import (
    check_str_option,
    default,
    instantiate_from_config,
    ismap,
    list_exclude,
    log_txt_as_img,
)

__models__ = {
    'class_label': EncoderUNetModel,
    'segmentation': UNetModel
}

CLF_LOSS_TYPE = Literal["l1", "l2", "huber", "poisson_kl"]
CLF_QUERY_MODE = Literal["all", "seen", "batch_all", "batch_seen", "specified"]
TS_SAMPLER_TYPE = Literal["IterativeUniform"]


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class NoisyLatentImageClassifier(pl.LightningModule):

    def __init__(self,
                 diffusion_path,
                 num_classes,
                 ckpt_path=None,
                 pool='attention',
                 label_key=None,
                 diffusion_ckpt_path=None,
                 scheduler_config=None,
                 weight_decay=1.e-2,
                 log_steps=10,
                 monitor='val/loss',
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.num_classes = num_classes
        # get latest config of diffusion model
        diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
        self.diffusion_config = OmegaConf.load(diffusion_config).model
        self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
        self.load_diffusion()

        self.monitor = monitor
        self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
        self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
        self.log_steps = log_steps

        self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
            else self.diffusion_model.cond_stage_key

        assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'

        if self.label_key not in __models__:
            raise NotImplementedError()

        self.load_classifier(ckpt_path, pool)

        self.scheduler_config = scheduler_config
        self.use_scheduler = self.scheduler_config is not None
        self.weight_decay = weight_decay

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
            sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
        if len(unexpected) > 0:
            print(f"Unexpected Keys: {unexpected}")

    def load_diffusion(self):
        model = instantiate_from_config(self.diffusion_config)
        self.diffusion_model = model.eval()
        self.diffusion_model.train = disabled_train
        for param in self.diffusion_model.parameters():
            param.requires_grad = False

    def load_classifier(self, ckpt_path, pool):
        model_config = deepcopy(self.diffusion_config.params.unet_config.params)
        model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
        model_config.out_channels = self.num_classes
        if self.label_key == 'class_label':
            model_config.pool = pool

        self.model = __models__[self.label_key](**model_config)
        if ckpt_path is not None:
            print('#####################################################################')
            print(f'load from ckpt "{ckpt_path}"')
            print('#####################################################################')
            self.init_from_ckpt(ckpt_path)

    @torch.no_grad()
    def get_x_noisy(self, x, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x))
        continuous_sqrt_alpha_cumprod = None
        if self.diffusion_model.use_continuous_noise:
            continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
            # todo: make sure t+1 is correct here

        return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
                                             continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)

    def forward(self, x_noisy, t, *args, **kwargs):
        return self.model(x_noisy, t)

    @torch.no_grad()
    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = rearrange(x, 'b h w c -> b c h w')
        x = x.to(memory_format=torch.contiguous_format).float()
        return x

    @torch.no_grad()
    def get_conditioning(self, batch, k=None):
        if k is None:
            k = self.label_key
        assert k is not None, 'Needs to provide label key'

        targets = batch[k].to(self.device)

        if self.label_key == 'segmentation':
            targets = rearrange(targets, 'b h w c -> b c h w')
            for down in range(self.numd):
                h, w = targets.shape[-2:]
                targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')

            # targets = rearrange(targets,'b c h w -> b h w c')

        return targets

    def compute_top_k(self, logits, labels, k, reduction="mean"):
        _, top_ks = torch.topk(logits, k, dim=1)
        if reduction == "mean":
            return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
        elif reduction == "none":
            return (top_ks == labels[:, None]).float().sum(dim=-1)

    def on_train_epoch_start(self):
        # save some memory
        self.diffusion_model.model.to('cpu')

    @torch.no_grad()
    def write_logs(self, loss, logits, targets):
        log_prefix = 'train' if self.training else 'val'
        log = {}
        log[f"{log_prefix}/loss"] = loss.mean()
        log[f"{log_prefix}/acc@1"] = self.compute_top_k(
            logits, targets, k=1, reduction="mean"
        )
        log[f"{log_prefix}/acc@5"] = self.compute_top_k(
            logits, targets, k=5, reduction="mean"
        )

        self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
        self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
        self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
        lr = self.optimizers().param_groups[0]['lr']
        self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)

    def shared_step(self, batch, t=None):
        x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
        targets = self.get_conditioning(batch)
        if targets.dim() == 4:
            targets = targets.argmax(dim=1)
        if t is None:
            t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
        else:
            t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
        x_noisy = self.get_x_noisy(x, t)
        logits = self(x_noisy, t)

        loss = F.cross_entropy(logits, targets, reduction='none')

        self.write_logs(loss.detach(), logits.detach(), targets.detach())

        loss = loss.mean()
        return loss, logits, x_noisy, targets

    def training_step(self, batch, batch_idx):
        loss, *_ = self.shared_step(batch)
        return loss

    def reset_noise_accs(self):
        self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
                          range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}

    def on_validation_start(self):
        self.reset_noise_accs()

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        loss, *_ = self.shared_step(batch)

        for t in self.noisy_acc:
            _, logits, _, targets = self.shared_step(batch, t)
            self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
            self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

        if self.use_scheduler:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [optimizer], scheduler

        return optimizer

    @torch.no_grad()
    def log_images(self, batch, N=8, *args, **kwargs):
        log = dict()
        x = self.get_input(batch, self.diffusion_model.first_stage_key)
        log['inputs'] = x

        y = self.get_conditioning(batch)

        if self.label_key == 'class_label':
            y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
            log['labels'] = y

        if ismap(y):
            log['labels'] = self.diffusion_model.to_rgb(y)

            for step in range(self.log_steps):
                current_time = step * self.log_time_interval

                _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)

                log[f'inputs@t{current_time}'] = x_noisy

                pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
                pred = rearrange(pred, 'b h w c -> b c h w')

                log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)

        for key in log:
            log[key] = log[key][:N]

        return log


class IterativeUniformTimeStepSampler:
    """Sample time steps in uniform grids with increasing granularity."""

    def __init__(
        self,
        max_time_steps: int,
        max_num_steps: int,
        num_repeats: int = 1,
        random_state: Optional[int] = 42,
    ):
        self.max_time_steps = max_time_steps
        self.max_num_steps = max_num_steps
        self.num_repeats = num_repeats
        self.rng = np.random.default_rng(random_state)

        time_steps = self._get_time_steps(max_time_steps, max_num_steps)
        if len(time_steps) > max_num_steps:
            # print(time_steps)
            time_steps = self._sample_sorted(time_steps, max_num_steps, self.rng)
            # print(time_steps)
        self.time_steps = time_steps

        self.sampled_set = set()

    @staticmethod
    def _get_time_steps(max_time_steps: int, num_steps: int) -> List[int]:
        interval = max_time_steps // num_steps
        start = interval // 2
        return list(range(start, max_time_steps, interval))

    @staticmethod
    def _sample_sorted(
        time_steps: List[int],
        max_num_steps: int,
        rng: Optional[np.random.Generator] = None,
    ) -> List[int]:
        rng = rng or np.random.default_rng()
        return sorted(rng.choice(time_steps, size=max_num_steps, replace=False))

    def __call__(
        self,
        num_steps: int,
        register_steps: bool = True,
        shuffle: bool = True,
    ) -> List[int]:
        """Sample time steps.

        Args:
            num_steps: Number of uniform steps to take.
            register_steps: If set to True, then only return steps that are not
                drawn from previous draws and register the returned steps as
                drawn.
            shuffle: If set to True and the number of uniform steps exceed the
                specified num_steps, then uniformly randomly sample steps.

        """
        idxs = self._get_time_steps(self.max_num_steps, num_steps)

        if shuffle and len(idxs) > num_steps:
            idxs = self._sample_sorted(idxs, num_steps, self.rng)

        if register_steps:
            idxs = sorted(set(idxs) - self.sampled_set)
            self.sampled_set.update(idxs)

        time_steps = [self.time_steps[i] for i in idxs]
        time_steps = np.repeat(time_steps, self.num_repeats).tolist()

        return time_steps


class DiffusionClassifier:
    """Diffusion classifier object.

    From a high level, the classifier works as follow

        1. Use all unique conditions as candidate.
        2. Evaluate reconstruction errors of cells with different candidate
           conditions.
        3. Select the top k conditions for each cell that resulted in the
           lowest errors and use these as the new candidate conditions.
           Note that k gradually decrease from round to round, and eventually
           drops to 1 in the final round to select the best matching condition
           for each cell.
        4. Repeat 2 and 3 until the last round as reached.

    Args:
        n_samples_list: List of samples to draw from in each round.
        to_kee_list: List of top conditions to keep for next round of
            evaluation.
        n_trials: Number of trials per sampled time points.
        loss: Type of loss to evaluate the error of the predictions given a
            particular condition.
        query_mode: What conditions to query for. "all" uses all possible
            combinations of the conditions, "seen" uses only the combinations
            seen during training, and "specified" uses the combinations passed
            to the :attr:`conditions` argument. "batch_all" and "batch_seen"
            are analogous to "all" and "seen" but only select from the batch.
        model: Model object to bind to. See
            :class:`celldiff.models.diffusion.ddpm.DDPM` for an example.

    Notes:
        There are several requirements for the binding model to work with the
        :class:`DiffusionClassifier` object.

            1. There should be a :attr:`diffusion_model` (or
               :attr:`model.diffusion_model`) attribute, which holds the
               denoising diffusion model.
            2. There should be a :attr:`num_timesteps` attribute about the
               maximum time step size of the diffusion model.
            3. There should be a :meth:`q_sample` method that performs forward
               sampling of the original data, i.e., adding noise.

    """

    def __init__(
        self,
        n_samples_list: List[int],
        to_keep_list: List[int],
        n_trials: int = 1,
        loss: CLF_LOSS_TYPE = "l2",
        query_mode: CLF_QUERY_MODE = "all",
        inference_mask: Union[bool, str] = "all",
        time_step_sampler: TS_SAMPLER_TYPE = "IterativeUniform",
        model: Optional[Any] = None,
        conds_to_fix: Optional[Union[str, List[str]]] = None,
        conds_to_null: Optional[Union[str, List[str]]] = None,
    ):
        assert len(n_samples_list) == len(to_keep_list)
        assert to_keep_list[-1] == 1, "Last trial must only select one best matching condition."
        self.n_samples_list = n_samples_list
        self.to_keep_list = to_keep_list
        self.n_trials = n_trials
        self.loss = check_str_option("loss", loss, CLF_LOSS_TYPE)
        self.query_mode = check_str_option("query_mode", query_mode, CLF_QUERY_MODE)
        convert = lambda x: [x] if isinstance(x, str) else x
        self.conds_to_fix = set(default(convert(conds_to_fix), []))
        self.conds_to_null = set(default(convert(conds_to_null), []))
        self.inference_mask = inference_mask
        self._model = model

    @property
    def get_time_step_sampler(self):
        return self._time_step_sampler_cls

    @get_time_step_sampler.setter
    def get_time_step_sampler(self, val):
        if val == "IterativeUniform":
            self._time_step_sampler_cls = IterativeUniformTimeStepSampler
        elif val not in (opts := get_args(TS_SAMPLER_TYPE)):
            raise ValueError(f"Unknown time step sampler {val!r}, available options are {opts}")
        else:
            raise NotImplementedError(f"[DEVERROR] please implement {val}")

    @property
    def model(self):
        return self._model

    def __call__(
        self,
        x: torch.Tensor,
        x_conds: Dict[str, torch.Tensor],
        model: Optional[Any] = None,
        specified_conds: Optional[torch.Tensor] = None,
    ):
        """Predict conditions of each cell.

        Args:
            x_orig: Original expression values to be used for generating noised
                expression input to the diffusion model.
            x_conds: Conditions of all cells (i.e., rows of x_orig). These
                are the label to be predicted against.
            model: Model object to bind to. See
                :class:`celldiff.models.diffusion.ddpm.DDPM` for an example.
                Only needed if model was not binded during initialization.
            specified_conds: Specify the conditions to query for in "specified"
                query mode.

        Returns:
            Predicted conditions for each cell.

        """
        # Hook up with model
        if (model is None) and (self.model is None):
            raise ValueError("Model object not stored during init, please pass during call.")
        if (model is not None) and (self.model is not None):
            warnings.warn(
                f"Model object already specified during init: {self.model} but "
                f"also passed during call: {model}. Using the passed model "
                "in this call. Please remove duplicated model specification "
                "to suppress this warning.",
                UserWarning,
                stacklevel=2,
            )
        model = default(model, self.model)
        diffusion_model, timesteps = self.get_assets_from_model(model)
        conditions = self._get_conditions(model, x_conds, specified_conds)
        assert conditions, "Failed to extract query conditions"

        # Set constants
        max_n_samples = max(self.n_samples_list)
        num_cells = len(x)
        num_conditions = len(conditions[list(conditions)[0]])
        num_t_splits = len(self.n_samples_list)
        n_samples_list = self.n_samples_list
        to_keep_list = self.to_keep_list

        if num_conditions < (max_to_keep := max(to_keep_list)):
            warnings.warn(
                f"Maximum conditions to keep ({max_to_keep}) exceeds the total "
                f"number of conditions available ({num_conditions}). Implicitly "
                f"setting max number of conditions to keep to {num_conditions}.",
                UserWarning,
                stacklevel=2,
            )
            to_keep_list = [min(num_conditions, i) for i in to_keep_list]

        # TODO: support other samplers
        time_step_sampler = IterativeUniformTimeStepSampler(
            timesteps,
            max_n_samples,
            self.n_trials,
        )

        eval_error_func = partial(
            self.eval_error,
            diffusion_model=diffusion_model,
            q_sample=model.q_sample,
            x_orig=x,
            x_conds=x_conds,
            query_conds=conditions,
            loss=self.loss,
            inference_mask=self.inference_mask,
        )

        full_error_tensor = torch.full((num_t_splits, num_cells, num_conditions), torch.nan)
        best_idxs = torch.arange(num_conditions).repeat(num_cells, 1)
        for i, (n_samples, n_to_keep) in enumerate(zip(n_samples_list, to_keep_list)):
            error_tensor = torch.zeros_like(best_idxs, dtype=torch.float)
            curr_t_to_eval = time_step_sampler(n_samples)

            # Perform one round of evaluation
            for j in tqdm.trange(
                best_idxs.shape[1],
                leave=False,
                desc=f"Round {i + 1} / {len(n_samples_list)}",
            ):
                error_tensor[:, j] = eval_error_func(ts=curr_t_to_eval, query_cond_idx=best_idxs[:, j])

            # Aggregate evaluation results across rounds
            if i == 0:
                full_error_tensor[i] = error_tensor
                prev_size = len(curr_t_to_eval)
            else:
                curr_size = len(curr_t_to_eval)
                full_error_tensor[i].scatter_(dim=1, index=best_idxs, src=error_tensor)
                # Aggregate errors with previous runs
                full_error_tensor[i] = (
                    (full_error_tensor[i] * curr_size + full_error_tensor[i-1] * prev_size)
                    / (curr_size + prev_size)
                )
                prev_size += curr_size

            # Find best conditions for each cell
            new_best_vals, new_best_idxs = error_tensor.topk(n_to_keep, dim=1, largest=False)
            assert not torch.isnan(new_best_vals).any(), "Found nans in selected entries."
            # Convert idx back to original condition idx
            best_idxs = best_idxs.gather(dim=1, index=new_best_idxs)

        pred_conds, target_conds = {}, {}
        for i, j in conditions.items():
            if len(x_conds[i].unique()) == 1:
                warnings.warn(
                    "Current batch only has one type of {i}, try increasing batch size?",
                    RuntimeWarning,
                    stacklevel=2,
                )
            pred_conds[i] = j[best_idxs.flatten()].cpu()
            target_conds[i] = x_conds[i].cpu()

        return pred_conds, target_conds

    def _get_conditions(
        self,
        model: torch.nn.Module,
        x_conds: Dict[str, torch.Tensor],
        specified_conds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # First prepare the unique conditions, either from the observations
        # from the batch, or from the full training dataset (recorded by the
        # unique_conditions attribute in the model object)
        self.all_cond_names = model.cond_names
        self.cond_names = list_exclude(self.all_cond_names, self.conds_to_fix | self.conds_to_null)
        if not self.query_mode.startswith("batch_"):
            unique_conditions = model.unique_conditions

        old_version_ckpt_flag = getattr(model, "unique_conditions") is None
        if self.query_mode.startswith("batch_") or old_version_ckpt_flag:
            if old_version_ckpt_flag:
                warnings.warn(
                    "The model ckpt being used is from an older version that "
                    "do not contain the 'unique_conditions'. Implicitly "
                    f"switching the query_mode from {self.query_mode} to "
                    f"batch_{self.query_mode}",
                    UserWarning,
                    stacklevel=2,
                )
            conditions_tensor = torch.cat([x_conds[k].unsqueeze(1) for k in self.all_cond_names], dim=1)
            unique_conditions = conditions_tensor.unique(dim=0)

        # Extract conditions to query
        valid_cond_idx = [i for i, j in enumerate(self.all_cond_names) if j in self.cond_names]
        unique_conditions = unique_conditions[:, valid_cond_idx].unique(dim=0)

        # Remove NULL conditions from query candidates
        unique_conditions = unique_conditions[~(unique_conditions == 0).any(1)]

        # The only difference is how we prepare the unique_conditions, which
        # we have handeled above
        query_mode = self.query_mode.replace("batch_", "")

        if query_mode == "all":
            individual_unique_conditions = [i.unique().tolist() for i in unique_conditions.T]
            return torch.tensor(
                list(itertools.product(*individual_unique_conditions)),
                device=unique_conditions.device,
            )
        elif query_mode == "seen":
            out = unique_conditions
        elif query_mode == "specified":
            if specified_conds is None:
                raise ValueError("query_mode set to 'specified' but did not passed specified conditions")
            elif not isinstance(specified_conds, torch.Tensor):
                raise TypeError(f"Please pass specified contions as tensor, got {type(specified_conds)}")
            elif specified_conds.shape[1] != unique_conditions:
                raise ValueError(
                    f"Inconsistent condition type number. Got {specified_conds.shape[1]} "
                    f"conditions in the specified conditions, but model only recorded "
                    f"{unique_conditions.shape[1]} conditions.",
                )
            # FIX: specified conds might not match cond_names
            out = specified_conds
        else:
            raise NotImplementedError(query_mode)

        return {i: j for i, j in zip(self.cond_names, out.T)}

    @staticmethod
    def get_assets_from_model(model):
        if hasattr(model, "diffusion_model"):
            diffusion_model = model.diffusion_model
        else:
            diffusion_model = model.model.diffusion_model

        if hasattr(model, "num_timesteps"):
            timesteps = model.num_timesteps
        else:
            timesteps = model.timesteps

        return diffusion_model, timesteps

    @torch.inference_mode()
    def eval_error(
        self,
        *,
        diffusion_model: torch.nn.Module,
        q_sample: Callable,
        x_orig: torch.Tensor,
        x_conds: Dict[str, torch.Tensor],
        query_conds: Dict[str, torch.Tensor],
        query_cond_idx: torch.Tensor,
        ts: List[int],
        loss: CLF_LOSS_TYPE = 'l2',
        inference_mask: Union[bool, str] = "all",
    ) -> torch.Tensor:
        device = x_orig.device
        pred_errors = torch.zeros(len(x_orig), device=device)
        x_empty = torch.zeros_like(x_orig)  # use decoder only (no context encoder)

        conditions = {}
        for i, j in x_conds.items():
            if i in query_conds:
                conditions[i] = query_conds[i][query_cond_idx]  # query
            elif i in self.conds_to_fix:
                conditions[i] = j  # fixed from input
            elif i in self.conds_to_null:
                conditions[i] = torch.zeros_like(j)  # fixed as null
            else:
                raise ValueError(f"Unknown conditions found in x_conds: {i}")

        for t in tqdm.tqdm(ts, leave=False, desc="Estimating errors"):
            t_input = torch.tensor([t], device=device)
            x_noised = q_sample(x_orig, t_input)
            if "mask_all" in inspect.signature(diffusion_model.forward).parameters:
                raise NotImplementedError("Not tested yet")
                pred, mask = diffusion_model(x_empty, x_noised, timesteps=t_input,
                                             pe_input=None, conditions=conditions,
                                             mask=False, mask_all=inference_mask)
            else:
                pred, mask = diffusion_model(x_empty, x_noised, timesteps=t_input,
                                             pe_input=None, conditions=conditions,
                                             mask=inference_mask)

            # UPDATE: full recon eval instead to align with training obj
            # # Only evaluate performance on masked entries
            # pred = pred * mask
            # x_orig = x_orig * mask

            if loss == 'l2':
                error = F.mse_loss(x_orig, pred, reduction='none').mean(1)
            elif loss == 'l1':
                error = F.l1_loss(x_orig, pred, reduction='none').mean(1)
            elif loss == 'huber':
                error = F.huber_loss(x_orig, pred, reduction='none').mean(1)
            else:
                raise NotImplementedError(f"Unknown loss type {loss!r}")

            pred_errors += error.detach()

        return (pred_errors / len(ts)).cpu()


# TODO: check why performs bad
class CellJumpClassifier:
    """Diffusion classifier object.

    From a high level, the classifier works as follow

        1. Use all unique conditions as candidate.
        2. Evaluate reconstruction errors of cells with different candidate
           conditions.
        3. Select the top k conditions for each cell that resulted in the
           lowest errors and use these as the new candidate conditions.
           Note that k gradually decrease from round to round, and eventually
           drops to 1 in the final round to select the best matching condition
           for each cell.
        4. Repeat 2 and 3 until the last round as reached.

    Args:
        n_samples_list: List of samples to draw from in each round.
        to_kee_list: List of top conditions to keep for next round of
            evaluation.
        n_trials: Number of trials per sampled time points.
        loss: Type of loss to evaluate the error of the predictions given a
            particular condition.
        query_mode: What conditions to query for. "all" uses all possible
            combinations of the conditions, "seen" uses only the combinations
            seen during training, and "specified" uses the combinations passed
            to the :attr:`conditions` argument. "batch_all" and "batch_seen"
            are analogous to "all" and "seen" but only select from the batch.
        model: Model object to bind to. See
            :class:`celldiff.models.diffusion.ddpm.DDPM` for an example.

    Notes:
        There are several requirements for the binding model to work with the
        :class:`DiffusionClassifier` object.

            1. There should be a :attr:`diffusion_model` (or
               :attr:`model.diffusion_model`) attribute, which holds the
               denoising diffusion model.
            2. There should be a :attr:`num_timesteps` attribute about the
               maximum time step size of the diffusion model.
            3. There should be a :meth:`q_sample` method that performs forward
               sampling of the original data, i.e., adding noise.

    """

    def __init__(
        self,
        n_samples_list: List[int],
        to_keep_list: List[int],
        n_trials: int = 1,
        query_mode: CLF_QUERY_MODE = "all",
        inference_mask: bool = False,
        time_step_sampler: TS_SAMPLER_TYPE = "IterativeUniform",
        model: Optional[Any] = None,
    ):
        assert len(n_samples_list) == len(to_keep_list)
        assert to_keep_list[-1] == 1, "Last trial must only select one best matching condition."
        self.n_samples_list = n_samples_list
        self.to_keep_list = to_keep_list
        self.n_trials = n_trials
        self.query_mode = check_str_option("query_mode", query_mode, CLF_QUERY_MODE)
        self.inference_mask = inference_mask
        self._model = model

    @property
    def get_time_step_sampler(self):
        return self._time_step_sampler_cls

    @get_time_step_sampler.setter
    def get_time_step_sampler(self, val):
        if val == "IterativeUniform":
            self._time_step_sampler_cls = IterativeUniformTimeStepSampler
        elif val not in (opts := get_args(TS_SAMPLER_TYPE)):
            raise ValueError(f"Unknown time step sampler {val!r}, available options are {opts}")
        else:
            raise NotImplementedError(f"[DEVERROR] please implement {val}")

    @property
    def model(self):
        return self._model

    def __call__(
        self,
        x: torch.Tensor,
        x_conds: torch.Tensor,
        model: Optional[Any] = None,
        specified_conds: Optional[torch.Tensor] = None,
    ):
        """Predict conditions of each cell.

        Args:
            x_orig: Original expression values to be used for generating noised
                expression input to the diffusion model.
            x_conds: Conditions of all cells (i.e., rows of x_orig). These
                are the label to be predicted against.
            model: Model object to bind to. See
                :class:`celldiff.models.diffusion.ddpm.DDPM` for an example.
                Only needed if model was not binded during initialization.
            specified_conds: Specify the conditions to query for in "specified"
                query mode.

        Returns:
            Predicted conditions for each cell.

        """
        # Hook up with model
        if (model is None) and (self.model is None):
            raise ValueError("Model object not stored during init, please pass during call.")
        if (model is not None) and (self.model is not None):
            warnings.warn(
                f"Model object already specified during init: {self.model} but "
                f"also passed during call: {model}. Using the passed model "
                "in this call. Please remove duplicated model specification "
                "to suppress this warning.",
                UserWarning,
                stacklevel=2,
            )
        model = model or self.model
        diffusion_model, timesteps = self.get_assets_from_model(model)
        conditions = self._get_conditions(model, x_conds, specified_conds)

        # Set constants
        max_n_samples = max(self.n_samples_list)
        num_cells = len(x)
        num_conditions = len(conditions)
        num_t_splits = len(self.n_samples_list)
        n_samples_list = self.n_samples_list
        to_keep_list = self.to_keep_list

        if num_conditions < (max_to_keep := max(to_keep_list)):
            warnings.warn(
                f"Maximum conditions to keep ({max_to_keep}) exceeds the total "
                f"number of conditions available ({num_conditions}). Implicitly "
                f"setting max number of conditions to keep to {num_conditions}.",
                UserWarning,
                stacklevel=2,
            )
            to_keep_list = [min(num_conditions, i) for i in to_keep_list]

        # TODO: support other samplers
        time_step_sampler = IterativeUniformTimeStepSampler(
            timesteps,
            max_n_samples,
            self.n_trials,
        )

        full_error_tensor = torch.full((num_t_splits, num_cells, num_conditions), torch.nan)
        best_idxs = torch.arange(num_conditions).repeat(num_cells, 1)
        for i, (n_samples, n_to_keep) in enumerate(zip(n_samples_list, to_keep_list)):
            error_tensor = torch.zeros_like(best_idxs, dtype=torch.float)
            curr_t_to_eval = time_step_sampler(n_samples)

            # Perform one round of evaluation
            for j in tqdm.trange(
                best_idxs.shape[1],
                leave=False,
                desc=f"Round {i + 1} / {len(n_samples_list)}",
            ):
                query_conds = conditions[best_idxs[:, j]]
                error_tensor[:, j] = self.eval_error(model, x, curr_t_to_eval, query_conds,
                                                     self.inference_mask)

            # Aggregate evaluation results across rounds
            if i == 0:
                full_error_tensor[i] = error_tensor
                prev_size = len(curr_t_to_eval)
            else:
                curr_size = len(curr_t_to_eval)
                full_error_tensor[i].scatter_(dim=1, index=best_idxs, src=error_tensor)
                # Aggregate errors with previous runs
                full_error_tensor[i] = (
                    (full_error_tensor[i] * curr_size + full_error_tensor[i-1] * prev_size)
                    / (curr_size + prev_size)
                )
                prev_size += curr_size

            # Find best conditions for each cell
            new_best_vals, new_best_idxs = error_tensor.topk(n_to_keep, dim=1, largest=False)
            assert not torch.isnan(new_best_vals).any(), "Found nans in selected entries."
            # Convert idx back to original condition idx
            best_idxs = best_idxs.gather(dim=1, index=new_best_idxs)

        pred_conditions = conditions[best_idxs.flatten()]

        return pred_conditions

    def _get_conditions(
        self,
        model: torch.nn.Module,
        x_conds: torch.Tensor,
        specified_conds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # First prepare the unique conditions, either from the observations
        # from the batch, or from the full training dataset (recorded by the
        # unique_conditions attribute in the model object)
        if not self.query_mode.startswith("batch_"):
            unique_conditions = model.unique_conditions

        old_version_ckpt_flag = getattr(model, "unique_conditions") is None
        if self.query_mode.startswith("batch_") or old_version_ckpt_flag:
            if old_version_ckpt_flag:
                warnings.warn(
                    "The model ckpt being used is from an older version that "
                    "do not contain the 'unique_conditions'. Implicitly "
                    f"switching the query_mode from {self.query_mode} to "
                    f"batch_{self.query_mode}",
                    UserWarning,
                    stacklevel=2,
                )
            unique_conditions = x_conds.unique(dim=0)

        # The only difference is how we prepare the unique_conditions, which
        # we have handeled above
        query_mode = self.query_mode.replace("batch_", "")

        if query_mode == "all":
            individual_unique_conditions = [i.unique().tolist() for i in unique_conditions.T]
            return torch.tensor(
                list(itertools.product(*individual_unique_conditions)),
                device=unique_conditions.device,
            )
        elif query_mode == "seen":
            return unique_conditions
        elif query_mode == "specified":
            if specified_conds is None:
                raise ValueError("query_mode set to 'specified' but did not passed specified conditions")
            elif not isinstance(specified_conds, torch.Tensor):
                raise TypeError(f"Please pass specified contions as tensor, got {type(specified_conds)}")
            elif specified_conds.shape[1] != unique_conditions:
                raise ValueError(
                    f"Inconsistent condition type number. Got {specified_conds.shape[1]} "
                    f"conditions in the specified conditions, but model only recorded "
                    f"{unique_conditions.shape[1]} conditions.",
                )
            return specified_conds
        else:
            raise NotImplementedError(query_mode)

    @staticmethod
    def get_assets_from_model(model):
        if hasattr(model, "diffusion_model"):
            diffusion_model = model.diffusion_model
        else:
            diffusion_model = model.model.diffusion_model

        if hasattr(model, "num_timesteps"):
            timesteps = model.num_timesteps
        else:
            timesteps = model.timesteps

        return diffusion_model, timesteps

    @staticmethod
    @torch.inference_mode()
    def eval_error(
        model: torch.nn.Module,
        x_orig: torch.Tensor,
        ts: List[int],
        conditions: torch.Tensor,
        inference_mask: bool = False,
    ) -> torch.Tensor:
        device = x_orig.device
        pred_errors = torch.zeros(len(x_orig), device=device)

        for t in tqdm.tqdm(ts, leave=False, desc="Estimating errors"):
            t_input = torch.tensor([t], device=device)
            error = model.get_loss(x_orig, t_input, conditions=conditions, mask_flag=inference_mask, w_diff=0)
            pred_errors += error.detach()

        return (pred_errors / len(ts)).cpu()
