import os
import argparse
from functools import partial
import torch
from torch.amp import autocast
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import timm
from torch.optim import AdamW
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset, NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES, TREATMENT_LABELS


# --- Constants ---
TREATMENT_NAMES = list(TREATMENT_LABELS.keys())
DELTA_T_MEAN = 35.0731
DELTA_T_STD = 22.9570

torch.set_float32_matmul_precision('high')

# --- Configuration via ArgParse ---
def parse_args():
    parser = argparse.ArgumentParser(description="Train Temporal Adversarial Diffusion Model with RNN Discriminator")
    # --- Mode ---
    parser.add_argument("--test_only", action="store_true", help="Run in testing mode only")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint (.ckpt) for testing")
    parser.add_argument("--results_file", type=str, default="test_results.txt", help="File to save test metrics")
    # Paths
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    parser.add_argument("--base_checkpoint_dir", type=str, default="/local/acc/OAI_checkpoints", help="Base dir for saving models")
    parser.add_argument("--log_dir_base", type=str, default="/local/acc/OAI_checkpoints/tb_logs", help="Base dir for TensorBoard logs")
    # Model Params
    parser.add_argument("--diffusion_model_pretrain", type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument("--feature_predictor_model", type=str, default='efficientformerv2_l.snap_dist_in1k')
    # Training Params
    parser.add_argument("--lr_generator", type=float, default=1e-5, help="Learning rate for diffusion generator")
    parser.add_argument("--lr_discriminator", type=float, default=1e-4, help="Learning rate for treatment discriminator")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--max_epochs", type=int, default=50)
    parser.add_argument("--feature_loss_weight", type=float, default=0, help="Weight for perceptual feature loss")
    parser.add_argument("--precision", type=str, default="16-mixed", choices=["16-mixed", "bf16-mixed", "32-true"])
    parser.add_argument("--devices", type=int, nargs='+', default=[0,1], help="GPU devices to use")
    parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="Value for gradient clipping (0 to disable)")

    # --- Adversarial Params ---
    parser.add_argument("--adversarial_weight", type=float, default=0.1, help="Weight for adversarial loss") # Adjusted default
    # --- Discriminator Architecture (RNN-based) ---
    parser.add_argument("--disc_model_type", type=str, choices=["rnn", "transformer"], default="rnn", help="Discriminator sequence model type")
    parser.add_argument("--disc_hidden_dim", type=int, default=128, help="Discriminator RNN/Transformer hidden dim")
    parser.add_argument("--disc_num_layers", type=int, default=2, help="Discriminator RNN/Transformer num layers")
    parser.add_argument("--disc_dropout", type=float, default=0.2, help="Discriminator RNN/Transformer dropout")
    parser.add_argument("--disc_delta_t_feat_dim", type=int, default=8, help="Discriminator projected delta_t feature dim")
    parser.add_argument("--disc_side_feat_dim", type=int, default=4, help="Discriminator projected side feature dim")
    parser.add_argument("--disc_current_img_cond_dim", type=int, default=628, help="Dim of img_cond from current input image (usually text_encoder.config.hidden_size)")
    parser.add_argument("--use_disc_pos_weight", action="store_true", help="Use class weight for the discriminator based on training label prevalence.")
    # Other
    parser.add_argument("--run_suffix", type=str, default="a", help="Suffix for run name")

    args = parser.parse_args()
    if args.test_only:
        if not args.ckpt_path or not os.path.exists(args.ckpt_path):
            parser.error("--ckpt_path is required and must exist for --test_only mode.")
        args.run_name = f"TEST_{os.path.basename(args.ckpt_path)}"
    else:
        feat_str = '_'.join(sorted(PERCEPTUAL_FEATURES))
        args.run_name = f"{args.run_suffix}_flw{args.feature_loss_weight}_lrg{args.lr_generator}_lrd{args.lr_discriminator}_aw{args.adversarial_weight}_pw{args.use_disc_pos_weight}_gc{args.gradient_clip_val}_b{args.batch_size}x{len(args.devices)}_dnl{args.disc_num_layers}_dhd{args.disc_hidden_dim}_icd{args.disc_current_img_cond_dim}"
        args.checkpoint_dir = os.path.join(args.base_checkpoint_dir, args.run_name)
        args.log_dir = os.path.join(args.log_dir_base, args.run_name)
        os.makedirs(args.checkpoint_dir, exist_ok=True)
    return args

# --- Image Conditioning Module ---
class ImageConditioning(nn.Module):
    def __init__(self, in_channels=4, hidden_dim=768): # hidden_dim matches text_encoder
        super().__init__(); self.linear = nn.Linear(in_channels, hidden_dim)
    def forward(self, latent): return self.linear(latent.mean(dim=[2, 3]))

