import os
import collections
import copy
import pickle
import sys

import fsspec
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L

import trainer_base
import utils

import math
import scipy.stats as stats
from lightning.pytorch.callbacks.progress import TQDMProgressBar
import torch.utils.checkpoint as checkpoint

import torch.distributed as dist

# MDLM are adopted from Sahoo, Subham Sekhar, et al. "The diffusion duality." arXiv preprint arXiv:2506.10892 (2025).
class MDLM(trainer_base.AbsorbingState):
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self._validate_configuration()
    if self.config.mode == "sample_eval":
       self._loaded_checkpoint_path = self.config.eval.checkpoint_path
    self.print_every = getattr(config.algo, 'print_every', 10)
    self.enable_progress_bar = getattr(config.algo, 'enable_progress_bar', True)
    self.current_print = -1

  def on_train_start(self):
    if not self.enable_progress_bar:
        for i, callback in enumerate(self.trainer.callbacks):
            if isinstance(callback, TQDMProgressBar):
                self.trainer.callbacks[i] = DisabledProgressBar()
                break
    # self.trainer.fit(model, train_ds, valid_ds, ckpt_path=config.training.finetune_path)
    save_dir = self.trainer.default_root_dir
    if self.trainer.checkpoint_callback and hasattr(self.trainer.checkpoint_callback, 'dirpath'):
      save_dir = self.trainer.checkpoint_callback.dirpath
    self.print(f"----------------------------------------------------------------------------------------------------\n"
               f"  --> Checkpoint save to: \n"
               f"  --> \"{save_dir}\"\n"
               f"----------------------------------------------------------------------------------------------------\n")
    filepath = os.path.join(save_dir, "student_before_train.ckpt")
    os.makedirs(save_dir, exist_ok=True)
    self.trainer.save_checkpoint(filepath)
    
  def _validate_configuration(self):
    pass

  def _process_model_output(self, model_output, xt, sigma):
    del sigma
    model_output[:, :, self.mask_index] += self.neg_infinity
    
    model_output = model_output - torch.logsumexp(
      model_output, dim=-1, keepdim=True)
    unmasked_indices = (xt != self.mask_index)
    model_output[unmasked_indices] = self.neg_infinity
    model_output[unmasked_indices, xt[unmasked_indices]] = 0
    return model_output

  def nll_per_token(self, log_x_theta, xt, x0, alpha_t,
                    dalpha_t, low_var=False):
    del xt
    log_p_theta = torch.gather(
      input=log_x_theta,
      dim=-1,
      index=x0[:, :, None]).squeeze(-1)
    return log_p_theta * dalpha_t / (1 - alpha_t)

  def _get_score(self, x, sigma):
    model_output = self.forward(x, sigma)
    
    log_k = - torch.log(torch.expm1(sigma)).squeeze(-1)
    assert log_k.ndim == 1
    
    masked_score = model_output + log_k[:, None, None]
    masked_score[:, :, self.mask_index] = 0

    unmasked_score = self.neg_infinity * torch.ones_like(
      model_output)
    unmasked_score = torch.scatter(
      unmasked_score,
      -1,
      x[..., None],
      torch.zeros_like(unmasked_score[..., :1]))
    unmasked_score[:, :, self.mask_index] = - (
      log_k[:, None] * torch.ones_like(x))
    
    masked_indices = (x == self.mask_index).to(
      model_output.dtype)[:, :, None]
    model_output = (
      masked_score * masked_indices
      + unmasked_score * (1 - masked_indices))
    return model_output.exp()

  def on_save_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher'))
    super().on_save_checkpoint(checkpoint)

  def on_load_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher'))
    super().on_load_checkpoint(checkpoint)

  def training_step(self, batch, batch_idx):
      loss = super().training_step(batch, batch_idx)
      if self.print_every and self.global_step % self.print_every == 0 and self.global_step > 0 and self.current_print != self.global_step:
          self.current_print = self.global_step
          self.print(f"Global Step: {self.global_step}, Train Loss: {loss.item():.4f}")
      return loss
  
