"""
Based on: https://github.com/kuleshov-group/mdlm

"""
import itertools
import math
import os
import typing
from dataclasses import dataclass
from pathlib import Path

import hydra.utils
import lightning as L
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
import transformers
import torchvision.models as vision_models

import dataloader
import models
import noise_schedule
import utils
import logging
import hashlib
from typing import Dict, Tuple, Iterable

from logging.handlers import RotatingFileHandler

from transformers import CLIPVisionModel, CLIPProcessor

import csv
import matplotlib.pyplot as plt
from PIL import Image as PILImage
import json
from datetime import datetime

import torchvision.transforms as transforms
from ImageEncoder import ImageEncoder
import diffusion_utils
import traceback


### ==== INITIAL CONFIGURATION ====

log_dir = 'logs'
os.makedirs(log_dir, exist_ok=True)

# Set up logger with a rotating file handler
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # Set to INFO to avoid too much verbosity

# Set up a rotating log file handler (max size of 1MB and 3 backups)
handler = RotatingFileHandler(os.path.join(log_dir, 'output.log'), maxBytes=10**6, backupCount=3)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
if not any(isinstance(h, RotatingFileHandler) for h in logger.handlers):
    logger.addHandler(handler)

LOG2 = math.log(2)

repo_root = Path(os.environ.get("REPO_ROOT", os.getcwd()))
parent_dir = repo_root
        
@dataclass
class Loss:
    loss: torch.FloatTensor
    nlls: torch.FloatTensor  # Should be sum of NLLs per sample for metric
    token_mask: torch.FloatTensor  # Should be count of tokens per sample for metric


class NLL(torchmetrics.aggregation.MeanMetric):
    pass  # MeanMetric will average (sum of NLLs / sum of token_counts)


class BPD(NLL):
    def compute(self) -> torch.Tensor:
        mean_nll_val = super().compute()
        if torch.is_tensor(mean_nll_val) and not torch.isnan(mean_nll_val) and not torch.isinf(mean_nll_val):
            return mean_nll_val / LOG2
        device = mean_nll_val.device if torch.is_tensor(mean_nll_val) else torch.device("cpu")
        return torch.tensor(float('nan'), device=device)


class Perplexity(NLL):
    def compute(self) -> torch.Tensor:
        mean_nll_val = super().compute()
        if torch.is_tensor(mean_nll_val) and not torch.isnan(mean_nll_val) and not torch.isinf(mean_nll_val):
            return torch.exp(mean_nll_val)
        device = mean_nll_val.device if torch.is_tensor(mean_nll_val) else torch.device("cpu")
        return torch.tensor(float('nan'), device=device)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


### ==== DIFFUSION CLASS ====  