class ContextEncoderRNN(nn.Module):
    def __init__(self, args, current_img_cond_dim, cov_hist_dim, num_treatments_hist):
        super().__init__()
        self.model_type = args.disc_model_type
        self.delta_t_feat_dim = args.disc_delta_t_feat_dim
        self.side_feat_dim = args.disc_side_feat_dim

        seq_input_dim = cov_hist_dim + num_treatments_hist

        if self.model_type.lower() == "rnn":
            self.hist_encoder = nn.LSTM(seq_input_dim, args.disc_hidden_dim, args.disc_num_layers,
                                        batch_first=True, dropout=args.disc_dropout if args.disc_num_layers > 1 else 0)
            self.hist_encoder_output_dim = args.disc_hidden_dim
        else: raise ValueError(f"Unsupported context encoder model type: {self.model_type}")

        self.delta_t_processor = nn.Sequential(nn.Linear(1, self.delta_t_feat_dim), nn.ReLU())
        self.side_processor = nn.Sequential(nn.Linear(1, self.side_feat_dim), nn.ReLU())
        
        # Output dimension of this context encoder
        self.output_dim = self.hist_encoder_output_dim + \
                          current_img_cond_dim + \
                          self.delta_t_feat_dim + \
                          self.side_feat_dim

    def forward(self, img_cond_current, # From current X_t
                cov_seq_hist, trt_seq_hist, lengths_hist, # Historical data
                delta_t, side, image_seq_hist=None, # Context for current interval
                delta_t_mean=0.0, delta_t_std=1.0):
        
        combined_hist_seq = torch.cat([cov_seq_hist, trt_seq_hist], dim=2)
        packed_input = pack_padded_sequence(combined_hist_seq, lengths_hist.cpu(), batch_first=True, enforce_sorted=False)
        self.hist_encoder.flatten_parameters()
        _, (h_n, _) = self.hist_encoder(packed_input)
        hist_context_vector = h_n[-1]

        # Process delta_t and side 
        delta_t_norm = (delta_t - delta_t_mean) / (delta_t_std + 1e-8)
        delta_t_features = self.delta_t_processor(delta_t_norm)
        side_features = self.side_processor(side)

        if img_cond_current.ndim == 3: img_cond_current = img_cond_current.mean(dim=1)

        # Concatenate features that form the rich context
        rich_context_vector = torch.cat([
            hist_context_vector,
            img_cond_current, # Features from current image X_t
            delta_t_features,
            side_features
        ], dim=1)
        return rich_context_vector

class TreatmentDiscriminator(nn.Module):
    def __init__(self, args, input_rich_context_dim): # Input is the output dim of ContextEncoderRNN
        super().__init__()
        self.num_treatments_output = NUM_TREATMENTS
        # Simple MLP head
        self.treatment_head = nn.Sequential(
            nn.Linear(input_rich_context_dim, args.disc_hidden_dim), # Use disc_hidden_dim from args
            nn.ReLU(),
            nn.Dropout(args.disc_dropout), # Use disc_dropout from args
            nn.Linear(args.disc_hidden_dim, self.num_treatments_output)
        )
        print(f"Simple TreatmentDiscriminator input dim: {input_rich_context_dim}, output logits: {self.num_treatments_output}")

    def forward(self, rich_context_vector):
        return self.treatment_head(rich_context_vector)

