import itertools
import math
from dataclasses import dataclass

import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn.functional as F
import torchmetrics
from torch import Tensor

import applications.drakes_dna.dataloader_gosai as dataloader_gosai
import applications.drakes_dna.models as models
import applications.drakes_dna.noise_schedule as noise_schedule
import applications.drakes_dna.utils as utils
import applications.drakes_dna.oracle as oracle
from scipy.stats import wasserstein_distance, pearsonr

LOG2 = math.log(2)
LOGGER = utils.get_logger(__name__)


def _sample_categorical(categorical_probs):
  gumbel_norm = (
    1e-10
    - (torch.rand_like(categorical_probs) + 1e-10).log())
  return (categorical_probs / gumbel_norm).argmax(dim=-1)

def _sample_categorical_gradient(categorical_probs, temp = 1.0):
  gumbel_norm = (
    1e-10
    - (torch.rand_like(categorical_probs) + 1e-10).log())
  output = torch.nn.functional.softmax((torch.log(categorical_probs)-torch.log(gumbel_norm))/temp, 2)
  return output

def _unsqueeze(x, reference):
  return x.view(
    * x.shape,
    * ((1,) * (len(reference.shape) - len(x.shape))))


@dataclass
class Loss:
  loss: torch.FloatTensor
  nlls: torch.FloatTensor
  token_mask: torch.FloatTensor


class NLL(torchmetrics.aggregation.MeanMetric):
  pass


class BPD(NLL):
  def compute(self) -> Tensor:
    """Computes the bits per dimension.

    Returns:
      bpd
    """
    return self.mean_value / self.weight / LOG2


class Perplexity(NLL):
  def compute(self) -> Tensor:
    """Computes the Perplexity.

    Returns:
     Perplexity
    """
    return torch.exp(self.mean_value / self.weight)