class Diffusion(L.LightningModule):
    def __init__(
        self,
        config,
        tokenizer: transformers.PreTrainedTokenizer
    ):
        super().__init__()
        self.save_hyperparameters(ignore=['tokenizer'])
        self.config = config

        self.tokenizer = tokenizer

        self.parameterization = self.config.parameterization
        self.sampler = self.config.sampling.get("predictor", "analytic")  
        self.T = self.config.T
        self.change_of_variables = self.config.training.change_of_variables
        self.time_conditioning = self.config.time_conditioning

        self.train_loss_history = []
        self.val_loss_history = []
        
        self.vocab_size = len(self.tokenizer)
        
        # ==== Mask tokens ====
        if hasattr(tokenizer, 'mask_token_id') and tokenizer.mask_token_id is not None:
            self.mask_index = tokenizer.mask_token_id
            logger.info(f"Using existing mask token: {tokenizer.mask_token_id}")
        else:
            # Add a new mask token to vocabulary
            raise ValueError("Tokenizer must have a mask token")	

        # ==== Image conditioning flags ====
        self.use_image_conditioning = getattr(self.config, "use_image_conditioning", False)
        if self.use_image_conditioning:
            self.image_encoder = ImageEncoder(self.config)
        else:
            self.image_encoder = None

        # ==== Backbone ====
        if self.config.backbone == 'dit':
            dit_image_feature_dim = None
            if self.use_image_conditioning and self.image_encoder is not None:
                dit_image_feature_dim = self.image_encoder.output_dim
                
            logger.info("Implementing DIT backbone")
            self.backbone = models.dit_v2.DIT(
                config=self.config,
                vocab_size=self.vocab_size,
                image_feature_dim=dit_image_feature_dim
            )
            
            # ===== Cross-attention enhancement ====          
            self.backbone = diffusion_utils.cross_attention_enhancement(self.use_image_conditioning, dit_image_feature_dim, self.config, logger, self.backbone)            
            
	    # ===== Noise ====
            self.noise = noise_schedule.get_noise(self.config, dtype=self.dtype)
            
            # =====  Load pretrained backbone weights if specified (from our own pretrained backbone with the specified tokenizer) ===== 
            if getattr(self.config.model, "init_from_pretrained_backbone", False):
                logger.info("Loading pretrained backbone weights...")
                backbone_model_path = self.config.model.get("pretrained_backbone_path", None)
                logger.info(f"Loading from: {backbone_model_path}")
                
                # Load the checkpoint
                if os.path.isfile(backbone_model_path):
                    # Local checkpoint file
                    checkpoint = torch.load(backbone_model_path, map_location='cpu')
                    if 'state_dict' in checkpoint:
                        pretrained_state = checkpoint['state_dict']
                    else:
                        pretrained_state = checkpoint
                else:
                    hf_model_object = transformers.AutoModelForMaskedLM.from_pretrained(
                        backbone_model_path, trust_remote_code=True
                    )
    
                    if hasattr(hf_model_object, "backbone"):
                        logger.info("Using hf_model.backbone.state_dict()")
                        pretrained_state = hf_model_object.backbone.state_dict()
                    else:
                        raw = hf_model_object.state_dict()
                        pretrained_state = {}
                        prefix = "backbone."
                        for name, param in raw.items():
                            key = name[len(prefix):] if name.startswith(prefix) else name
                            pretrained_state[key] = param
    
                # Load into backbone using existing function
                report = diffusion_utils.load_hf_into_dit_and_report(
                        dit=self.backbone,
                        hf_state=pretrained_state,
                        ignore_prefixes=("image_adapter.", "image_embedder.", "image_encoder."),  # Ignore image related params
                        try_map_head_and_embed=True,
                        logger=logger,
                        )
            
            # ==== Load pretrained backbone weights (openwebtext 1M HF pretrained - gpt2Tokenizer) ====
            if getattr(self.config.model, "init_from_hf_dit", False):        
            
                logger.info("Loading pretrained weights from HuggingFace DiT model...")
                
                hf_model_path = self.config.model.get("hf_dit_path", "mdlm-owt-model")
                hf_model_object_to_get_state_dict_from = transformers.AutoModelForMaskedLM.from_pretrained(
                    hf_model_path, trust_remote_code=True
                )
                
                logger.info("HF loaded")
                
                if hasattr(hf_model_object_to_get_state_dict_from, "backbone"):
                    logger.info("Using hf_model.backbone.state_dict()")
                    hf_state = hf_model_object_to_get_state_dict_from.backbone.state_dict()
                else:
                    logger.info("Using hf_model.state_dict()")
                    raw = hf_model_object_to_get_state_dict_from.state_dict()
                    hf_state = {}
                    prefix = "backbone."
                    for name, param in raw.items():
                        hf_state[name[len(prefix):]] = param if name.startswith(prefix) else param
                # Copy into our DIT and get a report
                report = diffusion_utils.load_hf_into_dit_and_report(
                    dit=self.backbone,
                    hf_state=hf_state,
                    ignore_prefixes=("image_adapter.",),  # leave new params random
                    try_map_head_and_embed=True,
                    logger=logger,
                )      
       
            total_trainable = sum(p.numel() for p in self.backbone.parameters())
            loaded_params = sum(self.backbone.state_dict()[k].numel() for k in report["matched"])
            pct = 100.0 * loaded_params / total_trainable
            logger.info(f"Loaded backbone params: {loaded_params}/{total_trainable} (~{pct:.1f}%)")
    
            if pct < 50.0:
                logger.warning("Less than 50% of backbone parameters matched; check tokenizer/vocab sizes and naming.")
            else:
                logger.info("Successfully loaded pretrained backbone weights for fine-tuning!")
                    
                                           
            # To restart EMA parameters
            if self.config.training.ema > 0:
                ema_params = list(self.backbone.parameters()) + list(self.noise.parameters())
            
                # Add image encoder parameters to EMA
                if self.use_image_conditioning and self.image_encoder is not None:
                    ema_params += list(self.image_encoder.parameters())
            
                self._ema_params = ema_params
                self.ema = models.ema.ExponentialMovingAverage(self._ema_params, decay=self.config.training.ema)
            else:
                self.ema = None
                self._ema_params = []   
        else:
            raise ValueError(f'Unknown backbone: {self.config.backbone}')
        
        self.antithetic_sampling = self.config.training.antithetic_sampling
        self.importance_sampling = self.config.training.importance_sampling
        self.subs_masking = self.config.subs_masking
        self.softplus = torch.nn.Softplus()
        metrics_collection = torchmetrics.MetricCollection({
            'nll': NLL(), 'bpd': BPD(), 'ppl': Perplexity(),
        })
        metrics_collection.set_dtype(torch.float64)
        self.train_metrics = metrics_collection.clone(prefix='train/')
        self.valid_metrics = metrics_collection.clone(prefix='val/')
        self.test_metrics = metrics_collection.clone(prefix='test/')
        self.gen_ppl_metric = Perplexity().set_dtype(torch.float64)
        
        self.lr = self.config.optim.lr
        self.sampling_eps = self.config.training.sampling_eps
        self.neg_infinity = -1000000.0
        self.fast_forward_epochs = None
        self.fast_forward_batches = None
        self._validate_configuration()
                
        # Add tracking variables for image-text logging
        self.tracking_images = None
        self.tracking_log_dir = "generation_tracking"
        self.tracking_interval = getattr(config, "tracking_interval", 100)  # Log every N steps
        os.makedirs(self.tracking_log_dir, exist_ok=True)
        
        # Create subdirectories
        os.makedirs(os.path.join(self.tracking_log_dir, "images"), exist_ok=True)
        os.makedirs(os.path.join(self.tracking_log_dir, "generations"), exist_ok=True)
        os.makedirs(os.path.join(self.tracking_log_dir, "combined"), exist_ok=True)
        
        # Initialize tracking log
        self.tracking_log = [] 
        
        
    def forward(self, x: torch.Tensor, sigma: torch.Tensor, 
                image_features: typing.Optional[torch.Tensor] = None):
        """
        Args:
            x: Input token indices [B, S]. Noisy tokens? [B, seq_len]
            sigma: Noise level [B] or [B, 1]
            image_features: Optional image features [B, D] or [B, N, D]
        """
        sigma_processed = self._process_sigma(sigma)

        if self.config.backbone == 'dit':
            # Pass through DiT backbone with image conditioning
            logits = self.backbone(x, sigma_processed, image_features=image_features)        
        else:
            logits = self.backbone(x, sigma_processed)

        # Apply parameterization
        if self.parameterization == 'subs':
            return self._subs_parameterization(logits=logits, xt=x)        
        return logits  # [B, seq_len, vocab_size]
        
        
    # ==== Process sigma & xt ====
    def _process_sigma(self, sigma): 
        if sigma is None:
            assert self.parameterization == 'ar'
            return sigma
        if sigma.ndim > 1:
            sigma = sigma.squeeze(-1)
        if not self.time_conditioning:
            sigma = torch.zeros_like(sigma)
        assert sigma.ndim == 1, sigma.shape
        return sigma    

    def q_xt(self, x, move_chance):
        """Computes the noisy sample xt.

        Args:
          x: int torch.Tensor with shape (batch_size, diffusion_model_input_length), input.
          move_chance: float torch.Tensor with shape (batch_size, 1).
        """
        move_indices = torch.rand(*x.shape, device=x.device) < move_chance.to(x.device)
        xt = torch.where(move_indices, self.mask_index, x)
        return xt
          
        
    # ==================== TRAIN THE MODEL =========================
    
    def on_train_start(self):  
        if hasattr(self, "ema") and self.ema:
            self.ema.move_shadow_params_to_device(self.device)
        
    def on_train_epoch_start(self):
        self.backbone.train()
        self.noise.train()

    def training_step(self, batch, batch_idx):
        images = batch.get("images", None)
        image_features = None
        
        # Extract image feature if images exist
        if self.use_image_conditioning and images is not None and self.image_encoder is not None:
            image_features = self.image_encoder(images)
            if self.training and self.logger and image_features is not None:
                self.log("image_feat/norm", image_features.norm(dim=-1).mean(), 
                        on_step=True, on_epoch=False)

        # Compute the loss using the diffusion process
        loss = self._compute_loss(batch, prefix='train', raw_image_features_from_step=image_features)
        self.log('train/loss_step', loss.item(), on_step=True, on_epoch=False, 
                sync_dist=True, prog_bar=True)

        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
        self.train_loss_history.append((self.global_step, loss.item()))
        
        # ===== MONITORING LOGIC =====
    
        # Setup tracking images on first batch
        if not hasattr(self, 'tracking_images') or self.tracking_images is None:
           if self.use_image_conditioning and images is not None:
                print("Setting up tracking images from current batch...")
            
                captions = batch.get('caption', [''] * len(images))
            
                num_tracking = min(4, len(images))
                self.tracking_images = images[:num_tracking].to(self.device)
                self.tracking_captions = captions[:num_tracking]
            
                try:
                    self._save_tracking_images()
                    print(f"Successfully setup {num_tracking} tracking images")
                except Exception as e:
                    print(f"Error setting up tracking images: {e}")                    
        return loss   
        
    def on_train_end(self):
        diffusion_utils.save_loss_history(self.train_loss_history, self.val_loss_history)
        super().on_train_end()  
        
    def on_train_batch_end(self, outputs, batch, batch_idx):
        # Execute monitoring after training step is complete"""
        if hasattr(self, '_should_log_after_step') and self._should_log_after_step:
	          self._should_log_after_step = False
	          print(f"Executing generation logging for step {self.global_step}")
	          try: self.log_generation_progress()
	          except Exception as e:
	              print(f"Error during generation logging: {e}")
	              traceback.print_exc()
        
    # ==================== COMPUTE LOSS =========================    
    def _compute_loss(self, batch, prefix, raw_image_features_from_step=None):
        attention_mask = batch.get('attention_mask', torch.ones_like(batch['input_ids'], device=self.device))
        losses_obj = self._loss(batch['input_ids'], attention_mask, raw_image_features_from_step)
        loss_val = losses_obj.loss
        metrics_to_update = getattr(self, f"{prefix}_metrics")
        metrics_to_update.update(losses_obj.nlls, losses_obj.token_mask)
        metrics = metrics_to_update.compute()
        
        # Monitor EOS token predictions
        for k, v in metrics.items():
            self.log(k,v, on_step=(prefix == 'train'), on_epoch=True, sync_dist=True, prog_bar=(k == 'val/nll'))
            
        return loss_val
        
        
    def _loss(self, x0, attention_mask, image_features=None):
    
        # x0 = clean text tokens [B, seq_len]        
        # Subsample if needed (for autoregressive case)        
        (input_tokens, output_tokens,
         attention_mask_processed) = self._maybe_sub_sample(x0, attention_mask)

        # Diffusion: add noise and predict clean tokens
        loss_val = self._forward_pass_diffusion(input_tokens, image_features=image_features)

        nlls_masked = loss_val * attention_mask_processed
        nlls_sum_per_sample = nlls_masked.sum(dim=-1)
        token_counts_per_sample = attention_mask_processed.sum(dim=-1).clamp(min=1e-9)
        mean_loss_per_token = nlls_sum_per_sample.sum() / token_counts_per_sample.sum().clamp(min=1e-9)

        return Loss(
            loss=mean_loss_per_token,
            nlls=nlls_sum_per_sample,
            token_mask=token_counts_per_sample
        )        
    def _maybe_sub_sample(self, x0, attention_mask): 
        seqlen = x0.shape[1]
        if seqlen > self.config.model.length:
            assert seqlen == 2 * self.config.model.length
            start = np.random.choice(self.config.model.length)
            end = start + self.config.model.length
            input_tokens = x0[:, start: end]
            output_tokens = x0[:, start + 1: end + 1]
            new_attention_mask = attention_mask[:, start: end]
            if self.tokenizer.bos_token_id is not None:
                input_tokens[:, 0] = self.tokenizer.bos_token_id
            if self.tokenizer.eos_token_id is not None:
                output_tokens[:, -1] = self.tokenizer.eos_token_id
        elif self.parameterization == 'ar':
            input_tokens = x0[:, :-1]
            output_tokens = x0[:, 1:]
            new_attention_mask = attention_mask[:, :-1] if attention_mask.shape == x0.shape else attention_mask[:, 1:]
        else:
            input_tokens = x0
            output_tokens = x0
            new_attention_mask = attention_mask
        return input_tokens, output_tokens, new_attention_mask        
        
    def _forward_pass_diffusion(self, x0, image_features=None):
    
        # Step 1: Sample random timestep t for each batch element    
        t = self._sample_t(x0.shape[0], x0.device)
        
        if self.T > 0:
            t_int = (t * self.T).to(torch.int)
            t = t_int / self.T
            t = torch.clamp(t + (1 / self.T), max=1.0)

        if self.change_of_variables:
            unet_conditioning = t.unsqueeze(-1)
            f_T = torch.log1p(-torch.exp(-self.noise.sigma_max))
            f_0 = torch.log1p(-torch.exp(-self.noise.sigma_min))
            move_chance = torch.exp(f_0 + t * (f_T - f_0)).unsqueeze(-1)
        else:
            # Step 2: Get noise level at time t
            sigma, dsigma = self.noise(t) # sigma controls how much masking
            unet_conditioning = sigma.unsqueeze(-1)
            # Step 3: Create noisy version by masking tokens
            move_chance = (1 - torch.exp(-sigma)).unsqueeze(-1)

        xt = self.q_xt(x0, move_chance) # Replace tokens with [MASK] based on move_chance

        # Step 4: Model predicts clean tokens from noisy input
        model_output = self.forward(xt, unet_conditioning, image_features=image_features)
        
        if self.training and hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
            eos_predictions = (model_output.argmax(dim=-1) == self.tokenizer.eos_token_id).float().mean()
            self.log('train/eos_prediction_rate', eos_predictions, on_step=True)        
        
        utils.print_nans(model_output, 'model_output')

        if self.T > 0:
            log_p_theta_x0 = torch.gather(
                input=model_output, dim=-1, index=x0.unsqueeze(-1)
            ).squeeze(-1)
            diffusion_loss = -log_p_theta_x0

            if self.parameterization == 'subs':
                reconstruction_loss_val = 0
            return reconstruction_loss_val + diffusion_loss
            
        # Step 5: Compute loss between predicted and true clean tokens
        log_p_theta_x0 = torch.gather(
            input=model_output, dim=-1, index=x0.unsqueeze(-1)
        ).squeeze(-1)
        
        if self.change_of_variables or self.importance_sampling:
            return -log_p_theta_x0 * torch.log1p(-torch.exp(-self.noise.sigma_min))

        weighting = (dsigma / torch.expm1(sigma)).unsqueeze(-1) if not self.change_of_variables else 1.0
        
        # Penalize very short sequences during training
        if self.training:
            sequence_lengths = (x0 != self.tokenizer.pad_token_id).sum(dim=1).float()
            min_length_penalty = torch.relu(5.0 - sequence_lengths).mean() * 0.1
            return -log_p_theta_x0 * weighting + min_length_penalty
        
        return -log_p_theta_x0 * weighting # Negative log likelihood      
        
        
    def _sample_t(self, n, device):
        _eps_t = torch.rand(n, device=device)
        if self.antithetic_sampling:
            offset = torch.arange(n, device=device) / n
            _eps_t = (_eps_t / n + offset) % 1
        t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
        if self.importance_sampling:
            return self.noise.importance_sampling_transformation(t)
        return t  
                
    # ==================== CHECKPOINTS =========================
    
    def on_load_checkpoint(self, checkpoint):
        if hasattr(self, "ema") and self.ema:
            self.ema.load_state_dict(checkpoint['ema'])
        self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed']
        self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed']

    def on_save_checkpoint(self, checkpoint):
        if hasattr(self, "ema") and self.ema:
            checkpoint['ema'] = self.ema.state_dict()
        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.automatic_optimization.optim_progress']['optimizer']['step']['total']['completed'] * self.trainer.accumulate_grad_batches
        checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.automatic_optimization.optim_progress']['optimizer']['step']['current']['completed'] * self.trainer.accumulate_grad_batches
        checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped'] = \
            checkpoint['loops']['fit_loop']['epoch_loop.automatic_optimization.optim_progress']['optimizer']['step']['total']['completed']
        if 'sampler' not in checkpoint.keys():
            checkpoint['sampler'] = {}
        if hasattr(self.trainer.train_dataloader.sampler, 'state_dict'):
            sampler_state_dict = self.trainer.train_dataloader.sampler.state_dict()
            checkpoint['sampler']['random_state'] = sampler_state_dict.get('random_state', None)
        else:
            checkpoint['sampler']['random_state'] = None
        
        
    # ==================== VALIDATION =========================
    
    def on_validation_epoch_start(self):
        if hasattr(self, "ema") and self.ema:
            self.ema.store(self._ema_params)
            self.ema.copy_to(self._ema_params)
            
        self.backbone.eval()
        self.noise.eval()
        
        if self.use_image_conditioning and self.image_encoder is not None:
            self.image_encoder.eval()
    
        self.valid_metrics.reset()
        self.validation_samples = []  # Store samples for quality assessment
        
        # Test image conditioning effectiveness
        if self.current_epoch % 5 == 0 and self.current_epoch > 0:
            self._test_image_conditioning_effectiveness()
        

    def validation_step(self, batch, batch_idx):
    
        # Compute validation loss and collect samples
        images = batch.get("images", None)
        raw_img_features = None
        
        if self.use_image_conditioning and images is not None and self.image_encoder is not None:
            raw_img_features = self.image_encoder(images)
            
            # Store validation samples for quality assessment (first few batches only)
            if batch_idx < 3 and len(self.validation_samples) < 12:
               captions = batch.get('caption', [''] * len(images))
               for i in range(min(4, len(images))):
                   self.validation_samples.append({
                        'image': images[i].clone(),
                        'ground_truth': captions[i]
                   })
                     
        val_loss = self._compute_loss(batch, prefix='valid', raw_image_features_from_step=raw_img_features)
        self.val_loss_history.append((self.global_step, val_loss.item()))
        return val_loss

    def on_validation_epoch_end(self):        
        should_sample = (
            self.current_epoch % 3 == 0 and  # Every 3 epochs only
            self.current_epoch > 2 and       # After initial training
            not self.trainer.sanity_checking and
            len(self.validation_samples) > 0
        )
        
        if should_sample and self.config.eval.get("generate_samples", True):
            self._evaluate_image_text_quality()
        
        # Restore training weights
        if hasattr(self, "ema") and self.ema:
            self.ema.restore(self._ema_params)
            
    def _evaluate_image_text_quality(self):    
        # Evaluate actual image-text generation quality
        try:
            validation_steps = 25
            
            # Sample from a few validation images
            test_samples = self.validation_samples[:6]  # Just 6 images
            images = torch.stack([s['image'] for s in test_samples]).to(self.device)
            ground_truths = [s['ground_truth'] for s in test_samples]
            
            # Generate captions
            generated_tokens = self.sample_with_image_conditioning(
                images=images,
                num_steps=validation_steps,
                eps=1e-3
            )
            
            generated_texts = self.tokenizer.batch_decode(
                generated_tokens, skip_special_tokens=True
            )
            
            # Compute meaningful metrics
            diffusion_utils.log_generation_quality(generated_texts, ground_truths, self.log)
            
        except Exception as e:
            print(f"Validation sampling failed: {e}")
    
                
    def _test_image_conditioning_effectiveness(self):
        # Test if image conditioning actually works in the subs parameterization
        if self.parameterization != 'subs':
            return
        
        print(f"\nTesting Image Conditioning (Epoch {self.current_epoch})")
        
        # Create test data
        batch_size = 2
        seq_len = 10
        x_with_masks = torch.full((batch_size, seq_len), self.mask_index, device=self.device)
        sigma_test = torch.tensor([0.5, 0.5], device=self.device)
        
        # Get some real image features from validation data
        test_images = None
        if hasattr(self, 'validation_images_for_sampling') and self.validation_images_for_sampling is not None:
            test_images = self.validation_images_for_sampling[:batch_size]
            image_features = self.image_encoder(test_images)
        else:
            # Create dummy image features
            feature_dim = self.image_encoder.output_dim if self.image_encoder else 512
            image_features = torch.randn(batch_size, feature_dim, device=self.device)
        
        with torch.no_grad():
            # Test 1: Image conditioning effect
            score_with_img = self.get_score(x_with_masks, sigma_test, image_features)
            score_without_img = self.get_score(x_with_masks, sigma_test, None)
            
            score_diff = (score_with_img - score_without_img).abs().mean()
            print(f"Score difference with/without images: {score_diff:.6f}")
            
            # Test 2: Check unmasked token behavior
            x_mixed = torch.tensor([[1, 2, self.mask_index, 4, self.mask_index]], device=self.device)
            if x_mixed.shape[0] == 1:
                scores_mixed = self.get_score(x_mixed, sigma_test[:1], image_features[:1])
                
                print("Unmasked token probabilities (should be ~1.0):")
                print(f"  Token 1 at pos 0: {scores_mixed[0, 0, 1]:.6f}")
                print(f"  Token 2 at pos 1: {scores_mixed[0, 1, 2]:.6f}")
                print(f"  Token 4 at pos 3: {scores_mixed[0, 3, 4]:.6f}")
                
                print("Masked positions (should have distributed probabilities):")
                mask_probs_pos2 = scores_mixed[0, 2, :5]  # First 5 vocab tokens
                mask_probs_pos4 = scores_mixed[0, 4, :5]
                print(f"  Pos 2 (masked): {mask_probs_pos2}")
                print(f"  Pos 4 (masked): {mask_probs_pos4}")
        
        print("End Image Conditioning Test\n")
    
    
    # ==================== SAMPLING ===========================
    @torch.no_grad()
    def sample_with_image_conditioning(self, images, num_steps=None, eps=1e-5):    
        # Generate text conditioned on images using diffusion sampling  
        self.backbone.eval()
        
        if self.use_image_conditioning and self.image_encoder is not None:
            self.image_encoder.eval()
        
        # Store and apply EMA weights if available
        if hasattr(self, "ema") and self.ema:
            self.ema.store(self._ema_params)
            self.ema.copy_to(self._ema_params)     
            
        batch_size = images.shape[0] if images is not None else self.config.loader.get("eval_batch_size", 4)

        # Extract image features once at the beginning
        image_features = None
        if self.use_image_conditioning and self.image_encoder is not None and images is not None:
            image_features = self.image_encoder(images.to(self.device))    
            
        # Start from pure noise (all MASK tokens)
        x = self._sample_prior(batch_size, self.config.model.length).to(self.device)        
        
        # Determine number of sampling steps
        if num_steps is None:
            num_steps = self.T if self.T > 0 else self.config.sampling.get("steps", 1000)
            
        # Create timestep schedule        
        timesteps_lin = torch.linspace(1.0, eps, num_steps + 1, device=self.device)
        dt = (1.0 - eps) / num_steps
        p_x0_cache = None        
        
        # Iterative denoising
        for i in range(num_steps):
            t_current = timesteps_lin[i] * torch.ones(x.shape[0], device=self.device)
            t_for_update = t_current.unsqueeze(-1)

            if self.sampler == 'ddpm':
                x = self._ddpm_update(x, t_for_update, dt, image_features)
            elif self.sampler == 'ddpm_cache':
                p_x0_cache, x_next = self._ddpm_caching_update(
                    x, t_for_update, dt, image_features, p_x0_cache
                )
                if not torch.allclose(x_next, x) or self.time_conditioning:
                    p_x0_cache = None
                x = x_next
            elif self.sampler == 'analytic':
                x = self._analytic_update(x, t_for_update, dt, image_features)
            else:
                raise ValueError(f"Unknown sampler: {self.sampler}")
                
        # Final denoising step if enabled
        if self.config.sampling.get("noise_removal", False):
            t_final = timesteps_lin[-1] * torch.ones(x.shape[0], device=self.device).unsqueeze(-1)
            if self.sampler == 'analytic' or self.config.sampling.get("denoiser_for_noise_removal", True):
                x = self._denoiser_update(x, t_final, image_features)
            else:
                sigma_final, _ = self.noise(t_final.squeeze(-1))
                final_model_output = self.forward(x, sigma_final, image_features)
                x = final_model_output.argmax(dim=-1)

        # Restore original model state
        if hasattr(self, "ema") and self.ema:
            self.ema.restore(self._ema_params)
        
        # Return to training mode
        self.backbone.train()
        if self.use_image_conditioning and self.image_encoder is not None:
            self.image_encoder.train()
        
        return x   
        
    def _sample_prior(self, *batch_dims):
        return self.mask_index * torch.ones(*batch_dims, dtype=torch.long, device=self.device)
        
    def _analytic_update(self, x, t, step_size, image_features=None):
        curr_sigma, _ = self.noise(t)
        next_sigma, _ = self.noise(t - step_size)
        dsigma = curr_sigma - next_sigma

        score = self.get_score(x, curr_sigma, image_features)
        stag_score = self._staggered_score(score, dsigma)
        probs = stag_score * self._transp_transition(x, dsigma)
        return diffusion_utils.sample_categorical(probs) 
        
    def get_score(self, x, sigma, image_features=None):
        model_output = self.forward(
            x, sigma,
            image_features=image_features
        )
        if self.parameterization == 'subs':
            """ 
                The 'subs' parameterization is designed for substitution-based diffusion where tokens get replaced with MASK tokens during the forward process. 
                The analytic sampler then tries to predict the original tokens.
            """
            # Base computation
            log_k_val = -torch.log(torch.expm1(sigma.squeeze(-1) if sigma.ndim > 1 else sigma)) # Related to noise level
            assert log_k_val.ndim == 1
            
            # For masked positions
            masked_score = model_output + log_k_val[:, None, None]
            masked_score[:, :, self.mask_index] = 0 # MASK token gets score 0. This gives higher scores to tokens the moel thinks should replace the MASK
            
            # For unmasked positions: Forces already revealed tokens to stay unchanged
            unmasked_score = self.neg_infinity * torch.ones_like(model_output) # Block most tokens
            unmasked_score = torch.scatter(unmasked_score, -1, x[..., None], torch.zeros_like(unmasked_score[..., :1])) # Allow current token
            
            
            unmasked_score_mask_val = -(log_k_val[:, None] * torch.ones_like(x, dtype=log_k_val.dtype))
            unmasked_score[:, :, self.mask_index] = unmasked_score_mask_val
            masked_indices_bool = (x == self.mask_index)
            masked_indices = masked_indices_bool.to(model_output.dtype)[:, :, None]
            
            # Combine based on mask status
            model_output = (masked_score * masked_indices + unmasked_score * (1 - masked_indices))
            return model_output.exp() # Convert log-scores to probabilities
            
        return F.softmax(model_output, dim=-1)
        
    def _staggered_score(self, score, dsigma):
        score = score.clone()
        if dsigma.ndim == 1:
            dsigma_exp = dsigma.exp().unsqueeze(-1).unsqueeze(-1)  # (B,1,1)
        elif dsigma.ndim == 2:
            dsigma_exp = dsigma.exp().unsqueeze(-1)  # (B,1,1)
        else:
            raise ValueError(f"Unexpected dsigma shape: {dsigma.shape}")

        score_sum = score.sum(dim=-1, keepdim=True)  # (B,S,1)
        extra_const = (1 - dsigma_exp) * score_sum
        score_updated = score * dsigma_exp

        score_updated[..., self.mask_index:self.mask_index + 1] = score_updated[..., self.mask_index:self.mask_index + 1] + extra_const
        return score_updated
        
    def _transp_transition(self, i, sigma): 
        sigma_unsqueezed = diffusion_utils.unsqueeze(sigma, reference=i[..., None])
        edge = torch.exp(-sigma_unsqueezed) * F.one_hot(i, num_classes=self.vocab_size)
        edge += torch.where(i == self.mask_index, 1 - torch.exp(-sigma_unsqueezed).squeeze(-1), 0)[..., None]
        return edge
    
    def _validate_configuration(self): 
        assert not (self.change_of_variables and self.importance_sampling)
        if self.parameterization == 'sedd':
            assert not self.importance_sampling
            assert not self.change_of_variables
        if self.parameterization == 'd3pm':
            assert self.T > 0
        if self.T > 0:
            assert self.parameterization in {'d3pm', 'subs'}
        if self.subs_masking:
            assert self.parameterization == 'd3pm'
    
    def _subs_parameterization(self, logits, xt):
        logits[:, :, self.mask_index] += self.neg_infinity
        logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
        unmasked_indices = (xt != self.mask_index)
        logits[unmasked_indices] = self.neg_infinity
        logits[unmasked_indices, xt[unmasked_indices]] = 0
        return logits
   
    def _ddpm_caching_update(self, x, t, dt,
                                        image_features=None,
                                        p_x0=None):
        assert self.config.noise.type == 'loglinear'
        sigma_t, _ = self.noise(t)
        if t.ndim > 1:
            t = t.squeeze(-1)
        assert t.ndim == 1
        move_chance_t = t[:, None, None]
        move_chance_s = (t - dt)[:, None, None]
        assert move_chance_t.ndim == 3, move_chance_t.shape

        if p_x0 is None:
            p_x0 = self.forward(
                x, sigma_t,
                image_features=image_features
            ).exp()
            
        if p_x0.ndim != move_chance_t.ndim:
            _mc_t = move_chance_t.squeeze(-1) if p_x0.ndim == move_chance_t.ndim - 1 else move_chance_t
            _mc_s = move_chance_s.squeeze(-1) if p_x0.ndim == move_chance_s.ndim - 1 else move_chance_s
        else:
            _mc_t, _mc_s = move_chance_t, move_chance_s

        q_xs = p_x0 * (_mc_t - _mc_s)
        if _mc_s.ndim > q_xs.ndim - 1:
            q_xs[:, :, self.mask_index] = _mc_s.squeeze(-1)
        else:
            q_xs[:, :, self.mask_index] = _mc_s

        _x = diffusion_utils.sample_categorical(q_xs)
        copy_flag = (x != self.mask_index).to(x.dtype)
        return p_x0, copy_flag * x + (1 - copy_flag) * _x        
        
    def _denoiser_update(self, x, t, image_features=None):
        sigma, _ = self.noise(t)
        score = self.get_score(x, sigma, image_features)
        stag_score = self._staggered_score(score, sigma)  # Pass sigma as dsigma here
        probs = stag_score * self._transp_transition(x, sigma)
        probs[..., self.mask_index] = 0
        return diffusion_utils.sample_categorical(probs)      


    def _ddpm_update(self, x, t, dt, image_features=None):
        # Simplified DDPM update
        sigma_t, _ = self.noise(t.squeeze(-1))
        sigma_s, _ = self.noise((t - dt).squeeze(-1))

        move_chance_t = (1 - torch.exp(-sigma_t)).unsqueeze(-1).unsqueeze(-1)
        move_chance_s = (1 - torch.exp(-sigma_s)).unsqueeze(-1).unsqueeze(-1)

        log_p_x0 = self.forward(x, sigma_t, image_features=image_features)

        q_xs = log_p_x0.exp() * (move_chance_t - move_chance_s)
        q_xs[:, :, self.mask_index] = move_chance_s.squeeze(-1).squeeze(-1)

        _x = diffusion_utils.sample_categorical(q_xs)
        copy_flag = (x != self.mask_index).to(x.dtype)
        return copy_flag * x + (1 - copy_flag) * _x
      

    # ==================== OPTIMIZER ===========================     
    def optimizer_step(self, *args, **kwargs):
        super().optimizer_step(*args, **kwargs)
        if hasattr(self, "ema") and self.ema:
            self.ema.update(self._ema_params)
            
    def configure_optimizers(self):
        # Single parameter group for backbone + noise + image_encoder
        params = list(self.backbone.parameters()) + list(self.noise.parameters())
        
        if self.use_image_conditioning and self.image_encoder is not None:
            params += list(self.image_encoder.parameters())

        optimizer = torch.optim.AdamW(
            params, 
            lr=self.config.optim.lr,
            betas=(self.config.optim.beta1, self.config.optim.beta2),
            eps=self.config.optim.eps,
            weight_decay=self.config.optim.weight_decay
        )
        
        scheduler_config = self.config.get('lr_scheduler', None)
        if scheduler_config:
            scheduler = hydra.utils.instantiate(scheduler_config, optimizer=optimizer)
            scheduler_dict = {
                'scheduler': scheduler, 
                'interval': scheduler_config.get('interval', 'step'),
                'monitor': scheduler_config.get('monitor', 'val/loss'), 
                'name': 'lr_scheduler'
            }
            return [optimizer], [scheduler_dict]
        return optimizer

    # ==================== BACKWARD ===========================       
    def on_after_backward(self):
        if self.use_image_conditioning and self.training:
            if self.global_step % 100 == 0:
                base_log_name = f"grad_norm_step_{self.global_step}"

                if self.image_encoder is not None and hasattr(self.image_encoder, 'encoder') and hasattr(self.image_encoder.encoder, 'conv1'):
                    if getattr(self.image_encoder.encoder.conv1, 'weight', None) is not None:
                        if self.image_encoder.encoder.conv1.weight.grad is not None:
                            logger.info(f"{base_log_name}/image_encoder_conv1")
                        else:
                            logger.warning(f"DEBUG: {base_log_name}/image_encoder_conv1 grad is None")

                if hasattr(self.backbone, 'image_embedder') and \
                        isinstance(self.backbone.image_embedder, nn.Sequential) and \
                        len(self.backbone.image_embedder) > 0 and \
                        hasattr(self.backbone.image_embedder[0], 'weight'):
                    param_to_log_grad = self.backbone.image_embedder[0].weight
                    if param_to_log_grad.grad is not None:
                        logger.info(f"{base_log_name}/dit_internal_image_embed_weight: {param_to_log_grad.grad.norm()}")
                    else:
                        logger.warning(f"DEBUG: {base_log_name}/dit_internal_image_embed_weight grad is None")
    

    def _save_tracking_images(self):
        # Denormalize CLIP images for visualization
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
        
        for i, img_tensor in enumerate(self.tracking_images):
            # Move to CPU and denormalize
            img_cpu = img_tensor.cpu()
            img_denorm = img_cpu * std + mean
            img_denorm = torch.clamp(img_denorm, 0, 1)
            
            # Convert to PIL and save
            img_pil = transforms.ToPILImage()(img_denorm)
            img_path = os.path.join(self.tracking_log_dir, "images", f"tracking_image_{i}.png")
            img_pil.save(img_path)
            
        # Save original captions
        captions_path = os.path.join(self.tracking_log_dir, "original_captions.json")
        with open(captions_path, 'w') as f:
            json.dump({
                f"image_{i}": caption for i, caption in enumerate(self.tracking_captions)
            }, f, indent=2)

    @torch.no_grad()
    def log_generation_progress(self):
        # Generate text for tracking images and log the results
        if self.tracking_images is None:
            return            
        print(f"Logging generation progress at step {self.global_step}...")
        
        # Set model to eval mode
        was_training = self.training
        self.eval()
        
        if hasattr(self, "ema") and self.ema:
            self.ema.store(self._ema_params)
            self.ema.copy_to(self._ema_params)
        
        try:
            # Generate text for tracking images
            generated_tokens = self.sample_with_image_conditioning(
                images=self.tracking_images,
                num_steps=50,
                eps=1e-3
            )
            
            # Decode generated text
            generated_texts = self.tokenizer.batch_decode(
                generated_tokens, 
                skip_special_tokens=True
            )
            
            # Create log entry
            log_entry = {
                "step": self.global_step,
                "epoch": self.current_epoch,
                "timestamp": datetime.now().isoformat(),
                "generations": {}
            }
            
            # Save individual generation results
            for i, (original_caption, generated_text) in enumerate(zip(self.tracking_captions, generated_texts)):
                log_entry["generations"][f"image_{i}"] = {
                    "original_caption": original_caption,
                    "generated_text": generated_text,
                    "length": len(generated_text.split())
                }
            
            # Add to tracking log
            self.tracking_log.append(log_entry)
            
            # Save updated log
            log_path = os.path.join(self.tracking_log_dir, "generation_log.json")
            with open(log_path, 'w') as f:
                json.dump(self.tracking_log, f, indent=2)
            
            # Create and save combined visualization
            self._create_combined_visualization(log_entry)
            
            # Log to wandb/tensorboard if available
            if self.logger:
                self._log_to_experiment_tracker(log_entry)
            
        except Exception as e:
            print(f"Error during generation logging: {e}")
        
        finally:
            # Restore model state
            if hasattr(self, "ema") and self.ema:
                self.ema.restore(self._ema_params)
            
            if was_training:
                self.train()

    def _create_combined_visualization(self, log_entry):
        # Create a combined image + text visualization
        step = log_entry["step"]
        
        # Create figure with subplots
        fig, axes = plt.subplots(2, len(self.tracking_images), 
                                figsize=(5 * len(self.tracking_images), 8))
        
        if len(self.tracking_images) == 1:
            axes = axes.reshape(2, 1)
        
        # Denormalization constants for CLIP
        mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(3, 1, 1)
        std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(3, 1, 1)
        
        for i in range(len(self.tracking_images)):
            # Plot image
            img_cpu = self.tracking_images[i].cpu()
            img_denorm = img_cpu * std + mean
            img_denorm = torch.clamp(img_denorm, 0, 1)
            img_np = img_denorm.permute(1, 2, 0).numpy()
            
            axes[0, i].imshow(img_np)
            axes[0, i].set_title(f"Image {i}")
            axes[0, i].axis('off')
            
            # Add text information
            original = log_entry["generations"][f"image_{i}"]["original_caption"]
            generated = log_entry["generations"][f"image_{i}"]["generated_text"]
            
            text_info = f"Original:\n{original[:100]}{'...' if len(original) > 100 else ''}\n\n"
            text_info += f"Generated (Step {step}):\n{generated[:100]}{'...' if len(generated) > 100 else ''}"
            
            axes[1, i].text(0.05, 0.95, text_info, 
                           transform=axes[1, i].transAxes,
                           verticalalignment='top',
                           fontsize=8,
                           wrap=True)
            axes[1, i].set_xlim(0, 1)
            axes[1, i].set_ylim(0, 1)
            axes[1, i].axis('off')
        
        plt.tight_layout()
        
        # Save the combined visualization
        save_path = os.path.join(self.tracking_log_dir, "combined", f"step_{step:06d}.png")
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        # Also save individual generation files
        gen_dir = os.path.join(self.tracking_log_dir, "generations")
        with open(os.path.join(gen_dir, f"step_{step:06d}.json"), 'w') as f:
            json.dump(log_entry, f, indent=2)

    def _log_to_experiment_tracker(self, log_entry):
        # Log to wandb/tensorboard if available
        step = log_entry["step"]
        
        # Log generation examples as text
        for i, gen_data in log_entry["generations"].items():
            # Log text lengths
            self.log(f"generation/text_length_{i}", gen_data["length"], 
                    on_step=True, on_epoch=False)
            
            # Log to wandb table if available
            if hasattr(self.logger, 'log_table'):
                self.logger.log_table(
                    key=f'generation_tracking_step_{step}',
                    columns=['Image', 'Original Caption', 'Generated Text', 'Step'],
                    data=[[i, gen_data['original_caption'], gen_data['generated_text'], step]]
                ) 