# --- Diffusion Lightning Module ---
class AdversarialDiffusionLitModule(pl.LightningModule):
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 delta_t_mean, delta_t_std, discriminator_pos_weight=None):
        super().__init__()
        self.save_hyperparameters(args, ignore=['vae', 'unet', 'scheduler', 'text_encoder', 'tokenizer',
                                                 'img_conditioning', 'feature_predictors', 'context_encoder']) # Add context_encoder
        self.vae = vae; self.unet = unet; self.scheduler = scheduler
        self.text_encoder = text_encoder; self.tokenizer = tokenizer
        self.img_conditioning = img_conditioning # For current X_t's latent
        self.context_encoder = ContextEncoderRNN(
            args=args,
            current_img_cond_dim=args.disc_current_img_cond_dim,
            cov_hist_dim=FEATURE_DIM,      
            num_treatments_hist=NUM_TREATMENTS
        )
        # U-Net's cross-attention dimension
        unet_cross_attn_dim = text_encoder.config.hidden_size
        
        if self.context_encoder.output_dim != unet_cross_attn_dim:
            self.rich_context_projector = nn.Linear(self.context_encoder.output_dim, unet_cross_attn_dim)
        else:
            self.rich_context_projector = nn.Identity()

        # Discriminator now takes the output of context_encoder
        self.treatment_discriminator = TreatmentDiscriminator( # Using the new simple MLP discriminator
            args=args,
            input_rich_context_dim=self.context_encoder.output_dim
        )
        self.discriminator_pos_weight = discriminator_pos_weight 
        if self.discriminator_pos_weight is not None:
             self.register_buffer("disc_pos_weight_buffer", self.discriminator_pos_weight)
             self.criterion_discriminator = nn.BCEWithLogitsLoss(pos_weight=self.disc_pos_weight_buffer)
        else:
             # If no weights provided (e.g., calculation failed)
             self.criterion_discriminator = nn.BCEWithLogitsLoss()

        # Store params from args
        self.latent_scale = 0.18215 # Standard SD scale
        self.lr_generator = args.lr_generator
        self.lr_discriminator = args.lr_discriminator
        self.adversarial_weight = args.adversarial_weight # lambda
        self.feature_loss_weight = args.feature_loss_weight

        # Store delta_t normalization params for discriminator
        self.delta_t_mean = delta_t_mean
        self.delta_t_std = delta_t_std
        
        self._freeze_components()
        self.transform_to_pil = T.ToPILImage()
        self.feature_predictors = nn.ModuleDict()
        self._load_feature_predictors()
        self.pipe = None
        self.automatic_optimization = False
        self._test_outputs = []

    def _freeze_components(self):
        if hasattr(self, 'vae'): self.vae.eval(); [p.requires_grad_(False) for p in self.vae.parameters()]
        if hasattr(self, 'text_encoder'): self.text_encoder.eval(); [p.requires_grad_(False) for p in self.text_encoder.parameters()]
        # Ensure feature predictors remain frozen
        if hasattr(self, 'feature_predictors'):
            for predictor in self.feature_predictors.values():
                predictor.eval(); [p.requires_grad_(False) for p in predictor.parameters()]
        print("VAE, Text Encoder, and Feature Predictors frozen.")

    def _load_feature_predictors(self):
        print("Loading and freezing feature predictors...")
        for feat in PERCEPTUAL_FEATURES:
            model_path = os.path.join("checkpoints", feat, f"{feat}_model.pth")
            state_dict = torch.load(model_path, map_location="cpu")
            num_outputs = state_dict['head.weight'].shape[0]
            predictor = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True)
            predictor.reset_classifier(num_outputs)
            predictor.load_state_dict(state_dict)
            predictor.eval()
            for param in predictor.parameters():
                param.requires_grad = False
            self.feature_predictors[feat] = predictor

    def toggle_optimizer(self, optimizer):
        optimizers = self.optimizers()
        is_opt_d = optimizer == optimizers[1] # True if discriminator optimizer is active
        gen_params = list(self.unet.parameters()) + \
                     list(self.img_conditioning.parameters()) + \
                     list(self.context_encoder.parameters())
        if hasattr(self, 'rich_context_projector') and not isinstance(self.rich_context_projector, nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
            
        for param in gen_params: param.requires_grad_(not is_opt_d)
        for param in self.treatment_discriminator.parameters(): param.requires_grad_(is_opt_d)

    def setup(self, stage=None):
        for feat in self.feature_predictors:
            self.feature_predictors[feat].to(self.device)
        if self.pipe is None: # Avoid re-creating if already done
            self.pipe = StableDiffusionImg2ImgPipeline(
                vae=self.vae, unet=self.unet, scheduler=self.scheduler,
                text_encoder=self.text_encoder, tokenizer=self.tokenizer,
                safety_checker=None, feature_extractor=None,
            ).to(self.device)
        print("Inference pipeline setup complete.")

    def forward(self, input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side):
        input_dtype = input_images.dtype
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps) # x_t
            text_inputs = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0] #[B, SeqLen, EmbDim]

            img_cond_current = self.img_conditioning(input_latents) # [B, EmbDim] - from current X_t visual
            
            # Get rich context from the context encoder
            rich_context = self.context_encoder(
                img_cond_current,
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side,
                image_seq_hist=None, # Assuming historical images are off for discriminator/context encoder
                delta_t_mean=self.delta_t_mean, delta_t_std=self.delta_t_std
            )
            projected_rich_context = self.rich_context_projector(rich_context) # [B, unet_cross_attn_dim]

            # U-Net Conditioning: text + projected_rich_context
            unet_conditioning = text_embeddings + projected_rich_context.unsqueeze(1)
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample

        # Return pred_noise, noise etc., AND rich_context for the discriminator
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype), \
               rich_context.to(input_dtype)


    def _calculate_predicted_x0(self, pred_noise, timesteps, noisy_latents):
        alphas_cumprod = self.scheduler.alphas_cumprod.to(pred_noise.device)
        sqrt_alpha_prod = alphas_cumprod[timesteps]**0.5; sqrt_alpha_prod = sqrt_alpha_prod.flatten()[:,None,None,None]
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps])**0.5; sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()[:,None,None,None]
        pred_original_sample = (noisy_latents - sqrt_one_minus_alpha_prod * pred_noise) / (sqrt_alpha_prod + 1e-8)
        return pred_original_sample


    def training_step(self, batch, batch_idx):
        opt_g, opt_d = self.optimizers()
        input_images=batch["input_image"]; target_images=batch["target_image"]; prompts=batch["prompt"]
        later_features=batch["later_features"]; interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist=batch["cov_seq"].to(self.device); trt_seq_hist=batch["trt_seq"].to(self.device)
        lengths_hist=batch["lengths"]; delta_t=batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        self.unet.train()
        self.img_conditioning.train()
        self.context_encoder.train()
        if hasattr(self, 'rich_context_projector') and not isinstance(self.rich_context_projector, nn.Identity):
            self.rich_context_projector.train()
        self.treatment_discriminator.train()

        # --- Generator Update ---
        self.toggle_optimizer(opt_g)
        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context_g = \
            self(input_images, target_images, prompts, 
                 cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side)

        if torch.isnan(pred_noise).any():
            print("FATAL TRAIN GEN: NaN in pred_noise"); self.zero_grad(); return None

        diffusion_loss = F.mse_loss(pred_noise, noise) # Use standard MSE loss
        if torch.isnan(diffusion_loss):
            print("FATAL TRAIN GEN: NaN in diffusion_loss"); self.zero_grad(); return None

        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            try:
                # Calculate predicted x0 (original image) from noise prediction
                pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
                if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                    # Decode predicted latents to image space
                    generated_images = self.vae.decode(pred_latents / self.latent_scale).sample
                    if not torch.isnan(generated_images).any():
                        f_loss_total = 0.0; f_counts = 0
                        gen_images_0_to_1_train = (generated_images * 0.5) + 0.5
                        imagenet_norm_transform_train = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        for feat_name, predictor in self.feature_predictors.items():
                            try:
                                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                                    feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                                    gt = later_features[:, feat_idx].long().to(self.device) # Ensure GT is on device
                                    valid_mask = (gt != -1) & (gt >= 0) # Basic validity check
                                    if valid_mask.sum() > 0:
                                        # Get predictor's final layer and predict
                                        head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                                        processed_gen_for_predictor_train = imagenet_norm_transform_train(gen_images_0_to_1_train.to(head_layer.weight.dtype))
                                        pred_logits = predictor(processed_gen_for_predictor_train.to(head_layer.weight.dtype))
                                        if not torch.isnan(pred_logits).any():
                                            num_classes = pred_logits.shape[-1]
                                            valid_mask &= (gt < num_classes) # Check GT is within class range
                                            if valid_mask.sum() > 0:
                                                current_loss = F.cross_entropy(pred_logits[valid_mask].float(), gt[valid_mask])
                                                if not torch.isnan(current_loss):
                                                    f_loss_total += current_loss
                                                    f_counts += 1
                            except (ValueError, IndexError, KeyError): # Catch potential errors finding index or feature
                                print(f"Warn: Skipping feature {feat_name} in feature loss calculation.")
                                continue
                            except Exception as pred_e:
                                print(f"Err predictor {feat_name}: {pred_e}")
                        if f_counts > 0: feature_loss = f_loss_total / f_counts
                    else: print("WARN TRAIN GEN: NaN in generated_images for feature loss")
                else: print("WARN TRAIN GEN: NaN/Inf in pred_latents for feature loss")
            except Exception as e: print(f"Err train feat loss block: {e}"); self.log(f"train_feat_loss_err", 1.0, sync_dist=True)
        if torch.isnan(feature_loss): feature_loss = torch.tensor(0.0, device=self.device) # Ensure it's zero if NaN occurs

        # Adversarial Loss for Generator
        # Pass features that generator can influence (img_cond_curr_g, cond_feats_curr_g)
        # along with historical context (which is fixed for this batch item)
        treatment_logits_adv = self.treatment_discriminator(rich_context_g) # Discriminator now takes rich_context
        log_probs_for_Lconf = F.log_softmax(treatment_logits_adv, dim=1)
        adversarial_loss_g = -log_probs_for_Lconf.mean()
        
        generator_loss = diffusion_loss + self.hparams.feature_loss_weight * feature_loss + self.hparams.adversarial_weight * adversarial_loss_g
        if torch.isnan(adversarial_loss_g):
            print("FATAL GEN: NaN adv_loss_g"); self.zero_grad(); return None
        if torch.isnan(generator_loss):
            print("FATAL GEN: NaN generator_loss"); self.zero_grad(); return None
        opt_g.zero_grad()
        self.manual_backward(generator_loss)

        if self.hparams.gradient_clip_val > 0: # Check if clipping is enabled via hparams
            self.clip_gradients(opt_g, gradient_clip_val=self.hparams.gradient_clip_val, gradient_clip_algorithm="norm")
        opt_g.step()
        bs = input_images.size(0)
        self.log_dict({
            "train_gen_loss": generator_loss.detach(),
            "train_diffusion_loss": diffusion_loss.detach(),
            "train_feature_loss": feature_loss.detach(),
            "train_adv_loss_g": adversarial_loss_g.detach() # Log the generator's adversarial component
        }, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=bs)
        
        # --- Discriminator Update ---
        self.toggle_optimizer(opt_d)
        # Get current image features (img_cond, conditioning_feats) again and DETACH them
        # We can reuse img_cond_curr_g and cond_feats_curr_g from above, but detached.
        rich_context_d = rich_context_g.detach()        # Historical data is already "detached" as it's direct batch input.

        treatment_logits_d = self.treatment_discriminator(rich_context_d)
        discriminator_loss = self.criterion_discriminator(treatment_logits_d, interval_labels.float())
        if torch.isnan(discriminator_loss):
            print("FATAL DISC: NaN discriminator_loss"); self.zero_grad(); return None
        
        opt_d.zero_grad()
        self.manual_backward(discriminator_loss)
        if self.hparams.gradient_clip_val > 0: # Check if clipping is enabled via hparams
            self.clip_gradients(opt_d, gradient_clip_val=self.hparams.gradient_clip_val, gradient_clip_algorithm="norm")
        opt_d.step()
        self.log("train_disc_loss", discriminator_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=bs)
        return None

    def validation_step(self, batch, batch_idx):
        input_images=batch["input_image"]; target_images=batch["target_image"]; prompts=batch["prompt"]
        earlier_features=batch["earlier_features"]; later_features=batch["later_features"]; interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist=batch["cov_seq"].to(self.device); trt_seq_hist=batch["trt_seq"].to(self.device)
        lengths_hist=batch["lengths"]; delta_t=batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context_val = \
            self(input_images, target_images, prompts,
                 cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side)

        if torch.isnan(pred_noise).any(): print("WARN VAL: NaN pred_noise"); return
        # Log unweighted loss
        val_loss = F.mse_loss(pred_noise, noise)
        if torch.isnan(val_loss): print("WARN VAL: NaNs in val_loss"); val_loss=torch.tensor(1000.0, device=self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        val_mae_loss = F.l1_loss(pred_noise, noise)
        if torch.isnan(val_mae_loss): print("WARN VAL: NaNs in val_mae_loss"); val_mae_loss=torch.tensor(1000.0, device=self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Discriminator Performance
        with torch.no_grad():
            self.treatment_discriminator.eval()
            treatment_logits_val = self.treatment_discriminator(rich_context_val)
            val_disc_loss = self.criterion_discriminator(treatment_logits_val, interval_labels.float())
            # Calculate discriminator accuracy/AUC for validation
            probs_val = torch.sigmoid(treatment_logits_val)
            preds_val = (probs_val >= 0.5).int()
            labels_np = interval_labels.cpu().numpy()
            preds_np = preds_val.cpu().numpy()
            probs_np = probs_val.cpu().numpy()

            # Flatten for overall accuracy/AUC calculation
            labels_flat = labels_np.flatten()
            preds_flat = preds_np.flatten()
            probs_flat = probs_np.flatten()

            val_disc_acc = accuracy_score(labels_flat, preds_flat)
            try:
                # Calculate macro AUC if multiple classes have variance
                aucs_val = []
                for i in range(NUM_TREATMENTS):
                    if len(np.unique(labels_np[:, i])) > 1:
                        aucs_val.append(roc_auc_score(labels_np[:, i], probs_np[:, i]))
                val_disc_auc = np.mean(aucs_val) if aucs_val else 0.0
            except Exception as e:
                print(f"WARN VAL: Could not calculate AUC: {e}")
                val_disc_auc = 0.0

            self.log_dict({
                "val_disc_loss": val_disc_loss,
                "val_disc_acc": val_disc_acc,
                "val_disc_auc": val_disc_auc
            }, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))
            self.treatment_discriminator.train() # Set back to train mode

        # Image Generation and Gen Loss
        generated_images = None
        try:
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else: print("WARN VAL: NaN/Inf in pred_latents"); generated_images = target_images
        except Exception: print("WARN VAL: Image Generation Failed"); generated_images = target_images
        if generated_images is None: 
            print("generated images are target images")
            generated_images = target_images

        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics
        if len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images):
            min_expected_earlier_dim = max([PERCEPTUAL_FEATURES.index(f) for f in self.feature_predictors.keys() if f in PERCEPTUAL_FEATURES] + [-1]) + 1
            if earlier_features.shape[1] >= min_expected_earlier_dim:
                for feat_name, predictor in self.feature_predictors.items():
                    try:
                        feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                        earlier_feat_gt = earlier_features[:, feat_idx].long()
                        later_feat_gt = later_features[:, feat_idx].long()
                        valid_mask = (earlier_feat_gt != -1) & (later_feat_gt != -1) & (later_feat_gt != -1.0)
                        if valid_mask.sum() == 0: continue
                        predictor.to(self.device)
                        gen_images_0_to_1 = (generated_images * 0.5) + 0.5
                        target_images_0_to_1 = (target_images * 0.5) + 0.5
                        imagenet_norm_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        with autocast('cuda', enabled=self.trainer.precision == "16-mixed"):
                            head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                            processed_gen_images_for_predictor = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                            processed_target_images_for_predictor = imagenet_norm_transform(target_images_0_to_1.to(head_layer.weight.dtype)) 
                            pred_feat_logits = predictor(processed_gen_images_for_predictor.to(head_layer.weight.dtype))
                            target_feat_logits = predictor(processed_target_images_for_predictor.to(head_layer.weight.dtype))

                        # Ensure logits are valid before argmax
                        if torch.isnan(pred_feat_logits).any():
                            print(f"WARN VAL ITE {feat_name}: NaN in predictor logits.")
                            continue
                        if torch.isnan(target_feat_logits).any():
                            print(f"WARN VAL ITE {feat_name}: NaN in target_feat_logits.")
                            continue

                        # Get predicted values (indices)
                        pred_feat_val = pred_feat_logits.argmax(dim=-1).float()[valid_mask]
                        target_feat_val = target_feat_logits.argmax(dim=-1).float()[valid_mask] # Predicted value for target image

                        # Get ground truth values
                        earlier_feat_masked = earlier_feat_gt[valid_mask].float()
                        later_feat_masked = later_feat_gt[valid_mask].float() # True value for target image

                        # Calculate deltas
                        pred_delta = pred_feat_val - earlier_feat_masked # Predicted change by diffusion model
                        gt_delta = later_feat_masked - earlier_feat_masked # True change
                        gt_model_delta = target_feat_val - earlier_feat_masked # Change predicted by feature model on real images

                        # Calculate ITE errors
                        val_ite_feat = torch.abs(pred_delta - gt_delta).mean() # Diff model vs Ground Truth change
                        val_model_ite_feat = torch.abs(pred_delta - gt_model_delta).mean() # Diff model vs Feature Predictor change on real target

                        # Log ITE metrics
                        if not torch.isnan(val_ite_feat):
                            self.log(f"val_ite_{feat_name}", val_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())
                        if not torch.isnan(val_model_ite_feat):
                            self.log(f"val_model_ite_{feat_name}", val_model_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())

                    except (ValueError, IndexError, KeyError):
                        print(f"WARN VAL ITE: Skipping feature {feat_name} due to index/key error.")
                        continue
                    except Exception as e:
                        print(f"Err val ITE {feat_name}: {e}")

        # Log qualitative images
        if self.global_rank == 0 and batch_idx == 2:
             self.log_images_qualitative(input_images[0:16], target_images[0:16], prompts[0:16], suffix="")

    def log_images_qualitative(self, input_imgs, target_imgs, prompts, suffix="", strength=0.75, guidance_scale=7.5):
        if self.pipe is None: print("Warn: Pipe not ready for qualitative logging."); return
        num_images = input_imgs.shape[0]
        print(f"Generating qualitative images ({num_images} samples)...")
        try:
            # Ensure input images are in PIL format
            input_pil = [self.transform_to_pil(((img * 0.5) + 0.5).cpu().float()) for img in input_imgs]
            # self.pipe.unet = self.unet # Potentially needed depending on how pipe is updated by PL
            with torch.no_grad(), autocast('cuda', enabled=False): # Use float32 for inference pipeline
                generated_pil = self.pipe(prompt=prompts, image=input_pil, strength=strength, guidance_scale=guidance_scale, num_inference_steps=50).images

            # Convert generated PIL images back to tensors for grid display
            to_tensor_transform = T.ToTensor()
            generated_tensors = torch.stack([to_tensor_transform(img) for img in generated_pil]).to(self.device)
            input_grid = make_grid(input_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))            
            target_grid = make_grid(target_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))
            generated_grid=make_grid(generated_tensors.cpu().detach())
            self.logger.experiment.add_image(f"Input{suffix}", input_grid, self.current_epoch)
            self.logger.experiment.add_image(f"Target{suffix}", target_grid, self.current_epoch)
            self.logger.experiment.add_image(f"Generated{suffix}", generated_grid, self.current_epoch)
            prompts_text="\n".join([f"Img {i+1}: {p}" for i, p in enumerate(prompts)])
            self.logger.experiment.add_text(f"Val/Prompts{suffix}", prompts_text, self.current_epoch)
            print("Qualitative images logged.")
        except Exception as e:
            print(f"Warn: Qualitative logging failed: {e}")

    def configure_optimizers(self):
        gen_params = list(self.unet.parameters()) + \
                     list(self.img_conditioning.parameters()) + \
                     list(self.context_encoder.parameters())
        if hasattr(self, 'rich_context_projector') and not isinstance(self.rich_context_projector, nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
            
        opt_g = AdamW(gen_params, lr=self.hparams.lr_generator) # Use hparams
        opt_d = AdamW(self.treatment_discriminator.parameters(), lr=self.hparams.lr_discriminator) # Use hparams
        return [opt_g, opt_d]
    
    def test_step(self, batch, batch_idx):
        input_images=batch["input_image"]; target_images=batch["target_image"]; prompts=batch["prompt"]
        earlier_features=batch["earlier_features"]; later_features=batch["later_features"]; interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist=batch["cov_seq"].to(self.device); trt_seq_hist=batch["trt_seq"].to(self.device)
        lengths_hist=batch["lengths"]; delta_t=batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)
        # Mask for samples that received treatment (last label == 0 means treatment occurred)
        treat_mask = (interval_labels[:, -1] == 0)
        metrics = {}
        with torch.no_grad(): # Ensure no gradients during testing
            pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context_test = self(
                input_images, target_images, prompts,
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side
            )

        if torch.isnan(pred_noise).any(): raise ValueError("NaN in pred_noise")

        # --- Calculate Losses ---
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)
        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()

        # Compute per-sample losses for treated subset
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        per_sample_mae = F.l1_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_loss_mae_treated'] = per_sample_mae[treat_mask].mean().item()

        # --- Generate Images for Gen Loss and ITE ---
        generated_images = None
        with torch.no_grad(): # Ensure VAE decode is without grad
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval() # Ensure VAE is in eval mode
                # Use autocast consistent with validation/training if applicable
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            if generated_images is None or torch.isnan(generated_images).any(): generated_images = target_images # Fallback

        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()

        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()

        # --- Calculate ITE Metrics ---
        if len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images):
            min_expected_earlier_dim = 0
            try:
                min_expected_earlier_dim = max([PERCEPTUAL_FEATURES.index(f) for f in self.feature_predictors.keys() if f in PERCEPTUAL_FEATURES] + [-1]) + 1
            except ValueError:
                print("WARN ITE: Mismatch between feature_predictors keys and PERCEPTUAL_FEATURES list.")

            if earlier_features.shape[1] >= min_expected_earlier_dim:
                # Define ImageNet normalization here for clarity
                imagenet_norm_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

                for feat_name, predictor in self.feature_predictors.items():
                    try:
                        predictor.to(self.device)
                        feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                        earlier_feat_gt = earlier_features[:, feat_idx].long().to(self.device)
                        later_feat_gt = later_features[:, feat_idx].long().to(self.device)
                        # --- Replacement block for ITE metrics ---
                        # Base validity mask (all samples with valid ground truth)
                        base_mask = (earlier_feat_gt != -1) & (later_feat_gt != -1) & (later_feat_gt != -1.0)
                        if base_mask.sum() == 0:
                            continue
                        # Treated subset mask
                        treated_mask = base_mask & treat_mask
                        # Prepare images for predictor (denormalize [-1,1] -> [0,1], then ImageNet norm)
                        gen_images_0_to_1 = (generated_images.to(self.device) * 0.5) + 0.5
                        target_images_0_to_1 = (target_images.to(self.device) * 0.5) + 0.5
                        with torch.no_grad(), autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                            head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                            if head_layer is None:
                                print(f"WARN ITE: Could not find head layer for predictor {feat_name}")
                                continue
                            processed_gen_images = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                            processed_target_images = imagenet_norm_transform(target_images_0_to_1.to(head_layer.weight.dtype))
                            pred_feat_logits = predictor(processed_gen_images)
                            target_feat_logits = predictor(processed_target_images)
                        if torch.isnan(pred_feat_logits).any() or torch.isnan(target_feat_logits).any():
                            print(f"WARN ITE {feat_name}: NaN in predictor logits.")
                            continue
                        # Get predicted values and ground truth for all valid samples (base_mask)
                        pred_vals = pred_feat_logits.argmax(dim=-1).float()
                        target_vals = target_feat_logits.argmax(dim=-1).float()
                        gt_vals = earlier_feat_gt.float()
                        # Deltas for all valid
                        pred_delta_all = pred_vals[base_mask] - gt_vals[base_mask]
                        gt_delta_all = later_feat_gt.float()[base_mask] - gt_vals[base_mask]
                        model_delta_all = target_vals[base_mask] - gt_vals[base_mask]
                        ite_err_all = torch.abs(pred_delta_all - gt_delta_all)
                        model_ite_err_all = torch.abs(pred_delta_all - model_delta_all)
                        metrics[f'test_ite_{feat_name}'] = ite_err_all.mean().item()
                        metrics[f'test_model_ite_{feat_name}'] = model_ite_err_all.mean().item()
                        # Now compute metrics for treated only, if any
                        if treated_mask.sum() > 0:
                            pred_delta_tr = pred_vals[treated_mask] - gt_vals[treated_mask]
                            gt_delta_tr = later_feat_gt.float()[treated_mask] - gt_vals[treated_mask]
                            model_delta_tr = target_vals[treated_mask] - gt_vals[treated_mask]
                            ite_err_tr = torch.abs(pred_delta_tr - gt_delta_tr)
                            model_ite_err_tr = torch.abs(pred_delta_tr - model_delta_tr)
                            metrics[f'test_ite_{feat_name}_treated'] = ite_err_tr.mean().item()
                            metrics[f'test_model_ite_{feat_name}_treated'] = model_ite_err_tr.mean().item()
                    except (ValueError, IndexError, KeyError) as e:
                        print(f"WARN ITE: Skipping feature {feat_name} due to index/key error: {e}")
                        continue
                    except Exception as e:
                        print(f"WARN ITE: Error calculating ITE for {feat_name}: {e}")

        metrics['batch_size'] = float(input_images.size(0))
        # accumulate for on_test_epoch_end
        self._test_outputs.append(metrics.copy())
        # Return only finite values
        return {k: v for k, v in metrics.items() if np.isfinite(v)}

    def on_test_epoch_end(self):
        outputs = self._test_outputs
        if not outputs: print("No valid outputs collected during testing."); return

        aggregated_metrics = {}
        valid_outputs = [out for out in outputs if out is not None] # Filter Nones
        if not valid_outputs: print("All test steps failed or returned None."); return

        # Aggregate all keys found in the outputs, except 'batch_size'
        metric_keys = {k for out in valid_outputs for k in out.keys()} - {'batch_size'}
        total_samples = sum(out.get('batch_size', 0) for out in valid_outputs)

        print(f"\n--- Aggregating Test Results ({len(valid_outputs)} Batches, {total_samples:.0f} Samples) ---")

        for key in sorted(list(metric_keys)):
            # Filter outputs that contain the key
            key_valid_outputs = [out for out in valid_outputs if key in out]
            if not key_valid_outputs: continue

            # Calculate weighted average using batch_size
            total_value = sum(out[key] * out['batch_size'] for out in key_valid_outputs)
            total_weight = sum(out['batch_size'] for out in key_valid_outputs)
            aggregated_metrics[key] = total_value / total_weight if total_weight > 0 else float('nan')

        print("--- Final Test Metrics ---")
        results_str = [f"Checkpoint: {getattr(self.hparams, 'ckpt_path', 'N/A')}", f"Total Samples: {total_samples:.0f}", "---"]
        for key, value in aggregated_metrics.items():
            print(f"{key}: {value:.6f}")
            results_str.append(f"{key}: {value:.6f}")

        # Save results to file (path specified in hparams, potentially added after loading)
        results_file_path = getattr(self.hparams, 'results_file', None)
        if results_file_path:
            print(f"Saving results to {results_file_path}...")
            try:
                # Only make parent directory if one is specified
                dirpath = os.path.dirname(results_file_path)
                if dirpath:
                    os.makedirs(dirpath, exist_ok=True)
                with open(results_file_path, 'w') as f:
                    f.write("\n".join(results_str))
                print("Results saved.")
            except Exception as e:
                print(f"Error saving results: {e}")
        # Optionally clear test outputs after aggregation
        self._test_outputs.clear()