class DiffInstruct(MDLM):
  """
  Implements DiDi-Instruct to train a few-step generator via distillation.
  """
  class DiscriminatorHead(nn.Module):
    def __init__(self, hidden_size):
      super().__init__()
      self.main = nn.Sequential(
        torch.nn.utils.spectral_norm(nn.Linear(in_features=hidden_size, out_features=hidden_size)),
        nn.SiLU(),
        torch.nn.utils.spectral_norm(nn.Linear(in_features=hidden_size, out_features=1))
      )
    def forward(self, x, c=None):
      return self.main(x)
    
  def __init__(self, config, tokenizer):
    super().__init__(config, tokenizer)
    self.automatic_optimization = False

    self.teacher = copy.deepcopy(self.backbone)
    self.teacher.eval()
    for param in self.teacher.parameters():
        param.requires_grad = False
    
    self.discriminator = copy.deepcopy(self.backbone)
    for param in self.discriminator.parameters():
        param.requires_grad = False
    
    # EMA student for stable generation during validation/testing
    self.student_ema = copy.deepcopy(self.backbone)
    self.student_ema.eval()
    for param in self.student_ema.parameters():
        param.requires_grad = False
      
    # --- Core Hyperparameters ---
    self.discriminator_lr = getattr(config.algo, 'discriminator_lr', 1e-5)
    self.jeffrey_beta = getattr(config.algo, 'jeffrey_beta', 0.5) # beta for Jeffrey Div. 0=FKL, 1=RKL
    self.kl_beta = getattr(config.algo, 'kl_beta', 0.05) # For KL regularization
    self.entropy_beta = getattr(config.algo, 'entropy_beta', 0.02)
    self.regularization_type = getattr(config.algo, 'regularization_type', 'Jeffrey')  # Forward KL, Backward KL, Jeffrey

    # --- Guided Sampling Hyperparameters ---
    self.num_candidates = getattr(config.algo, 'num_candidates', 1)
    self.guidance_scale_start = getattr(config.algo, 'guidance_scale_start', 0.2)
    self.guidance_scale_end = getattr(config.algo, 'guidance_scale_end', 1.0)
    self.rerank_steps_ratio = getattr(config.algo, 'rerank_steps_ratio', 0.5)

    # --- Logging & Housekeeping ---
    self.print_every = getattr(config.algo, 'print_every', 10)
    self.current_print = -1
    self.optim_config_student = config.optim
    self.optim_config_discriminator = config.algo.discriminator_optim
    self.lr_scheduler_config = config.algo.lr_scheduler
    self.save_hyperparameters()
    self.save_model_every = config.algo.save_after_n_steps
    self.last_growth_step = -1
    self.out_dir = getattr(config.algo, 'output_dir', None)
    self.global_steps = 0
    self.ema_beta = getattr(config.algo, 'ema_beta', 0.999)
    self.num_data = None

    self.register_buffer('reward_baseline', torch.tensor(0.0))

  def update_ema(self):
    with torch.no_grad():
      for param, ema_param in zip(self.backbone.parameters(), self.student_ema.parameters()):
        ema_param.data.lerp_(param.data, 1.0 - self.ema_beta)

  def on_train_start(self):
    if not self.enable_progress_bar:
      for i, callback in enumerate(self.trainer.callbacks):
        if isinstance(callback, TQDMProgressBar):
          self.trainer.callbacks[i] = DisabledProgressBar()
          break

  def to(self, *args, **kwargs):
    super().to(*args, **kwargs)
    self.teacher.to(*args, **kwargs)
    self.discriminator.to(*args, **kwargs)
    self.student_ema.to(*args, **kwargs)
    return self

  def _eval_mode(self):
    self.eval()

  def on_save_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher')
    )
    super(MDLM, self).on_save_checkpoint(checkpoint)

  def on_load_checkpoint(self, checkpoint):
    checkpoint['state_dict'] = collections.OrderedDict(
      (k, v) for k, v in checkpoint['state_dict'].items()
      if not k.startswith('teacher')
    )
    super(MDLM, self).on_load_checkpoint(checkpoint)

  def load_state_dict(self, state_dict, strict=False):
    return super().load_state_dict(state_dict, strict=strict)
    
  def _get_lr(self):
    step = self.global_steps
    opt_student, opt_discriminator = self.optimizers()
    warmup_steps = self.lr_scheduler_config.warmup_steps
    total_steps = self.trainer.max_steps
    # pass
    # return opt_student, opt_discriminator
  
  def _get_tau(self, batch_size):
    """
    Samples tau, the intermediate timestep 
    """
    r = torch.rand(batch_size, device=self.device)
    pi_tau = torch.ones_like(r)

    if self.tau_mode.startswith("beta"):
      a = float(self.tau_mode[4]) if len(self.tau_mode) > 4 else 2.0
      b = float(self.tau_mode[5]) if len(self.tau_mode) > 5 else a
      beta_dist = torch.distributions.Beta(torch.tensor(a, device=self.device), torch.tensor(b, device=self.device))
      tau = beta_dist.sample((batch_size,))
      pi_tau = beta_dist.log_prob(tau).exp()

    # Get alpha and sigma from the noise schedule.
    _, alpha_tau = self.noise(tau)
    sigma_tau = self._sigma_from_alphat(alpha_tau.unsqueeze(-1)).squeeze(-1)
    
    return tau, alpha_tau, sigma_tau, pi_tau
  
  def _corrupt_with_mask_ratio(self, x0, alpha_t):
      """
      Corrupts x0 to z_t 
      """
      masking_prob = 1.0 - alpha_t.unsqueeze(-1)
      noise = torch.rand_like(x0, dtype=torch.float32)
      mask = noise < masking_prob
      zt = x0.clone()
      zt[mask] = self.mask_index
      
      return zt, mask

  def _get_masked_input(self, batch_size):
    shape = (batch_size, self.num_tokens)
    return torch.full(shape, self.mask_index, device=self.device, dtype=torch.long)
  
  def _get_corrupted_inputs(self, batch_size, tau_gen=None, alpha_gen=None, sigma_gen=None):
    pass
    # return t_corr, alpha_corr, sigma_corr
  
  @torch.no_grad()
  def _get_rewards(self, zt_fake, sigma_corr, mask_fake, original_batch_size, G):
    """
    Calculates the reward signal from the discriminator
    """
    pass
    # return final_reward, reward_mean, reward_std, reward_fraction_clipped, advantage

  def _get_backward(self, x_prev, t_hi, t_lo):
    """
    Performs a single reverse step 
    """
    pass
    # return x_next, logp_step, logits, prev_was_mask
  
  def _get_regularization(self, batch_size, x_T, x_tau, sigma_gen, logits_step1, mask_step1, 
                          logits_step2, mask_step2):
    """
    Calculates the regularization (simplified for ICLR submission)
    """
    with torch.no_grad():
        sigma_hi = self._sigma_from_alphat(self.noise(torch.full((batch_size,), 1.0, device=self.device))[1].unsqueeze(-1)).squeeze(-1)
        teacher_logits_step1 = self.teacher(x_T, sigma=sigma_hi)
    with torch.no_grad():
        teacher_logits_step2 = self.teacher(x_tau, sigma=sigma_gen)

    kl_loss_step1 = F.kl_div(F.log_softmax(logits_step1, dim=-1), F.softmax(teacher_logits_step1, dim=-1), reduction='none', log_target=False).sum(dim=-1)
    kl_loss_step2 = F.kl_div(F.log_softmax(logits_step2, dim=-1), F.softmax(teacher_logits_step2, dim=-1), reduction='none', log_target=False).sum(dim=-1)

    loss_kl_hi = (kl_loss_step1 * mask_step1).sum() / (mask_step1.sum() + 1e-8)
    loss_kl_lo = (kl_loss_step2 * mask_step2).sum() / (mask_step2.sum() + 1e-8)
    loss_kl = self.kl_hi_coef * loss_kl_hi + self.kl_lo_coef * loss_kl_lo

    entropy_step1 = torch.distributions.Categorical(logits=logits_step1).entropy()  # Entropy bonus for both steps
    entropy_step1 = (entropy_step1 * mask_step1).sum() / (mask_step1.sum() + 1e-8)
    entropy_step2 = torch.distributions.Categorical(logits=logits_step2).entropy()
    entropy_step2 = (entropy_step2 * mask_step2).sum() / (mask_step2.sum() + 1e-8)
    entropy = (entropy_step1 + entropy_step2) / 2

    return loss_kl, entropy

  
  @torch.no_grad()
  def _get_ancestral_update(self, x, t, dt, p_x0=None, noise_removal_step=False):
    """
    Overrides the parent class's ancestral update method
    """
    pass
    # return p_x0_for_grad, copy_flag * x + (1 - copy_flag) * _x

  def _get_gradient_tilting_step(self, x, t, dt, guidance_scale, discriminator):
    """
    Performs a guided ancestral step using Gradient-Based Logits Tilting.
    """
    pass
    # return torch.where(x != self.mask_index, x, sampled_tokens)
  
  def _get_guided_ancestral_step(self, x, t, dt, guidance_scale, discriminator):
    """
    Performs a single unified guided ancestral sampling step.
    """
    pass
    # return x_next
  
  @torch.no_grad()
  def generate_samples(self, num_samples, num_steps=None, eps=1e-5):
    """
    Generate samples from the model with support for analytical, ancestral, and a hybrid guided strategy.
    """
    if self.config.mode == "sample_eval":
        print("\n========================================================")
        print(f"Generating samples from model:")
        if hasattr(self, '_loaded_checkpoint_path'):
            print(f"  --> {self._loaded_checkpoint_path}")
        print("========================================================")

    if num_steps is None:
      num_steps = self.config.sampling.steps
      
    # Use the EMA student for stable generation
    self.backbone.eval()
    self.discriminator.eval()

    x = self.prior_sample(num_samples, self.num_tokens)
    timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device)
    dt = (1 - eps) / num_steps
    rerank_start_step = int(num_steps * (1 - self.rerank_steps_ratio))

    for i in range(num_steps):
      t = timesteps[i] * torch.ones(x.shape[0], 1, device=self.device)
      if self.config.mode == "sample_eval":
          print(f"[Eval] Backward step {i+1}/{num_steps}, t = {t.mean().item():.4f}")
      
      # --- DISPATCH TO SAMPLER ---
      if 'ancestral' in self.sampler:
        _, x = super(MDLM, self)._ancestral_update(x=x, t=t, dt=dt, p_x0=None)
      elif self.sampler == 'analytic':
        x = super(MDLM, self)._analytic_update(x=x, t=t, dt=dt)
      elif self.sampler == 'guided':
        progress = i / (num_steps - 1)
        guid_scale = self.guidance_scale_start + progress * (self.guidance_scale_end - self.guidance_scale_start)
        if i < rerank_start_step:
            x = self._get_gradient_tilting_step(x, t, dt, guid_scale, self.discriminator)
        else:
            x = self._get_guided_ancestral_step(x, t, dt, guid_scale, self.discriminator)
      else:
        raise ValueError(f"Unknown sampler: {self.sampler}")

    t0 = timesteps[-1] * torch.ones(x.shape[0], 1, device=self.device)
    if self.config.sampling.noise_removal == 'ancestral':
        x = self._denoiser_update(x=x, t=t0)
    elif self.config.sampling.noise_removal == 'greedy':
      sigma = self._sigma_from_alphat(self.noise(t0)[1])
      model_output = self.forward(xt=x, sigma=sigma)
      x = self._process_model_output(model_output, xt=x, sigma=sigma).argmax(dim=-1)
    self.backbone.train()
    self.discriminator.train()
    return x
  
  def training_step(self, batch, batch_idx):

      # --- Repeat inputs for grouped sampling ---
      
      # --- STUDENT GENERATION & LOG PROB (BRANCH FOR 1-STEP vs 2-STEP) ---

      # --- Discriminator Update ---

      # --- Student Update ---
      
      # --- Logging ---

      # --- Save Checkpoints ---
      pass
  
  def configure_optimizers(self):
    optimizer_student = torch.optim.AdamW(
      self.backbone.parameters(),
      lr=self.optim_config_student.lr,
      weight_decay=self.optim_config_student.weight_decay,
      betas=(self.optim_config_student.beta1, self.optim_config_student.beta2)
    )
    trainable_discriminator_params = filter(
        lambda p: p.requires_grad, self.discriminator.parameters()
    )
    optimizer_discriminator = torch.optim.AdamW(
      trainable_discriminator_params, # Pass only the filtered parameters
      lr=self.discriminator_lr,
      betas=(self.optim_config_discriminator.beta1, self.optim_config_discriminator.beta2),
      weight_decay=self.optim_config_discriminator.weight_decay
    )
    return [optimizer_student, optimizer_discriminator]
  
