import os
import math
from typing import Optional, Tuple

import hydra
import lightning as L
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from main import _print_config

import dataloader
import diffusion as diffusion_mod
from models.remaskator import RemaskatorNet
from models.remaskator2 import Remaskator2Net




class RemaskatorModule(L.LightningModule):
  """Lightning module that trains RemaskatorNet against a frozen diffusion denoiser.

  Training step:
    - Sample x0 from train batch
    - Sample t ~ Uniform(eps, 1), s ~ Uniform(0, t - eps)
    - Compute x_t via q(x_t | x0, t)
    - Compute x_s via a single DDPM update from t to s conditioned on embedding(x0)
    - For tokens newly generated between x_t and x_s, label y=1 if x_s!=x0 (should remask), else 0
    - Optimize BCE on those positions only
  """

  def __init__(self, config: omegaconf.DictConfig, tokenizer, denoiser: diffusion_mod.Diffusion):
    super().__init__()
    self.save_hyperparameters(ignore=['tokenizer', 'denoiser'])
    self.config = config
    self.tokenizer = tokenizer
    self.denoiser = denoiser.eval()
    for p in self.denoiser.parameters():
      p.requires_grad = False

    # Common attributes from denoiser
    self.mask_index = int(self.denoiser.mask_index)
    self.vocab_size = int(self.denoiser.vocab_size)
    self.padding_index = int(self.denoiser.tokenizer.pad_token_id)
    self.seq_len = int(self.config.model.length)
    self.eps = float(self.config.training.sampling_eps)

    # t sampling options: 'uniform' or 'const'. If 'const', use training.t_const (default 1.0)
    # Use safe accessors so older configs without these fields still work.
    self.t_sampling = config.sampling.t_sampling
    self.t_const = config.sampling.t_const


    self.net = Remaskator2Net(
      vocab_size=self.vocab_size,
      config=self.config,
      cond_dim=self.denoiser.cond_dim,
    )

    self.lr = float(self.config.optim.lr)

    # All metrics computed ad-hoc per step on masked positions

  @torch.no_grad()
  def _sample_t_s(self, batch_size: int, device: torch.device) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    # Sample t according to configured strategy; s is kept at 0; dt = t - s
    eps = torch.tensor(self.eps, device=device, dtype=torch.float32)

    if self.t_sampling == 'uniform':
      t = (1 - eps) * torch.rand(batch_size, device=device) + eps
    elif self.t_sampling == 'const':
      # Clamp constant t to [eps, 1.0]
      t_value = max(float(self.eps), min(1.0, float(self.t_const)))
      t = torch.full((batch_size,), t_value, device=device, dtype=torch.float32)
    else:
      raise ValueError(f"Unknown t_sampling mode: {self.t_sampling}. Expected 'uniform' or 'const'.")

    s = 0

    dt = t - s
    return t, s, dt

  @torch.no_grad()
  def _compute_xt_xs_and_labels(self, x0: torch.LongTensor) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor]:
    device = x0.device
    batch_size = x0.shape[0]
    padding_mask = x0 == self.padding_index

    # Sample times
    t, s, dt = self._sample_t_s(batch_size, device)
    t_in = t.view(batch_size, 1)
    dt_in = dt.view(batch_size, 1)

    # Compute q(x_t | x0)
    sigma_t, _ = self.denoiser.noise(t)
    move_chance = 1.0 - torch.exp(-sigma_t)
    if move_chance.ndim == 1:
      move_chance = move_chance[:, None]
    xt = self.denoiser.q_xt(x0, move_chance)

    # Conditioning on x0 embedding if available
    cond = self.denoiser.indices_to_text_embeddings(x0)

    # One-step DDPM update from t to s
    xs = self.denoiser._ddpm_update(xt, t_in, dt_in, condition=cond)

    # Identify newly generated tokens between xt and xs
    new_mask = (xt == self.mask_index) & (xs != self.mask_index)

    # Targets: 1 if incorrect (should remask), 0 if correct (keep)
    target = (xs != x0).to(torch.float32)

    xt[padding_mask] = self.padding_index
    xs[padding_mask] = self.padding_index
    new_mask[padding_mask] = False
    target[padding_mask] = 0

    return xt, xs, new_mask.to(torch.bool), target

  def _compute_loss_and_metrics(self, logits: torch.FloatTensor, target: torch.FloatTensor, new_mask: torch.BoolTensor):
    # BCE over all tokens; optionally apply class-balanced weighting so 0/1 contribute equally
    if self.config.training.remaskator_reweighting:
      with torch.no_grad():
        total_elems = torch.tensor(float(target.numel()), device=logits.device, dtype=logits.dtype)
        pos_count = target.sum()
        neg_count = total_elems - pos_count
        # Avoid division by zero; if a class is absent, its weight won't be used by where()
        pos_count = torch.clamp(pos_count, min=1.0)
        neg_count = torch.clamp(neg_count, min=1.0)
        # Set weights so total weight mass per class equals total_elems/2
        w_pos = (total_elems / (2.0 * pos_count)).to(dtype=logits.dtype)
        w_neg = (total_elems / (2.0 * neg_count)).to(dtype=logits.dtype)
        weight = torch.where(target > 0.5, w_pos, w_neg)
      loss = F.binary_cross_entropy_with_logits(logits, target, weight=weight, reduction='mean')
    else:
      loss = F.binary_cross_entropy_with_logits(logits, target, reduction='mean')

    with torch.no_grad():
      preds = (torch.sigmoid(logits) > 0.5).to(torch.int32)
      true = target.to(torch.int32)

      total = preds.numel()
      if total == 0:
        acc = torch.tensor(0.0, device=logits.device)
        precision = torch.tensor(0.0, device=logits.device)
        recall = torch.tensor(0.0, device=logits.device)
        f1 = torch.tensor(0.0, device=logits.device)
        mistake_ratio = torch.tensor(0.0, device=logits.device)
      else:
        # Accuracy over all tokens
        correct = (preds == true).sum()
        acc = correct.float() / total

        # Precision, Recall, F1 over all tokens
        tp = ((preds == 1) & (true == 1)).sum().float()
        fp = ((preds == 1) & (true == 0)).sum().float()
        fn = ((preds == 0) & (true == 1)).sum().float()

        precision = tp / (tp + fp).clamp(min=1e-8)
        recall = tp / (tp + fn).clamp(min=1e-8)
        f1 = 2 * precision * recall / (precision + recall).clamp(min=1e-8)

        # Ratio of denoiser mistakes among all tokens
        mistakes = (true == 1).sum().float()
        mistake_ratio = mistakes / total
      
    return loss, acc, mistake_ratio, precision, recall, f1

  def forward(self, x_tokens: torch.LongTensor, cond: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
    if not self.config.text_embedder.use_text_embedder:
      cond = None
    return self.net(x_tokens, cond)

  def training_step(self, batch, batch_idx):
    x0 = batch['input_ids']  # (batch, seq_len)
    x0 = x0.to(self.device)

    # Compute xs and labels
    xt, xs, new_mask, target = self._compute_xt_xs_and_labels(x0)

    # Condition on x0 embeddings if available
    cond = self.denoiser.indices_to_text_embeddings(x0)

    logits = self.forward(xs, cond)

    loss, acc, mistake_ratio, precision, recall, f1 = self._compute_loss_and_metrics(logits, target, new_mask)
    
    # Log metrics
    self.log('train/loss', loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    self.log('train/acc', acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    self.log('train/denoiser_mistake_ratio', mistake_ratio, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    self.log('train/precision', precision, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    self.log('train/recall', recall, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    self.log('train/f1', f1, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x0 = batch['input_ids'].to(self.device)
    xt, xs, new_mask, target = self._compute_xt_xs_and_labels(x0)
    cond = self.denoiser.indices_to_text_embeddings(x0)
    logits = self.forward(xs, cond)

    loss, acc, mistake_ratio, precision, recall, f1 = self._compute_loss_and_metrics(logits, target, new_mask)
    
    # Log metrics
    self.log('val/loss', loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    self.log('val/acc', acc, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    self.log('val/denoiser_mistake_ratio', mistake_ratio, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    self.log('val/precision', precision, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    self.log('val/recall', recall, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    self.log('val/f1', f1, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(self.config.optim.beta1, self.config.optim.beta2), eps=self.config.optim.eps, weight_decay=self.config.optim.weight_decay)
    scheduler = hydra.utils.instantiate(self.config.lr_scheduler, optimizer=optimizer)
    return [optimizer], [{'scheduler': scheduler, 'interval': 'step', 'name': 'trainer/lr'}]


def _load_denoiser(config: omegaconf.DictConfig, tokenizer) -> diffusion_mod.Diffusion:
  if 'hf' in config.backbone:
    return diffusion_mod.Diffusion(config, tokenizer=tokenizer).to('cuda')
  return diffusion_mod.Diffusion.load_from_checkpoint(
    config.eval.checkpoint_path,
    tokenizer=tokenizer,
    config=config,
    strict=False,
  )


@hydra.main(version_base=None, config_path='configs', config_name='config')
def main(config: omegaconf.DictConfig):

  _print_config(config, resolve=True, save_cfg=True)

  # Tokenizer and dataloaders
  tokenizer = dataloader.get_tokenizer(config)
  train_loader, valid_loader = dataloader.get_dataloaders(config, tokenizer)

  # Denoiser (frozen)
  assert config.eval.checkpoint_path, 'Please set eval.checkpoint_path to the denoiser checkpoint.'
  denoiser = _load_denoiser(config, tokenizer)
  denoiser.eval()
  for p in denoiser.parameters():
    p.requires_grad = False

  # Module
  module = RemaskatorModule(config=config, tokenizer=tokenizer, denoiser=denoiser)
  # load weight from denoiser
  missing_keys, unexpected_keys = module.net.dit.load_state_dict(denoiser.backbone.state_dict(), strict=False)
  print(f"Missing keys: {missing_keys}")
  print(f"Unexpected keys: {unexpected_keys}")
  module.net.change_final_layer()

  if config.sampling.freeze_backbone:
    module.net.freeze_backbone()

  # Determine resume checkpoint path (optional)
  resume_ckpt_path = None
  if config.checkpointing.resume_from_ckpt:
    candidate_path = str(config.checkpointing.resume_ckpt_path)
    if candidate_path == 'last' or os.path.isfile(candidate_path):
      resume_ckpt_path = candidate_path
    else:
      print(f"[remaskator] Resume requested but checkpoint not found at {candidate_path}. Starting from scratch.")

  # Setup wandb logger
  wandb_logger = WandbLogger(
    project=config.wandb.project,
    name=f"remaskator_{config.wandb.name}",
    group=config.wandb.group,
    job_type="remaskator_training",
    tags=config.wandb.tags + ["remaskator"],
    notes=f"Remaskator training - {config.wandb.notes}",
    id=f"remaskator_{config.wandb.id}",
    save_dir=config.checkpointing.save_dir,
    resume='allow' if resume_ckpt_path else None
  )

  # Setup callbacks
  lr_monitor = LearningRateMonitor(logging_interval='step')
  callbacks = [lr_monitor]
  
  # Add any existing callbacks from config
  if hasattr(config, 'callbacks') and config.callbacks:
    existing_callbacks = [hydra.utils.instantiate(cb) for cb in config.callbacks.values()]
    callbacks.extend(existing_callbacks)

  # Trainer
  trainer: L.Trainer = hydra.utils.instantiate(config.trainer, logger=wandb_logger, callbacks=callbacks)

  # Fit (optionally resume from checkpoint)
  if resume_ckpt_path:
    print(f"[remaskator] Resuming training from checkpoint: {resume_ckpt_path}")
  trainer.fit(module, train_dataloaders=train_loader, val_dataloaders=valid_loader, ckpt_path=resume_ckpt_path)

  # Save final checkpoint
  ckpt_dir = os.path.join(config.checkpointing.save_dir, 'checkpoints_remaskator')
  os.makedirs(ckpt_dir, exist_ok=True)
  trainer.save_checkpoint(os.path.join(ckpt_dir, 'last.ckpt'))


if __name__ == '__main__':
  main()