# --- Main Script Logic ---
if __name__ == "__main__":
    args = parse_args()
    pl.seed_everything(42, workers=True)
    def collate_fn_adversarial_rnn(batch):
        input_images = torch.stack([b["input_image"] for b in batch])
        target_images = torch.stack([b["target_image"] for b in batch])
        prompts = [b["prompt"] for b in batch]
        earlier_features = torch.stack([b["earlier_features"] for b in batch]) # Tabular features for current X_t
        later_features = torch.stack([b["later_features"] for b in batch])
        interval_labels = torch.stack([b["interval_labels"] for b in batch])
        
        # Historical data (past X_0...X_{t-1})
        covs_hist = [b["cov_seq"] for b in batch] # Historical covariates
        trts_hist = [b["trt_seq"] for b in batch] # Historical treatments
        lengths_hist = torch.tensor([seq.size(0) for seq in covs_hist], dtype=torch.long)
        covs_hist_padded = pad_sequence(covs_hist, batch_first=True, padding_value=0.0)
        trts_hist_padded = pad_sequence(trts_hist, batch_first=True, padding_value=0.0)
        
        delta_t = torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1) # For current interval X_t -> X_{t+1}
        side_list = [b["side"] for b in batch]
        side_tensor = torch.tensor([1.0 if s.upper() == "R" else 0.0 for s in side_list], dtype=torch.float).unsqueeze(1)

        images_hist_padded = None
        
        return {
            "input_image": input_images, "target_image": target_images, "prompt": prompts,
            "earlier_features": earlier_features, "later_features": later_features,
            "interval_labels": interval_labels,
            "cov_seq": covs_hist_padded, "trt_seq": trts_hist_padded, "lengths": lengths_hist,
            "delta_t": delta_t, "side": side_tensor, "image_seq": images_hist_padded
        }
    
    print(f"Run Name: {args.run_name}")
    image_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    if args.test_only:
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)

        # Compute pos_w_disc if needed (use same logic as training)
        if args.use_disc_pos_weight:
            train_csv = os.path.join(args.pairs_dir, "train.csv")
            train_dataset = TemporalKneeFeatureConditioningDataset(
                csv_file=train_csv,
                image_transform=image_transforms,
                include_images=False
            )
            pos_counts = torch.zeros(NUM_TREATMENTS)
            n_train = 0
            for sample in train_dataset.samples:
                pos_counts += sample["interval_labels"].cpu()
                n_train += 1
            if n_train > 0:
                neg = n_train - pos_counts
                safe_pos = torch.clamp(pos_counts, min=1e-6)
                pos_w_disc = neg / safe_pos
            else:
                pos_w_disc = None
        else:
            pos_w_disc = None

        # Load model from checkpoint
        diffusion_model = AdversarialDiffusionLitModule.load_from_checkpoint(
            args.ckpt_path,
            map_location="cpu",
            args=args,
            vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning,
            delta_t_mean=DELTA_T_MEAN, delta_t_std=DELTA_T_STD,
            discriminator_pos_weight=pos_w_disc,
            strict=False
        )
        diffusion_model.hparams.results_file = args.results_file
        diffusion_model.hparams.ckpt_path = args.ckpt_path

        test_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "test.csv"),
            image_transform=image_transforms,
            include_images=False
        )
        print(f"Test dataset size: {len(test_dataset)}")
        test_loader = DataLoader(
            test_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True,
            collate_fn=partial(collate_fn_adversarial_rnn)
        )
        trainer = pl.Trainer(
            accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=False, callbacks=[]
        )
        trainer.test(model=diffusion_model, dataloaders=test_loader, verbose=True)
    else:
        train_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "train.csv"),
            image_transform=image_transforms, # This will make image_seq [-1,1]
            include_images=False
        )
        val_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "val.csv"),
            image_transform=image_transforms,
            include_images=False
        )
        if args.use_disc_pos_weight:
            print("Calculating pos_weight for discriminator...")
            pos_counts = torch.zeros(NUM_TREATMENTS)
            n_train = 0
            for sample in train_dataset.samples:
                pos_counts += sample["interval_labels"].cpu() # Ensure it's on CPU
                n_train += 1
            if n_train > 0:
                neg = n_train - pos_counts
                safe_pos = torch.clamp(pos_counts, min=1e-6) # Avoid division by zero
                pos_w_disc = neg / safe_pos
                print(f"  Discriminator pos_w calculated: {pos_w_disc.numpy()}")
            else:
                print("WARN: Could not calculate pos_w for discriminator (n_train=0 or samples missing). Using unweighted loss.")
                pos_w_disc = None # Explicitly None
        else:
            print("Discriminator class weighting disabled; using unweighted loss.")
            pos_w_disc = None

        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True,
            collate_fn=partial(collate_fn_adversarial_rnn), drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True,
            collate_fn=partial(collate_fn_adversarial_rnn)
        )
        # Stable Diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        print("Base pipeline components loaded.")

        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)

        diffusion_model = AdversarialDiffusionLitModule(
            args=args, vae=vae, unet=unet, scheduler=scheduler, text_encoder=text_encoder,
            tokenizer=tokenizer, img_conditioning=img_conditioning,
            delta_t_mean=DELTA_T_MEAN, delta_t_std=DELTA_T_STD, discriminator_pos_weight=pos_w_disc
        )
        monitor_metric = "val_loss"
        filename_metric = "val_loss={val_loss:.4f}"
        print(f"ModelCheckpoint monitoring: {monitor_metric}")
        checkpoint_callback = ModelCheckpoint(
            monitor=monitor_metric,
            dirpath=args.checkpoint_dir,
            filename=f"best-{{epoch}}-{filename_metric}",
            auto_insert_metric_name=False,
            save_top_k=2, mode="min", save_last=True
        )
        tensorboard_logger = TensorBoardLogger(save_dir=args.log_dir_base, name=args.run_name)

        trainer = pl.Trainer(
            max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=tensorboard_logger, log_every_n_steps=25,
            strategy="ddp_find_unused_parameters_true",
            callbacks=[checkpoint_callback]
        )
        print(f"Starting training run: {args.run_name}"); print(f" Logs: {os.path.join(args.log_dir_base, args.run_name)}")
        trainer.fit(diffusion_model, train_loader, val_loader)

        print("--- Training Finished ---")
        print(f"Best model checkpoint path: {checkpoint_callback.best_model_path}")