class Diffusion(L.LightningModule):
  def __init__(
    self,
    config,
    eval=False):
    super().__init__()
    self.save_hyperparameters()
    self.config = config
    self.vocab_size = 4
    self.sampler = self.config.sampling.predictor
    self.antithetic_sampling = self.config.training.antithetic_sampling
    self.importance_sampling = self.config.training.importance_sampling
    self.change_of_variables = self.config.training.change_of_variables
    self.mask_index = self.vocab_size
    self.vocab_size += 1
    self.parameterization = self.config.parameterization
    if self.config.backbone == 'cnn':
      self.backbone = models.dnaconv.CNNModel(
        self.config.model, alphabet_size=self.vocab_size, num_cls=3) # num_cls is not used since classifier is always set to False
    else:
      raise ValueError(
        f'Unknown backbone: {self.config.backbone}')

    self.T = self.config.T
    self.subs_masking = self.config.subs_masking

    self.softplus = torch.nn.Softplus()
    # metrics are automatically reset at end of epoch
    metrics = torchmetrics.MetricCollection({
      'nll': NLL(),
      'bpd': BPD(),
      'ppl': Perplexity(),
    })
    metrics.set_dtype(torch.float64)
    self.train_metrics = metrics.clone(prefix='train/')
    self.valid_metrics = metrics.clone(prefix='val/')
    self.test_metrics = metrics.clone(prefix='test/')

    # generative perplexity
    self.gen_ppl_metric = Perplexity()
    self.noise = noise_schedule.get_noise(self.config,
                                          dtype=self.dtype)
    if self.config.training.ema > 0:
      self.ema = models.ema.ExponentialMovingAverage(
        itertools.chain(self.backbone.parameters(),
                        self.noise.parameters()),
        decay=self.config.training.ema)
    else:
      self.ema = None
    
    self.lr = self.config.optim.lr
    self.sampling_eps = self.config.training.sampling_eps
    self.time_conditioning = self.config.time_conditioning
    self.neg_infinity = -1000000.0
    self.fast_forward_epochs = None
    self.fast_forward_batches = None
    self._validate_configuration()

    # subset of data for evaluation
    if eval:
      self.eval_sets_sp = oracle.subset_for_eval(n=config.eval.subset_size) 
      self.eval_sets_sp_clss = oracle.subset_eval_groundtruth(self.eval_sets_sp)
      self.eval_sets_sp_preds = oracle.subset_eval_preds(self.eval_sets_sp) 
      self.eval_sets_sp_kmers = oracle.subset_eval_kmers(self.eval_sets_sp) 
      self.emb_pca = oracle.cal_emb_pca(oracle.subset_for_eval(n=40000), n_components=50)
      self.eval_sets_sp_embs_pca = oracle.subset_eval_embs_pca(self.eval_sets_sp, self.emb_pca) 
  
  def _validate_configuration(self):
    assert not (self.change_of_variables
                and self.importance_sampling)
    assert self.parameterization == 'subs'

  def on_load_checkpoint(self, checkpoint):
    if self.ema:
      self.ema.load_state_dict(checkpoint['ema'])
    # Copied from:
    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
    self.fast_forward_epochs = checkpoint['loops'][
      'fit_loop']['epoch_progress']['current']['completed']
    self.fast_forward_batches = checkpoint['loops'][
      'fit_loop']['epoch_loop.batch_progress'][
        'current']['completed']

  def on_save_checkpoint(self, checkpoint):
    if self.ema:
      checkpoint['ema'] = self.ema.state_dict()
    # Copied from:
    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
    # ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration
    # behind, so we're using the optimizer's progress.
    checkpoint['loops']['fit_loop'][
      'epoch_loop.batch_progress']['total'][
        'completed'] = checkpoint['loops']['fit_loop'][
          'epoch_loop.automatic_optimization.optim_progress'][
            'optimizer']['step']['total'][
              'completed'] * self.trainer.accumulate_grad_batches
    checkpoint['loops']['fit_loop'][
      'epoch_loop.batch_progress']['current'][
        'completed'] = checkpoint['loops']['fit_loop'][
          'epoch_loop.automatic_optimization.optim_progress'][
            'optimizer']['step']['current'][
              'completed'] * self.trainer.accumulate_grad_batches
    # _batches_that_stepped tracks the number of global steps, not the number
    # of local steps, so we don't multiply with self.trainer.accumulate_grad_batches here.
    checkpoint['loops']['fit_loop'][
      'epoch_loop.state_dict'][
        '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
          'epoch_loop.automatic_optimization.optim_progress'][
            'optimizer']['step']['total']['completed']
    if 'sampler' not in checkpoint.keys():
      checkpoint['sampler'] = {}
    if hasattr(self.trainer.train_dataloader.sampler,
               'state_dict'):
      sampler_state_dict = self.trainer.\
        train_dataloader.sampler.state_dict()
      checkpoint['sampler'][
        'random_state'] = sampler_state_dict.get(
          'random_state', None)
    else:
      checkpoint['sampler']['random_state'] = None

  def on_train_start(self):
    if self.ema:
      self.ema.move_shadow_params_to_device(self.device)
    # Adapted from:
    # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
    distributed = (
      self.trainer._accelerator_connector.use_distributed_sampler
      and self.trainer._accelerator_connector.is_distributed)
    
    print('distributed:', distributed)
    # TODO: need to check these two functions
    if distributed:
      sampler_cls = dataloader_gosai.FaultTolerantDistributedSampler
    else:
      sampler_cls = dataloader_gosai.RandomFaultTolerantSampler
    
    updated_dls = []
    for dl in self.trainer.fit_loop._combined_loader.flattened:
      if hasattr(dl.sampler, 'shuffle'):
        dl_sampler = sampler_cls(
          dl.dataset, shuffle=dl.sampler.shuffle)
      else:
        dl_sampler = sampler_cls(dl.dataset)
      if (distributed
          and self.fast_forward_epochs is not None
          and self.fast_forward_batches is not None):
        dl_sampler.load_state_dict({
          'epoch': self.fast_forward_epochs,
          'counter': (self.fast_forward_batches
                      * self.config.loader.batch_size)})
      updated_dls.append(
        torch.utils.data.DataLoader(
          dl.dataset,
          batch_size=self.config.loader.batch_size,
          num_workers=self.config.loader.num_workers,
          pin_memory=self.config.loader.pin_memory,
          sampler=dl_sampler,
          shuffle=False,
          persistent_workers=True))
    self.trainer.fit_loop._combined_loader.flattened = updated_dls

  def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)
    if self.ema:
      self.ema.update(itertools.chain(
        self.backbone.parameters(),
        self.noise.parameters()))

  def _subs_parameterization(self, logits, xt):
    logits[:, :, self.mask_index] += self.neg_infinity
    logits = logits - torch.logsumexp(logits, dim=-1,
                                      keepdim=True)
    if xt.ndim > 2 and xt.shape[-1] == self.vocab_size:
      # this is for finetuning setting when the input is one-hot encoded or probs
      xt = xt.argmax(dim=-1)
    unmasked_indices = (xt != self.mask_index)
    logits[unmasked_indices] = self.neg_infinity
    logits[unmasked_indices, xt[unmasked_indices]] = 0
    return logits

  def _process_sigma(self, sigma):
    if sigma is None:
      assert self.parameterization == 'ar'
      return sigma
    if sigma.ndim > 1:
      sigma = sigma.squeeze(-1)
    if not self.time_conditioning:
      sigma = torch.zeros_like(sigma)
    assert sigma.ndim == 1, sigma.shape
    return sigma

  def forward(self, x, sigma):
    """Returns log score."""
    sigma = self._process_sigma(sigma)

    with torch.cuda.amp.autocast(dtype=torch.float32):
      logits = self.backbone(x, sigma)
    # print('logits shape', logits.shape)
    if self.parameterization == 'subs':
      return self._subs_parameterization(logits=logits,
                                         xt=x)
    return logits

  def _compute_loss(self, batch, prefix):
    if 'attention_mask' in batch:
      attention_mask = batch['attention_mask']
    else:
      attention_mask = None
    losses = self._loss(batch['seqs'], attention_mask)
    loss = losses.loss

    if prefix == 'train':
      self.train_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.train_metrics
    elif prefix == 'val':
      self.valid_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.valid_metrics
    elif prefix == 'test':
      self.test_metrics.update(losses.nlls, losses.token_mask)
      metrics = self.test_metrics
    else:
      raise ValueError(f'Invalid prefix: {prefix}')

    self.log_dict(metrics,
                  on_step=False,
                  on_epoch=True,
                  sync_dist=True)
    return loss

  def on_train_epoch_start(self):
    self.backbone.train()
    self.noise.train()

  def training_step(self, batch, batch_idx):
    loss = self._compute_loss(batch, prefix='train')
    self.log(name='trainer/loss',
             value=loss.item(),
             on_step=True,
             on_epoch=False,
             sync_dist=True)
    return loss

  def on_validation_epoch_start(self):
    if self.ema:
      self.ema.store(itertools.chain(
        self.backbone.parameters(),
        self.noise.parameters()))
      self.ema.copy_to(itertools.chain(
        self.backbone.parameters(),
        self.noise.parameters()))
    self.backbone.eval()
    self.noise.eval()
    assert self.valid_metrics.nll.mean_value == 0
    assert self.valid_metrics.nll.weight == 0

  def validation_step(self, batch, batch_idx):
    return self._compute_loss(batch, prefix='val')

  def on_validation_epoch_end(self):
    if ((self.config.eval.compute_perplexity_on_sanity
         or not self.trainer.sanity_checking)
         and self.config.eval.generate_samples
         and not self.parameterization == 'ar'):
      all_samples, all_detoeknized_samples = [], []
      for _ in range(
        self.config.sampling.num_sample_batches):
        samples = self._sample().detach().cpu().numpy()
        detokenized_samples = dataloader_gosai.batch_dna_detokenize(samples)
        all_samples.append(samples)
        all_detoeknized_samples.extend(detokenized_samples)
      all_samples = np.concatenate(all_samples, axis=0)
      ws_distance_dict = self.cal_wasserstein_distance(all_detoeknized_samples)
      pearsonr_list = self.cal_kmer_pearsonr(all_detoeknized_samples)
      ws_embpca_list = self.cal_ws_distance_embpca(all_detoeknized_samples)
      
      current_step = self.trainer.global_step
      LOGGER.info(f'Current step: {current_step}')
      LOGGER.info(f'Wasserstein distance: {ws_distance_dict}')
      LOGGER.info(f'3mer Pearsonr: {pearsonr_list}')
      LOGGER.info(f'Wasserstein distance embpca: {ws_embpca_list}')
      self.log('val/3mer_pearsonr', pearsonr_list, on_step=False, on_epoch=True, sync_dist=True)
      self.log('val/ws_embpca', ws_embpca_list, on_step=False, on_epoch=True, sync_dist=True)

      for key in ws_distance_dict:
        for cell_type in ws_distance_dict[key]:
          metric_values = ws_distance_dict[key][cell_type]
          if metric_values:  # Check if the list is not empty
              # Assuming metric_values contains [train_metric, valid_metric, test_metric]
              self.log(f'val/{key}_{cell_type}', metric_values[0], on_step=False, on_epoch=True, sync_dist=True)

    if self.ema:
      self.ema.restore(
        itertools.chain(self.backbone.parameters(),
                        self.noise.parameters()))
      
  def cal_wasserstein_distance(self, seqs):
    generated_preds = oracle.cal_gosai_pred_new(seqs)
    ws_distance_dict = {'truth': {'hepg2': [], 'k562': [], 'sknsh': []}, 
                        'preds': {'hepg2': [], 'k562': [], 'sknsh': []}} 
    ws_distance_dict['truth']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_clss[:, 0]))
    ws_distance_dict['truth']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_clss[:, 1]))
    ws_distance_dict['truth']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_clss[:, 2]))   
    ws_distance_dict['preds']['hepg2'].append(wasserstein_distance(generated_preds[:, 0], self.eval_sets_sp_preds[:, 0]))
    ws_distance_dict['preds']['k562'].append(wasserstein_distance(generated_preds[:, 1], self.eval_sets_sp_preds[:, 1]))
    ws_distance_dict['preds']['sknsh'].append(wasserstein_distance(generated_preds[:, 2], self.eval_sets_sp_preds[:, 2])) 
    return ws_distance_dict

  def cal_ws_distance_embpca(self, seqs):
    generated_embs = oracle.cal_gosai_emb(seqs)
    generated_embs_pca = self.emb_pca.transform(generated_embs.reshape(generated_embs.shape[0], -1))
    return oracle.get_wasserstein_dist(generated_embs_pca, self.eval_sets_sp_embs_pca)
  
  def compare_kmer(self, kmer1, kmer2, n_sp1, n_sp2):
    kmer_set = set(kmer1.keys()) | set(kmer2.keys())
    counts = np.zeros((len(kmer_set), 2))
    for i, kmer in enumerate(kmer_set):
        if kmer in kmer1:
            counts[i][1] = kmer1[kmer] * n_sp2 / n_sp1
        if kmer in kmer2:
            counts[i][0] = kmer2[kmer]
    return pearsonr(counts[:, 0], counts[:, 1])[0]

  def cal_kmer_pearsonr(self, seqs):
    generated_kmer = oracle.count_kmers(seqs)
    return self.compare_kmer(self.eval_sets_sp_kmers, generated_kmer, self.config.eval.subset_size, len(seqs))


  def configure_optimizers(self):
    # TODO(yair): Lightning currently giving this warning when using `fp16`:
    #  "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
    #  Not clear if this is a problem or not.
    #  See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
    optimizer = torch.optim.AdamW(
      itertools.chain(self.backbone.parameters(),
                      self.noise.parameters()),
      lr=self.config.optim.lr,
      betas=(self.config.optim.beta1,
             self.config.optim.beta2),
      eps=self.config.optim.eps,
      weight_decay=self.config.optim.weight_decay)

    scheduler = hydra.utils.instantiate(
      self.config.lr_scheduler, optimizer=optimizer)
    scheduler_dict = {
      'scheduler': scheduler,
      'interval': 'step',
      'monitor': 'val/loss',
      'name': 'trainer/lr',
    }
    return [optimizer], [scheduler_dict]

  def q_xt(self, x, move_chance):
    """Computes the noisy sample xt.

    Args:
      x: int torch.Tensor with shape (batch_size,
          diffusion_model_input_length), input. 
      move_chance: float torch.Tensor with shape (batch_size, 1).
    """
    move_indices = torch.rand(
      * x.shape, device=x.device) < move_chance
    xt = torch.where(move_indices, self.mask_index, x)
    return xt

  def _sample_prior(self, *batch_dims):
    return self.mask_index * torch.ones(
      * batch_dims, dtype=torch.int64)

  def _ddpm_caching_update(self, x, t, dt, p_x0=None):
    assert self.config.noise.type == 'loglinear'
    sigma_t, _ = self.noise(t)
    if t.ndim > 1:
      t = t.squeeze(-1)
    assert t.ndim == 1
    move_chance_t = t[:, None, None]
    move_chance_s = (t - dt)[:, None, None]
    assert move_chance_t.ndim == 3, move_chance_t.shape
    if p_x0 is None:
      p_x0 = self.forward(x, sigma_t).exp()
    
    assert move_chance_t.ndim == p_x0.ndim
    q_xs = p_x0 * (move_chance_t - move_chance_s)
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    _x = _sample_categorical(q_xs)
    
    copy_flag = (x != self.mask_index).to(x.dtype)
    return p_x0, copy_flag * x + (1 - copy_flag) * _x

  def _ddpm_update(self, x, t, dt, return_process=False):
    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigma_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t) # t
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    q_xs = log_p_x0.exp() * (move_chance_t
                             - move_chance_s)
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    _x = _sample_categorical(q_xs)
    copy_flag = (x != self.mask_index).to(x.dtype)
    if return_process:
      return copy_flag * x + (1 - copy_flag) * _x, x, unet_conditioning, move_chance_t, copy_flag
    else:
      return copy_flag * x + (1 - copy_flag) * _x
  
  def _ar_sampler(self, bsz):
    # precompute token buffer
    num_pred_tokens = self.config.model.length - 1
    x = torch.zeros(
      (bsz, num_pred_tokens + 1),
      dtype=torch.long,
      device=self.device)
    x[:, 0] = self.tokenizer.bos_token_id
    # precompute noise
    noise = (torch.distributions.Gumbel(0, 1)
             .sample((bsz, num_pred_tokens, self.vocab_size))
             .to(self.device))
    for i in range(num_pred_tokens):
      next_logits = self.forward(x[:, :i + 1], None)[:, -1]
      y = (next_logits + noise[:, i]).argmax(-1)
      x[:, i + 1] = y
    return x

  @torch.no_grad()
  def _sample(self, num_steps=None, eps=1e-5, eval_sp_size=None):
    """Generate samples from the model."""
    if eval_sp_size is None:
      batch_size_per_gpu = self.config.loader.eval_batch_size
    else:
      batch_size_per_gpu = eval_sp_size
    if self.parameterization == 'ar':
      return self._ar_sampler(batch_size_per_gpu)
    # Lightning auto-casting is not working in this method for some reason
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = self._sample_prior(
      batch_size_per_gpu,
      self.config.model.length).to(self.device)
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        x = self._ddpm_update(x, t, dt)
      elif self.sampler == 'ddpm_cache':
        p_x0_cache, x_next = self._ddpm_caching_update(
          x, t, dt, p_x0=p_x0_cache)
        if (not torch.allclose(x_next, x)
            or self.time_conditioning):
          p_x0_cache = None
        x = x_next
      else:
        x = self._analytic_update(x, t, dt)

    if self.config.sampling.noise_removal:
      t = timesteps[-1] * torch.ones(x.shape[0], 1,
                                     device=self.device)
      if self.sampler == 'analytic':
        x = self._denoiser_update(x, t)
      else:
        unet_conditioning = self.noise(t)[0]
        logits = self.forward(x, unet_conditioning)
        x = logits[:, :, :-1].argmax(dim=-1)
    return x
  
  def _ddpm_update_finetune_gradient(self, x, t, dt, copy_flag_temp, return_process=False):
    
    if x.ndim == 2 or x.shape[-1] != self.vocab_size:
      x = F.one_hot(x, num_classes=self.vocab_size).to(torch.float32)

    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigma_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    q_xs = log_p_x0.exp() * (move_chance_t
                             - move_chance_s)
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    _x = _sample_categorical_gradient(q_xs, temp=self.config.finetuning.gumbel_softmax_temp)
    
    if copy_flag_temp is not None:
      copy_flag_prob = 1 - x[:, :, self.mask_index].unsqueeze(-1)
      soft_copy_flag = torch.nn.functional.sigmoid(copy_flag_prob/copy_flag_temp)
    else:
      soft_copy_flag = 1 - x[:, :, self.mask_index].unsqueeze(-1)

    if return_process:
      return soft_copy_flag * x + (1 - soft_copy_flag) * _x, x, unet_conditioning, move_chance_t, soft_copy_flag
    else:
      return soft_copy_flag * x + (1 - soft_copy_flag) * _x
    
   
  def _sample_finetune_gradient(self, num_steps=None, eps=1e-5, eval_sp_size=None, copy_flag_temp=None):
    """Generate samples from the model."""
    assert self.parameterization == 'subs' and self.sampler == 'ddpm'
    if eval_sp_size is None:
      batch_size_per_gpu = self.config.loader.eval_batch_size
    else:
      batch_size_per_gpu = eval_sp_size
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = self._sample_prior(
      batch_size_per_gpu,
      self.config.model.length).to(self.device)
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    last_x_list = []
    condt_list = []
    move_chance_t_list = []
    copy_flag_list = []

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        if i < num_steps - self.config.finetuning.truncate_steps:
          x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update(x, t, dt, return_process=True)
          x = x.detach()
          copy_flag = copy_flag.unsqueeze(-1)
          last_x = F.one_hot(last_x, num_classes=self.vocab_size).to(torch.float32).detach()
        else: 
          x, last_x, condt, move_chance_t, copy_flag = self._ddpm_update_finetune_gradient(x, t, dt, copy_flag_temp, return_process=True)
      
      last_x_list.append(last_x)
      condt_list.append(condt)
      move_chance_t_list.append(move_chance_t)
      copy_flag_list.append(copy_flag)

    x_argmax = x[:, :, :-1].argmax(dim=-1)
    x_argmax = torch.nn.functional.one_hot(x_argmax, num_classes=self.vocab_size-1).to(torch.float32)
    return x[:, :, :-1] + (x_argmax - x[:, :, :-1]).detach(), last_x_list, condt_list, move_chance_t_list, copy_flag_list
  
  @torch.no_grad()
  def _ddpm_update_finetune_controlled_SMC(self, x, t, dt, reward_model, alpha = 1.0):

    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigma_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t)
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    q_xs = log_p_x0.exp() * (move_chance_t
                             - move_chance_s)
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    copy_flag = (x != self.mask_index).to(x.dtype)
    sample = copy_flag * x + (1 - copy_flag) * _sample_categorical(q_xs)
    '''
    Calcualte exp(v_{t-1}(x_{t-1})/alpha)
    '''
    expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
    expected_x0_arg = torch.argmax(expected_x0,dim=2)
    expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
    reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
    '''
    Calcualte exp(v_{t}(x_{t})/alpha)
    '''
    expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
    expected_x0_arg = torch.argmax(expected_x0,dim=2) 
    expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
    reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
  
    ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) # Now calculate exp( (v_{t-1}(x_{t-1) -v_{t}(x_{t}) /alpha) 
    ratio = ratio.detach().cpu().numpy()
    final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p =  ratio/ratio.sum() ) 
   
    return sample[final_sample_indices]
  
  def _ddpm_update_finetune_controlled_CG(self, x, t, dt, reward_model,  guidance_scale):

    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigma_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t)
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    q_xs = log_p_x0.exp() * (move_chance_t
                             - move_chance_s)
    x_onehot = F.one_hot(x, num_classes=5).float()

    x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
    guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    q_xs = q_xs * guidance.exp()

    _x = _sample_categorical(q_xs)
    copy_flag = (x != self.mask_index).to(x.dtype)
    return copy_flag * x + (1 - copy_flag) * _x 

  def compute_gradient_CG(self, x_onehot, x, reward_model, sigma_s):
    x_onehot.requires_grad_(True)
    expected_x0 = self.forward(x_onehot, sigma_s) # Calcualte E[x_0|x_t]
    scores = reward_model(expected_x0.transpose(1, 2)[:,0:4,:])[:, 0]
    scores = scores.mean()
    scores.backward()
    x_grad = x_onehot.grad.clone()
    return x_grad

  def _ddpm_update_finetune_controlled_TDS(self, x, t, dt, reward_model, alpha = 1.0, guidance_scale=1000):
    # SMC with the twisted proposal

    sigma_t, _ = self.noise(t)
    sigma_s, _ = self.noise(t - dt)
    if sigma_t.ndim > 1:
      sigma_t = sigma_t.squeeze(-1)
    if sigma_s.ndim > 1:
      sigma_s = sigma_s.squeeze(-1)
    assert sigma_t.ndim == 1, sigma_t.shape
    assert sigma_s.ndim == 1, sigma_s.shape
    move_chance_t = 1 - torch.exp(-sigma_t)
    move_chance_s = 1 - torch.exp(-sigma_s)
    move_chance_t = move_chance_t[:, None, None]
    move_chance_s = move_chance_s[:, None, None]
    unet_conditioning = sigma_t
    log_p_x0 = self.forward(x, unet_conditioning)
    assert move_chance_t.ndim == log_p_x0.ndim
    q_xs = log_p_x0.exp() * (move_chance_t
                             - move_chance_s)
    x_onehot = F.one_hot(x, num_classes=5).float()

    x_grad = self.compute_gradient_CG(x_onehot, x, reward_model, sigma_s )
    guidance = guidance_scale * (x_grad - x_grad[:, :, self.mask_index][:, :, None])
    q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
    # print(q_xs.sum(-1))
    q_xs = q_xs * guidance.exp()

    _x = _sample_categorical(q_xs)
    copy_flag = (x != self.mask_index).to(x.dtype)
    sample = copy_flag * x + (1 - copy_flag) * _x
    prob_multiplier = (1 - copy_flag) * torch.gather(guidance.exp(), 2, _x.unsqueeze(-1)).squeeze(-1) + copy_flag * torch.ones_like(_x)
    '''
    Calcualte exp(v_{t-1}(x_{t-1})/alpha)
    '''
    expected_x0 = self.forward(sample, sigma_s) # Calcualte E[x_0|x_{t-1}]
    expected_x0_arg = torch.argmax(expected_x0,dim=2)
    expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
    reward_num = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
    '''
    Calcualte exp(v_{t}(x_{t})/alpha)
    '''
    expected_x0 = self.forward(x, sigma_s) # Calcualte E[x_0|x_t]
    expected_x0_arg = torch.argmax(expected_x0,dim=2) 
    expected_x0_onehot = torch.nn.functional.one_hot(expected_x0_arg)
    reward_den = reward_model(expected_x0_onehot.float().transpose(1, 2)).detach()[:, 0][:, 0]
    
    # set the nan values to 1
    prob_multiplier[torch.isnan(prob_multiplier)] = 1
    ratio = torch.exp(1.0/alpha * (reward_num - reward_den)) / prob_multiplier.prod(dim=-1)
    ratio = ratio.detach().cpu().numpy()
    final_sample_indices = np.random.choice(reward_num.shape[0], reward_num.shape[0], p =  ratio/ratio.sum() ) 
   
    return sample[final_sample_indices]
  
  @torch.no_grad()
  def controlled_sample_SMC(self, reward_model, alpha, num_steps=None, eps=1e-5, eval_sp_size=None):
    """Generate samples from the model."""
    if eval_sp_size is None:
      batch_size_per_gpu = self.config.loader.eval_batch_size
    else:
      batch_size_per_gpu = eval_sp_size
    if self.parameterization == 'ar':
      return self._ar_sampler(batch_size_per_gpu)
    # Lightning auto-casting is not working in this method for some reason
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = self._sample_prior(
      batch_size_per_gpu,
      self.config.model.length).to(self.device)
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        x  = self._ddpm_update_finetune_controlled_SMC(x, t, dt, reward_model, alpha)
      else:
        x = self._analytic_update(x, t, dt)

    if self.config.sampling.noise_removal:
      t = timesteps[-1] * torch.ones(x.shape[0], 1,
                                     device=self.device)
      if self.sampler == 'analytic':
        x = self._denoiser_update(x, t)
      else:
        unet_conditioning = self.noise(t)[0]
        logits = self.forward(x, unet_conditioning)
        x = logits[:, :, :-1].argmax(dim=-1)
    return x

  def controlled_sample_CG(self, reward_model, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
    """Generate samples from the model."""
    if eval_sp_size is None:
      batch_size_per_gpu = self.config.loader.eval_batch_size
    else:
      batch_size_per_gpu = eval_sp_size
    if self.parameterization == 'ar':
      return self._ar_sampler(batch_size_per_gpu)
    # Lightning auto-casting is not working in this method for some reason
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = self._sample_prior(
      batch_size_per_gpu,
      self.config.model.length).to(self.device)
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        x  = self._ddpm_update_finetune_controlled_CG(x, t, dt, reward_model, guidance_scale)
      else:
        x = self._analytic_update(x, t, dt)

    if self.config.sampling.noise_removal:
      t = timesteps[-1] * torch.ones(x.shape[0], 1,
                                     device=self.device)
      if self.sampler == 'analytic':
        x = self._denoiser_update(x, t)
      else:
        unet_conditioning = self.noise(t)[0]
        logits = self.forward(x, unet_conditioning)
        x = logits[:, :, :-1].argmax(dim=-1)
    return x

  def controlled_sample_TDS(self, reward_model, alpha, guidance_scale, num_steps=None, eps=1e-5, eval_sp_size=None):
    """Generate samples from the model."""
    if eval_sp_size is None:
      batch_size_per_gpu = self.config.loader.eval_batch_size
    else:
      batch_size_per_gpu = eval_sp_size
    if self.parameterization == 'ar':
      return self._ar_sampler(batch_size_per_gpu)
    # Lightning auto-casting is not working in this method for some reason
    if num_steps is None:
      num_steps = self.config.sampling.steps
    x = self._sample_prior(
      batch_size_per_gpu,
      self.config.model.length).to(self.device)
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    p_x0_cache = None

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(
        x.shape[0], 1, device=self.device)
      if self.sampler == 'ddpm':
        x  = self._ddpm_update_finetune_controlled_TDS(x, t, dt, reward_model,alpha, guidance_scale)
      else:
        x = self._analytic_update(x, t, dt)

    if self.config.sampling.noise_removal:
      t = timesteps[-1] * torch.ones(x.shape[0], 1,
                                     device=self.device)
      if self.sampler == 'analytic':
        x = self._denoiser_update(x, t)
      else:
        unet_conditioning = self.noise(t)[0]
        logits = self.forward(x, unet_conditioning)
        x = logits[:, :, :-1].argmax(dim=-1)
    return x

  @torch.no_grad()
  def get_likelihood(self, x0, num_steps=None, eps=1e-5, n_samples=1):
    """Compute the likelihood of a sequence under the model.
    x0: int torch.Tensor with shape (batch_size,
          diffusion_model_input_length)
    """
    if num_steps is None:
      num_steps = self.config.sampling.steps
    timesteps = torch.linspace(
      1, eps, num_steps + 1, device=self.device) # t=0 is clean data
    dt = (1 - eps) / num_steps
    log_p_sample_list = []
    for _ in range(n_samples):
      log_p_at_time_list = []
      for i in range(num_steps):
        t = timesteps[i] * torch.ones(
          x0.shape[0], 1, device=self.device)
        sigma_t, _ = self.noise(t)
        sigma_s, _ = self.noise(t - dt)
        if sigma_t.ndim > 1:
          sigma_t = sigma_t.squeeze(-1)
        if sigma_s.ndim > 1:
          sigma_s = sigma_s.squeeze(-1)
        assert sigma_t.ndim == 1, sigma_t.shape
        assert sigma_s.ndim == 1, sigma_s.shape
        move_chance_t = 1 - torch.exp(-sigma_t) # (1-eps)*t
        move_chance_s = 1 - torch.exp(-sigma_s)
        move_chance_t = move_chance_t[:, None] # [bsz, 1]
        move_chance_s = move_chance_s[:, None]
        unet_conditioning = sigma_t # [bsz]
        multiplier = (move_chance_t - move_chance_s)/move_chance_t # [bsz, 1]
        xt = self.q_xt(x0, move_chance_t) # [bsz, seq_len]
        # log prob, already apply subs parametrization (unmasked token remains unchanged)
        model_output = self.forward(xt, unet_conditioning) # [bsz, seq_len, vocab_size]
        # take the log prob of the token that corresponds to x0
        log_p_x0 = model_output.gather(-1, x0[..., None]).squeeze(-1) # [bsz, seq_len]
        log_p_x0 = log_p_x0 * multiplier
        log_p_at_time_list.append(log_p_x0)
      log_p_x0 = torch.stack(log_p_at_time_list, dim=0).sum(dim=0) # [bsz, seq_len]
      log_p_sample_list.append(log_p_x0.sum(dim=-1))
    log_p_sample = torch.stack(log_p_sample_list, dim=0).mean(dim=0)
    return log_p_sample

  def get_score(self, x, sigma):
    model_output = self.forward(x, sigma)
    if self.parameterization == 'subs':
      # score(x, t) = p_t(y) / p_t(x)
      # => log score(x, t) = log p_t(y) - log p_t(x)
      
      # case 1: x = masked
      #   (i) y = unmasked
      #     log score(x, t) = log p_\theta(x)|_y + log k
      #     where k = exp(- sigma) / (1 - exp(- sigma))
      #   (ii) y = masked
      #     log score(x, t) = 0

      # case 2: x = unmasked
      #   (i) y != masked, y != x
      #     log score(x_i, t) = - inf
      #   (ii) y = x 
      #     log score(x_i, t) = 0
      #   (iii) y = masked token
      #     log score(x_i, t) = - log k
      #     where k = exp(- sigma) / (1 - exp(- sigma))
      
      log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
      assert log_k.ndim == 1
      
      masked_score = model_output + log_k[:, None, None]
      masked_score[:, :, self.mask_index] = 0

      unmasked_score = self.neg_infinity * torch.ones_like(
        model_output)
      unmasked_score = torch.scatter(
        unmasked_score,
        -1,
        x[..., None],
        torch.zeros_like(unmasked_score[..., :1]))
      unmasked_score[:, :, self.mask_index] = - (
        log_k[:, None] * torch.ones_like(x))
      
      masked_indices = (x == self.mask_index).to(
        model_output.dtype)[:, :, None]
      model_output = (
        masked_score * masked_indices
        + unmasked_score * (1 - masked_indices))
    return model_output.exp()

  def _staggered_score(self, score, dsigma):
    score = score.clone()
    extra_const = (1 - dsigma.exp()) * score.sum(dim=-1)
    score *= dsigma.exp()[:, None]
    score[..., self.mask_index] += extra_const
    return score

  def _analytic_update(self, x, t, step_size):
    curr_sigma, _ = self.noise(t)
    next_sigma, _ = self.noise(t - step_size)
    dsigma = curr_sigma - next_sigma
    score = self.get_score(x, curr_sigma)
    stag_score = self._staggered_score(score, dsigma)
    probs = stag_score * self._transp_transition(x, dsigma)
    return _sample_categorical(probs)

  def _denoiser_update(self, x, t):
    sigma, _ = self.noise(t)
    score = self.get_score(x, sigma)
    stag_score = self._staggered_score(score, sigma)
    probs = stag_score * self._transp_transition(x, sigma)
    probs[..., self.mask_index] = 0
    samples = _sample_categorical(probs)
    return samples

  def _transp_transition(self, i, sigma):
    sigma = _unsqueeze(sigma, reference=i[..., None])
    edge = torch.exp(-sigma) * F.one_hot(
      i, num_classes=self.vocab_size)
    edge += torch.where(i == self.mask_index,
                        1 - torch.exp(-sigma).squeeze(-1),
                        0)[..., None]
    return edge

  def _sample_t(self, n, device):
    _eps_t = torch.rand(n, device=device)
    if self.antithetic_sampling:
      # for variance reduction
      offset = torch.arange(n, device=device) / n
      _eps_t = (_eps_t / n + offset) % 1
    t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
    if self.importance_sampling:
      return self.noise.importance_sampling_transformation(t)
    return t

  def _maybe_sub_sample(self, x0, attention_mask):
    seqlen = x0.shape[1]
    if seqlen > self.config.model.length:
      raise NotImplementedError('Sub-sampling not implemented')
    elif self.parameterization == 'ar':
      input_tokens = x0[:, :-1]
      output_tokens = x0[:, 1:]
      new_attention_mask = attention_mask[:, 1:]
    else:
      input_tokens = x0
      output_tokens = None
      new_attention_mask = attention_mask
    return input_tokens, output_tokens, new_attention_mask

  def _reconstruction_loss(self, x0):
    t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
                     device=self.device)
    assert self.config.noise.type == 'loglinear'
    # The above assert is for d3pm parameterization
    unet_conditioning = self.noise(t0)[0][:, None]
    model_output_t0 = self.forward(x0, unet_conditioning)
    return - torch.gather(input=model_output_t0,
                          dim=-1,
                          index=x0[:, :, None]).squeeze(-1)

  def _forward_pass_diffusion(self, x0):
    t = self._sample_t(x0.shape[0], x0.device)
    if self.T > 0:
      # else ts are between 0 and 1
      t = (t * self.T).to(torch.int)
      t = t / self.T
      # t \in {1/T, 2/T, ..., 1}
      t += (1 / self.T)

    if self.change_of_variables: # False
      unet_conditioning = t[:, None]
      f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
      f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
      move_chance = torch.exp(f_0 + t * (f_T - f_0))
      move_chance = move_chance[:, None]
    else:
      sigma, dsigma = self.noise(t) # total noise, rate noise
      unet_conditioning = sigma[:, None]
      move_chance = 1 - torch.exp(-sigma[:, None])

    xt = self.q_xt(x0, move_chance) # q(xt|x0)
    model_output = self.forward(xt, unet_conditioning)
    # print(model_output.shape)
    utils.print_nans(model_output, 'model_output')

    if self.parameterization == 'sedd':
      return dsigma[:, None] * self._score_entropy(
        model_output, sigma[:, None], xt, x0)
    
    if self.T > 0:
      diffusion_loss = self._d3pm_loss(
        model_output=model_output, xt=xt, x0=x0, t=t)
      if self.parameterization == 'd3pm':
        reconstruction_loss = self._reconstruction_loss(x0)
      elif self.parameterization == 'subs':
        reconstruction_loss = 0
      return reconstruction_loss + diffusion_loss
    
    # SUBS parameterization, continuous time.
    log_p_theta = torch.gather(
      input=model_output,
      dim=-1,
      index=x0[:, :, None]).squeeze(-1)
    
    # print(f'log_p_theta shape: {log_p_theta.shape}, avg: {log_p_theta.mean()}')
    
    if self.change_of_variables or self.importance_sampling:
      return log_p_theta * torch.log1p(
        - torch.exp(- self.noise.sigma_min))
    
    return - log_p_theta * (
      dsigma / torch.expm1(sigma))[:, None]

  def _loss(self, x0, attention_mask):
    (input_tokens, output_tokens,
     attention_mask) = self._maybe_sub_sample(
       x0, attention_mask)

    if self.parameterization == 'ar':
      logprobs = self.backbone(input_tokens, None)
      loss = - logprobs.gather(
        -1, output_tokens[:, :, None])[:, :, 0]
    else:
      loss = self._forward_pass_diffusion(input_tokens)
      
    # print(f'D_KL: {-loss.mean().item()}')
    
    nlls = loss * attention_mask
    count = attention_mask.sum()

    batch_nll = nlls.sum()
    token_nll = batch_nll / count

    return Loss(loss=token_nll,
                nlls=nlls,
                token_mask=attention_mask)

  def _score_entropy(self, log_score, sigma, xt, x0):
    """Computes the SEDD loss.

    Args:
      log_score: float torch.Tensor with shape (batch_size,
          diffusion_model_input_length, vocab_size),
          log score, output of the denoising network.
      xt: int torch.Tensor with shape (batch_size,
          diffusion_model_input_length), input.
      x0: int torch.Tensor with shape (batch_size,
          diffusion_model_input_length), input.
      sigma: float torch.Tensor with shape (batch_size, 1).

    Returns:
      loss with shape (batch_size, diffusion_model_input_length)
    """
    # seems that it takes y=x0,xt=M case
    # what is the const term for, seems to be y=M,xt=x0 case and x0 is known so score estimation is precise
    masked_indices = xt == self.mask_index

    expsig_minus_1 = torch.expm1(sigma).expand_as(xt)
    q_ratio = 1 / expsig_minus_1[masked_indices]

    words_that_were_masked = x0[masked_indices]

    neg_term = q_ratio * torch.gather(
      log_score[masked_indices],
      -1,
      words_that_were_masked[..., None]).squeeze(-1)
    score = log_score[masked_indices].exp()
    if self.mask_index == self.vocab_size - 1:
      pos_term = score[:, :-1].sum(dim=-1)
    else:
      pos_term = score[:, : self.mask_index].sum(
        dim=-1) + score[:, self.mask_index + 1:].sum(dim=-1)
    const = q_ratio * (q_ratio.log() - 1)

    entropy = torch.zeros(* xt.shape, device=xt.device)
    entropy[masked_indices] += pos_term - neg_term + const
    return entropy
