import os
import argparse
from functools import partial
import torch
from torch.amp import autocast
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import make_grid
import timm
from torch.optim import AdamW
import numpy as np
from pretrain_propensity_model_temporal import TemporalPropensityModel
from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset, NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES, TREATMENT_LABELS

delta_t_mean, delta_t_std = 35.0731, 22.9570 # Using pre-calculated values

torch.set_float32_matmul_precision('high')

# --- Configuration via ArgParse ---
def parse_args():
    parser = argparse.ArgumentParser(description="Train Temporal IPW Diffusion Model")
    # --- Mode ---
    parser.add_argument("--test_only", action="store_true", help="Run in testing mode only")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint (.ckpt) for testing")
    parser.add_argument("--results_file", type=str, default="test_results.txt", help="File to save test metrics")
    # Paths
    parser.add_argument("--propensity_model_path", type=str, default="./checkpoints/Propensity_Model_Temporal/propensity_model_temporal_best.pth")
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    parser.add_argument("--base_checkpoint_dir", type=str, default="/local/acc/OAI_checkpoints", help="Base dir for saving models")
    parser.add_argument("--log_dir_base", type=str, default="/local/acc/OAI_checkpoints/tb_logs", help="Base dir for TensorBoard logs")
    # Model Params
    parser.add_argument("--diffusion_model_pretrain", type=str, default="runwayml/stable-diffusion-v1-5")
    parser.add_argument("--feature_predictor_model", type=str, default='efficientformerv2_l.snap_dist_in1k')
    # Training Params
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate for diffusion model")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--max_epochs", type=int, default=50)
    parser.add_argument("--feature_loss_weight", type=float, default=0, help="Weight for perceptual feature loss")
    parser.add_argument("--precision", type=str, default="16-mixed", choices=["16-mixed", "bf16-mixed", "32-true"])
    parser.add_argument("--devices", type=int, nargs='+', default=[0,1], help="GPU devices to use")
    # IPW Params
    parser.add_argument("--disable_ipw_loss", action='store_true', help="If set, train without IPW loss weighting (but still calculate IPW metrics in val)")
    parser.add_argument("--ipw_max_clamp", type=float, default=150.0, help="Max value to clamp IPW weights")
    # Propensity Model Params (must match saved model if loading)
    parser.add_argument("--propensity_model_type", type=str, default='rnn')
    parser.add_argument("--propensity_hidden_dim", type=int, default=64)
    parser.add_argument("--propensity_num_layers", type=int, default=2)
    parser.add_argument("--propensity_dropout", type=float, default=0.2)
    parser.add_argument("--propensity_delta_t_feat_dim", type=int, default=8)
    parser.add_argument("--propensity_model_uses_images", action='store_true')
    # Context Encoder
    parser.add_argument("--disc_model_type", type=str, choices=["rnn", "transformer"], default="rnn", help="Context encoder sequence model type")
    parser.add_argument("--disc_hidden_dim", type=int, default=128, help="Context encoder RNN/Transformer hidden dim")
    parser.add_argument("--disc_num_layers", type=int, default=2, help="Context encoder RNN/Transformer num layers")
    parser.add_argument("--disc_dropout", type=float, default=0.2, help="Context encoder RNN/Transformer dropout")
    parser.add_argument("--disc_delta_t_feat_dim", type=int, default=8, help="Context encoder projected delta_t feature dim")
    parser.add_argument("--disc_side_feat_dim", type=int, default=4, help="Context encoder projected side feature dim")
    parser.add_argument("--disc_current_img_cond_dim", type=int, default=768, help="Dim of img_cond from current input image (used for ImageConditioning and ContextEncoderRNN)")

    # Other
    parser.add_argument("--run_suffix", type=str, default="", help="Suffix for run name")

    args = parser.parse_args()
    if args.test_only:
        if not args.ckpt_path or not os.path.exists(args.ckpt_path):
            parser.error("--ckpt_path is required and must exist for --test_only mode.")
        args.run_name = f"TEST_{os.path.basename(args.ckpt_path)}"
    else:
        feat_str = '_'.join(sorted(PERCEPTUAL_FEATURES))
        args.run_name = f"ipw_{args.run_suffix}_{args.lr}_{args.ipw_max_clamp}_b{args.batch_size}x{len(args.devices)}_flw{args.feature_loss_weight}"
        if args.disable_ipw_loss: args.run_name += "_NoIPWLoss"
        args.checkpoint_dir = os.path.join(args.base_checkpoint_dir, args.run_name)
        args.log_dir = os.path.join(args.log_dir_base, args.run_name)
        os.makedirs(args.checkpoint_dir, exist_ok=True)
    return args

# --- Image Conditioning Module ---
class ImageConditioning(nn.Module):
    def __init__(self, in_channels=4, hidden_dim=768): # hidden_dim matches text_encoder
        super().__init__(); self.linear = nn.Linear(in_channels, hidden_dim)
    def forward(self, latent): return self.linear(latent.mean(dim=[2, 3]))

class ContextEncoderRNN(nn.Module):
    def __init__(self, args, current_img_cond_dim, cov_hist_dim, num_treatments_hist):
        super().__init__()
        self.model_type = args.disc_model_type
        self.delta_t_feat_dim = args.disc_delta_t_feat_dim
        self.side_feat_dim = args.disc_side_feat_dim

        seq_input_dim = cov_hist_dim + num_treatments_hist

        if self.model_type.lower() == "rnn":
            self.hist_encoder = nn.LSTM(seq_input_dim, args.disc_hidden_dim, args.disc_num_layers,
                                        batch_first=True, dropout=args.disc_dropout if args.disc_num_layers > 1 else 0)
            self.hist_encoder_output_dim = args.disc_hidden_dim
        else: raise ValueError(f"Unsupported context encoder model type: {self.model_type}")

        self.delta_t_processor = nn.Sequential(nn.Linear(1, self.delta_t_feat_dim), nn.ReLU())
        self.side_processor = nn.Sequential(nn.Linear(1, self.side_feat_dim), nn.ReLU())
        
        # Output dimension of this context encoder
        self.output_dim = self.hist_encoder_output_dim + \
                          current_img_cond_dim + \
                          self.delta_t_feat_dim + \
                          self.side_feat_dim

    def forward(self, img_cond_current, # From current X_t
                cov_seq_hist, trt_seq_hist, lengths_hist, # Historical data
                delta_t, side, image_seq_hist=None, # Context for current interval
                delta_t_mean=0.0, delta_t_std=1.0):
        
        combined_hist_seq = torch.cat([cov_seq_hist, trt_seq_hist], dim=2)
        packed_input = pack_padded_sequence(combined_hist_seq, lengths_hist.cpu(), batch_first=True, enforce_sorted=False)
        self.hist_encoder.flatten_parameters()
        _, (h_n, _) = self.hist_encoder(packed_input)
        hist_context_vector = h_n[-1]

        # Process delta_t and side 
        delta_t_norm = (delta_t - delta_t_mean) / (delta_t_std + 1e-8)
        delta_t_features = self.delta_t_processor(delta_t_norm)
        side_features = self.side_processor(side)

        if img_cond_current.ndim == 3: img_cond_current = img_cond_current.mean(dim=1)

        # Concatenate features that form the rich context
        rich_context_vector = torch.cat([
            hist_context_vector,
            img_cond_current, # Features from current image X_t
            delta_t_features,
            side_features
        ], dim=1)
        return rich_context_vector


# --- Diffusion Lightning Module ---
class DiffusionLitModule(pl.LightningModule):
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 propensity_model=None, delta_t_mean=0.0, delta_t_std=1.0):
        super().__init__()
        self.save_hyperparameters(args, ignore=['vae', 'unet', 'scheduler', 'text_encoder', 'tokenizer',
                                                 'img_conditioning', 'feature_predictors']) # Important for PL
        self.vae = vae; self.unet = unet; self.scheduler = scheduler
        self.text_encoder = text_encoder; self.tokenizer = tokenizer
        self.img_conditioning = img_conditioning # For current X_t's latent
        self.context_encoder = ContextEncoderRNN(
            args=args,
            current_img_cond_dim=args.disc_current_img_cond_dim,
            cov_hist_dim=FEATURE_DIM,      
            num_treatments_hist=NUM_TREATMENTS
        )

        # U-Net's cross-attention dimension
        unet_cross_attn_dim = text_encoder.config.hidden_size
        
        if self.context_encoder.output_dim != unet_cross_attn_dim:
            self.rich_context_projector = nn.Linear(self.context_encoder.output_dim, unet_cross_attn_dim)
            print(f"Projecting rich context from {self.context_encoder.output_dim} to {unet_cross_attn_dim}")
        else:
            self.rich_context_projector = nn.Identity()
        # Store params from args
        self.latent_scale = 0.18215 # Standard SD scale
        self.lr = args.lr
        self.feature_loss_weight = args.feature_loss_weight
        self.disable_ipw_loss = args.disable_ipw_loss # Store IPW enable flag
        # Store propensity model related info
        self.propensity_model = propensity_model
        self.propensity_delta_t_mean = delta_t_mean
        self.propensity_delta_t_std = delta_t_std
        self.ipw_max_clamp = args.ipw_max_clamp # Store clamp value
        # Save hyperparameters from args + calculated constants
        self.save_hyperparameters(args, ignore=['vae', 'unet', 'scheduler', 'text_encoder', 'tokenizer',
                                                 'img_conditioning', 'propensity_model', 'feature_predictors'])
        self._freeze_components()
        self.transform_to_pil = T.ToPILImage()
        self.feature_predictors = nn.ModuleDict()
        self._load_feature_predictors()
        self.pipe = None
        self._test_outputs = []

    def _freeze_components(self):
        if hasattr(self, 'vae'): self.vae.eval(); [p.requires_grad_(False) for p in self.vae.parameters()]
        if hasattr(self, 'text_encoder'): self.text_encoder.eval(); [p.requires_grad_(False) for p in self.text_encoder.parameters()]
        if self.propensity_model is not None: self.propensity_model.eval(); [p.requires_grad_(False) for p in self.propensity_model.parameters()]
        # Ensure feature predictors remain frozen
        if hasattr(self, 'feature_predictors'):
            for predictor in self.feature_predictors.values():
                predictor.eval(); [p.requires_grad_(False) for p in predictor.parameters()]
        print("VAE, Text Encoder, Propensity Model, and Feature Predictors frozen.")

    def _load_feature_predictors(self):
        print("Loading and freezing feature predictors...")
        for feat in PERCEPTUAL_FEATURES:
            model_path = os.path.join("checkpoints", feat, f"{feat}_model.pth")
            state_dict = torch.load(model_path, map_location="cpu")
            num_outputs = state_dict['head.weight'].shape[0]
            predictor = timm.create_model('efficientformerv2_l.snap_dist_in1k', pretrained=True)
            predictor.reset_classifier(num_outputs)
            predictor.load_state_dict(state_dict)
            predictor.eval()
            for param in predictor.parameters():
                param.requires_grad = False
            self.feature_predictors[feat] = predictor

    def setup(self, stage=None):
        for feat in self.feature_predictors:
            self.feature_predictors[feat].to(self.device)
        if self.pipe is None: # Avoid re-creating if already done
            self.pipe = StableDiffusionImg2ImgPipeline(
                vae=self.vae, unet=self.unet, scheduler=self.scheduler,
                text_encoder=self.text_encoder, tokenizer=self.tokenizer,
                safety_checker=None, feature_extractor=None,
            ).to(self.device)
        print("Inference pipeline setup complete.")

    def forward(self, input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side):
        input_dtype = input_images.dtype
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps) # x_t
            text_inputs = self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0] #[B, SeqLen, EmbDim]

            img_cond_current = self.img_conditioning(input_latents) # [B, EmbDim] - from current X_t visual
            
            # Get rich context from the context encoder
            rich_context = self.context_encoder(
                img_cond_current,
                cov_seq_hist.to(self.device).to(img_cond_current.dtype), # Ensure device and dtype match
                trt_seq_hist.to(self.device).to(img_cond_current.dtype),
                lengths_hist.to(self.device), # lengths_hist is typically long, device transfer is enough
                delta_t.to(self.device).to(img_cond_current.dtype), 
                side.to(self.device).to(img_cond_current.dtype),
                image_seq_hist=None, # Not used by the copied ContextEncoderRNN
                delta_t_mean=self.propensity_delta_t_mean, # Use stored delta_t stats
                delta_t_std=self.propensity_delta_t_std
            )
            projected_rich_context = self.rich_context_projector(rich_context) # [B, unet_cross_attn_dim]

            # U-Net Conditioning: text + projected_rich_context
            unet_conditioning = text_embeddings + projected_rich_context.unsqueeze(1)
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype)

    def _calculate_predicted_x0(self, pred_noise, timesteps, noisy_latents):
        alphas_cumprod = self.scheduler.alphas_cumprod.to(pred_noise.device)
        sqrt_alpha_prod = alphas_cumprod[timesteps]**0.5; sqrt_alpha_prod = sqrt_alpha_prod.flatten()[:,None,None,None]
        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps])**0.5; sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()[:,None,None,None]
        pred_original_sample = (noisy_latents - sqrt_one_minus_alpha_prod * pred_noise) / (sqrt_alpha_prod + 1e-8)
        return pred_original_sample

    def _calculate_ipw(self, batch):
        # Ensure all NEW required keys are present, including "side"
        required_keys = ["cov_seq", "trt_seq", "lengths", "delta_t", "side"]
        if self.propensity_model is None or \
           not all(k in batch and batch[k] is not None for k in required_keys):
            print("WARN IPW: Missing one or more required inputs for IPW calculation. Skipping.") # Optional: for debugging
            return None
        try:
            cov_seq, trt_seq, lengths, delta_t = batch["cov_seq"], batch["trt_seq"], batch["lengths"], batch["delta_t"]
            side = batch["side"] 
            interval_labels = batch["interval_labels"] 

            # Handle image_seq based on hparams
            image_seq = None
            if self.hparams.propensity_model_uses_images: # Check the hyperparameter
                if "image_seq" in batch and batch["image_seq"] is not None:
                    image_seq = batch["image_seq"]

            # Prepare tensors for propensity model
            model_dtype = next(self.propensity_model.parameters()).dtype
            prop_device = next(self.propensity_model.parameters()).device

            cov_seq_p = cov_seq.to(prop_device, dtype=model_dtype)
            trt_seq_p = trt_seq.to(prop_device, dtype=model_dtype)
            lengths_p = lengths # .cpu() is handled in propensity_model call for pack_padded_sequence
            delta_t_p = delta_t.to(prop_device, dtype=model_dtype)
            side_p = side.to(prop_device, dtype=model_dtype) # <<< NEW: Prepare side tensor
            
            image_seq_p = None
            if image_seq is not None: # Only if image_seq was successfully retrieved and is needed
                image_seq_p = image_seq.to(prop_device, dtype=model_dtype)

            delta_t_norm = (delta_t_p - self.propensity_delta_t_mean) / (self.propensity_delta_t_std + 1e-8)

            # Call propensity model, now including side_p
            with torch.no_grad(), autocast('cuda', enabled=False): # Disable autocast for propensity model if it causes issues
                propensity_logits = self.propensity_model(
                    cov_seq_p, trt_seq_p, lengths_p.cpu(), # lengths must be on CPU for pack_padded_sequence
                    delta_t_norm, image_seq_p, side_p # <<< NEW: Pass side_p
                )
            
            if torch.isnan(propensity_logits).any(): print("WARN IPW: NaNs in logits"); return None
            propensity_probs = torch.sigmoid(propensity_logits.float())
            propensity_probs = torch.clamp(propensity_probs, 1e-6, 1.0 - 1e-6)
            if torch.isnan(propensity_probs).any(): print("WARN IPW: NaNs in probs"); return None
            
            label = interval_labels.to(prop_device).float()
            joint_prob = torch.prod(torch.where(label == 1, propensity_probs, 1.0 - propensity_probs), dim=1)
            if torch.isnan(joint_prob).any(): print("WARN IPW: NaNs in joint_prob"); return None
            
            ipw = 1.0 / (joint_prob + 1e-7)
            ipw = torch.clamp(ipw, min=0.01, max=self.ipw_max_clamp)
            if torch.isnan(ipw).any(): print("WARN IPW: NaNs in final ipw"); return None
            
            return ipw.detach()
        except Exception as e:
            print(f"Error during IPW calculation: {e}")
            import traceback
            traceback.print_exc()
            return None

    def training_step(self, batch, batch_idx):
        input_images=batch["input_image"]; target_images=batch["target_image"]; prompts=batch["prompt"]
        later_features=batch["later_features"] 
        cov_seq_hist=batch["cov_seq"]; trt_seq_hist=batch["trt_seq"]
        lengths_hist=batch["lengths"]; delta_t_hist=batch["delta_t"]; side_hist=batch["side"]

        self.unet.train()
        self.img_conditioning.train()
        self.context_encoder.train()
        if hasattr(self, 'rich_context_projector') and not isinstance(self.rich_context_projector, nn.Identity):
            self.rich_context_projector.train()


        pred_noise, noise, noisy_latents, timesteps, target_latents = self(input_images, target_images, prompts,
            cov_seq_hist, trt_seq_hist, lengths_hist, delta_t_hist, side_hist)
        if torch.isnan(pred_noise).any(): print("FATAL TRAIN: NaN in pred_noise"); return None
        per_sample_diff_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if torch.isnan(per_sample_diff_loss).any(): print("FATAL TRAIN: NaN in per_sample_diff_loss"); return None
        unweighted_diff_loss = per_sample_diff_loss.mean()

        # Calculate weighted loss ONLY if IPW is enabled for training
        weighted_diff_loss = unweighted_diff_loss # Default to unweighted
        ipw = None
        if not self.hparams.disable_ipw_loss: # Check flag passed via args
            ipw = self._calculate_ipw(batch)
            if ipw is not None:
                if torch.isnan(ipw).any(): ipw = None; print("WARN TRAIN: NaN in IPW, using unweighted.")
                else: weighted_diff_loss = (ipw.to(per_sample_diff_loss.device) * per_sample_diff_loss).mean()
            else: # IPW calculation failed, already using unweighted
              print("IPW calculation failed")
        if torch.isnan(weighted_diff_loss): print("FATAL TRAIN: NaN in weighted_diff_loss"); return None

        # --- Feature Loss Calculation ---
        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            try:
                # Calculate predicted x0 (original image) from noise prediction
                pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
                if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                    # Decode predicted latents to image space
                    generated_images = self.vae.decode(pred_latents / self.latent_scale).sample
                    if not torch.isnan(generated_images).any():
                        f_loss_total = 0.0; f_counts = 0
                        gen_images_0_to_1_train = (generated_images * 0.5) + 0.5
                        imagenet_norm_transform_train = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        for feat_name, predictor in self.feature_predictors.items():
                            try:
                                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                                    feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                                    gt = later_features[:, feat_idx].long().to(self.device) # Ensure GT is on device
                                    valid_mask = (gt != -1) & (gt >= 0) # Basic validity check
                                    if valid_mask.sum() > 0:
                                        # Get predictor's final layer and predict
                                        head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                                        processed_gen_for_predictor_train = imagenet_norm_transform_train(gen_images_0_to_1_train.to(head_layer.weight.dtype))
                                        pred_logits = predictor(processed_gen_for_predictor_train.to(head_layer.weight.dtype))
                                        if not torch.isnan(pred_logits).any():
                                            num_classes = pred_logits.shape[-1]
                                            valid_mask &= (gt < num_classes) # Check GT is within class range
                                            if valid_mask.sum() > 0:
                                                current_loss = F.cross_entropy(pred_logits[valid_mask].float(), gt[valid_mask])
                                                if not torch.isnan(current_loss):
                                                    f_loss_total += current_loss
                                                    f_counts += 1
                            except (ValueError, IndexError, KeyError): # Catch potential errors finding index or feature
                                print(f"Warn: Skipping feature {feat_name} in feature loss calculation.")
                                continue
                            except Exception as pred_e:
                                print(f"Err predictor {feat_name}: {pred_e}")
                        if f_counts > 0: feature_loss = f_loss_total / f_counts
                    else: print("WARN TRAIN GEN: NaN in generated_images for feature loss")
                else: print("WARN TRAIN GEN: NaN/Inf in pred_latents for feature loss")
            except Exception as e: print(f"Err train feat loss block: {e}"); self.log(f"train_feat_loss_err", 1.0, sync_dist=True)
        if torch.isnan(feature_loss): feature_loss = torch.tensor(0.0, device=self.device) # Ensure it's zero if NaN occurs

        # Use weighted_diff_loss if IPW enabled and successful, else unweighted_diff_loss
        loss_to_use = weighted_diff_loss if (ipw is not None and not self.hparams.disable_ipw_loss) else unweighted_diff_loss
        total_loss = loss_to_use + self.feature_loss_weight * feature_loss

        if torch.isnan(total_loss): print("FATAL TRAIN: NaN in total_loss"); return None

        # Logging
        bs = input_images.size(0)
        self.log("train_loss", total_loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True, batch_size=bs)
        self.log("train_diff_loss_uw", unweighted_diff_loss, prog_bar=False, sync_dist=True, on_step=False, on_epoch=True, batch_size=bs)
        if ipw is not None and not self.hparams.disable_ipw_loss: # Log weighted only if used
             self.log("train_diff_loss_w", weighted_diff_loss, prog_bar=False, sync_dist=True, on_step=False, on_epoch=True, batch_size=bs)
             self.log("mean_ipw", ipw.mean(), prog_bar=False, sync_dist=True, batch_size=bs)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
             self.log("train_feat_loss", feature_loss, prog_bar=False, sync_dist=True, batch_size=bs)

        return total_loss

    def validation_step(self, batch, batch_idx):
        input_images=batch["input_image"]; target_images=batch["target_image"]; prompts=batch["prompt"]
        earlier_features=batch["earlier_features"]; later_features=batch["later_features"]
        # History for ContextEncoderRNN
        cov_seq_hist=batch["cov_seq"]; trt_seq_hist=batch["trt_seq"]
        lengths_hist=batch["lengths"]; delta_t_hist=batch["delta_t"]; side_hist=batch["side"]


        # Always calculate IPW in validation if model exists
        ipw = self._calculate_ipw(batch)

        pred_noise, noise, noisy_latents, timesteps, target_latents = self(input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, delta_t_hist, side_hist)
        if torch.isnan(pred_noise).any(): print("WARN VAL: NaN in pred_noise"); return

        # Log unweighted loss
        val_loss = F.mse_loss(pred_noise, noise)
        if torch.isnan(val_loss): print("WARN VAL: NaNs in val_loss"); val_loss=torch.tensor(1000.0, device=self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        val_mae_loss = F.l1_loss(pred_noise, noise)
        if torch.isnan(val_mae_loss): print("WARN VAL: NaNs in val_mae_loss"); val_mae_loss=torch.tensor(1000.0, device=self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Log weighted loss if IPW available
        if ipw is not None:
             per_sample_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
             if not torch.isnan(per_sample_loss).any():
                  ipw_val_loss = (ipw.to(per_sample_loss.device) * per_sample_loss).mean()
                  if not torch.isnan(ipw_val_loss): self.log("ipw_val_loss", ipw_val_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Generate Image & Calculate Generation Loss
        generated_images = None
        try: # Use manual calc for metrics generation
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                 try: generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
                 except Exception as decode_e: print(f"WARN VAL: VAE decode: {decode_e}"); generated_images = target_images
            else: print("WARN VAL: NaN/Inf in pred_latents"); generated_images = target_images
        except Exception as e: print(f"ERROR val img gen: {e}"); self.log("val_gen_err", 1.0, sync_dist=True); generated_images = target_images
        if generated_images is None: 
            print("generated images are target images")
            generated_images = target_images

        # Log Gen Loss
        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))
        if ipw is not None:
            gen_loss_per_sample = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
            if not torch.isnan(gen_loss_per_sample).any():
                 ipw_val_gen_loss = (ipw.to(gen_loss_per_sample.device) * gen_loss_per_sample).mean()
                 if not torch.isnan(ipw_val_gen_loss): self.log("ipw_gen_val_loss", ipw_val_gen_loss, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics (unweighted and weighted)
        if len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images):
            min_expected_earlier_dim = max([PERCEPTUAL_FEATURES.index(f) for f in self.feature_predictors.keys() if f in PERCEPTUAL_FEATURES] + [-1]) + 1
            if earlier_features.shape[1] >= min_expected_earlier_dim:
                for feat_name, predictor in self.feature_predictors.items():
                    try:
                        feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                        earlier_feat_gt = earlier_features[:, feat_idx].long()
                        later_feat_gt = later_features[:, feat_idx].long()
                        valid_mask = (earlier_feat_gt != -1) & (later_feat_gt != -1) & (later_feat_gt != -1.0)
                        if valid_mask.sum() == 0: continue
                        gen_images_0_to_1 = (generated_images * 0.5) + 0.5
                        target_images_0_to_1 = (target_images * 0.5) + 0.5
                        imagenet_norm_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                        with autocast('cuda', enabled=self.trainer.precision == "16-mixed"):
                            head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                            processed_gen_images_for_predictor = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                            processed_target_images_for_predictor = imagenet_norm_transform(target_images_0_to_1.to(head_layer.weight.dtype)) 
                            pred_feat_logits = predictor(processed_gen_images_for_predictor.to(head_layer.weight.dtype))
                            target_feat_logits = predictor(processed_target_images_for_predictor.to(head_layer.weight.dtype))
                        # Get predicted values (indices)
                        pred_feat_val = pred_feat_logits.argmax(dim=-1).float()[valid_mask]
                        target_feat_val = target_feat_logits.argmax(dim=-1).float()[valid_mask] # Predicted value for target image

                        # Get ground truth values
                        earlier_feat_masked = earlier_feat_gt[valid_mask].float()
                        later_feat_masked = later_feat_gt[valid_mask].float() # True value for target image

                        # Calculate deltas
                        pred_delta = pred_feat_val - earlier_feat_masked # Predicted change by diffusion model
                        gt_delta = later_feat_masked - earlier_feat_masked # True change
                        gt_model_delta = target_feat_val - earlier_feat_masked # Change predicted by feature model on real images

                        # Calculate ITE errors
                        val_ite_feat = torch.abs(pred_delta - gt_delta).mean() # Diff model vs Ground Truth change
                        val_model_ite_feat = torch.abs(pred_delta - gt_model_delta).mean() # Diff model vs Feature Predictor change on real target
                        # Log ITE metrics
                        if not torch.isnan(val_ite_feat):
                            self.log(f"val_ite_{feat_name}", val_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())
                        if not torch.isnan(val_model_ite_feat):
                            self.log(f"val_model_ite_{feat_name}", val_model_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())
                        if ipw is not None: # Only log weighted if ipw exists
                            ipw_masked = ipw[valid_mask].to(val_ite_feat.device)
                            ipw_val_ite_feat = (torch.abs(pred_delta - gt_delta) * ipw_masked).mean()
                            ipw_val_model_ite_feat = (torch.abs(pred_delta - gt_model_delta) * ipw_masked).mean()
                            self.log(f"ipw_val_ite_{feat_name}", ipw_val_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())
                            self.log(f"ipw_val_model_ite_{feat_name}", ipw_val_model_ite_feat, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True, batch_size=valid_mask.sum())
                    except (ValueError, IndexError, KeyError):
                        print(f"WARN VAL ITE: Skipping feature {feat_name} due to index/key error.")
                        continue
                    except Exception as e:
                        print(f"Err val ITE {feat_name}: {e}")

        # Log qualitative images
        if self.global_rank == 0 and batch_idx == 0:
            self.log_images_qualitative(input_images[0:16], target_images[0:16], prompts[0:16], suffix="") # Log more images

    def log_images_qualitative(self, input_imgs, target_imgs, prompts, suffix="", strength=0.75, guidance_scale=7.5):
        if self.pipe is None: print("Warn: Pipe not ready for qualitative logging."); return
        num_images = input_imgs.shape[0]
        print(f"Generating qualitative images ({num_images} samples)...")
        try:
            input_pil = [self.transform_to_pil(((img * 0.5) + 0.5).cpu().float()) for img in input_imgs]
            # self.pipe.unet = self.unet # Potentially needed depending on how pipe is updated by PL
            with torch.no_grad(), autocast('cuda', enabled=False): # Use float32 for inference pipeline
                generated_pil = self.pipe(prompt=prompts, image=input_pil, strength=strength, guidance_scale=guidance_scale, num_inference_steps=50).images
            # Convert generated PIL images back to tensors for grid display
            to_tensor_transform = T.ToTensor()
            generated_tensors = torch.stack([to_tensor_transform(img) for img in generated_pil]).to(self.device)
            input_grid = make_grid(input_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))            
            target_grid = make_grid(target_imgs.cpu().detach(), normalize=True, value_range=(-1, 1))
            generated_grid=make_grid(generated_tensors.cpu().detach())
            self.logger.experiment.add_image(f"Input{suffix}", input_grid, self.current_epoch)
            self.logger.experiment.add_image(f"Target{suffix}", target_grid, self.current_epoch)
            self.logger.experiment.add_image(f"Generated{suffix}", generated_grid, self.current_epoch)
            prompts_text="\n".join([f"Img {i+1}: {p}" for i, p in enumerate(prompts)])
            self.logger.experiment.add_text(f"Val/Prompts{suffix}", prompts_text, self.current_epoch)
            print("Qualitative images logged.")
        except Exception as e:
            print(f"Warn: Qualitative logging failed: {e}")

    def configure_optimizers(self):
        gen_params = list(self.unet.parameters()) + \
                     list(self.img_conditioning.parameters()) + \
                     list(self.context_encoder.parameters())
        if hasattr(self, 'rich_context_projector') and not isinstance(self.rich_context_projector, nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
        optimizer = AdamW(gen_params, lr=self.lr)
        return optimizer

    def test_step(self, batch, batch_idx):
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"]

        # Mask for samples that received treatment (last label == 0 means treatment occurred)
        treat_mask = (interval_labels[:, -1] == 0)

        # Calculate IPW (uses propensity model stored in self.propensity_model)
        ipw = self._calculate_ipw(batch)  # Inherited method

        metrics = {}
        # History for ContextEncoderRNN
        cov_seq_hist = batch["cov_seq"]
        trt_seq_hist = batch["trt_seq"]
        lengths_hist = batch["lengths"]
        delta_t_hist = batch["delta_t"]
        side_hist = batch["side"]

        with torch.no_grad():  # Ensure no gradients during testing
            pred_noise, noise, noisy_latents, timesteps, target_latents = self(
                input_images, target_images, prompts,
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t_hist, side_hist
            )

        if torch.isnan(pred_noise).any(): raise ValueError("NaN in pred_noise")

        # --- Calculate Losses ---
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)
        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()

        # Compute per-sample losses for treated subset
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        per_sample_mae = F.l1_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_loss_mae_treated'] = per_sample_mae[treat_mask].mean().item()
            if ipw is not None:
                w_t = ipw[treat_mask]
                # normalize weights
                w_t = w_t / w_t.sum().clamp(min=1e-6)
                ipw_mse_treated = (w_t * per_sample_mse[treat_mask]).sum()
                metrics['test_ipw_loss_treated'] = ipw_mse_treated.item()

        if ipw is not None:
            per_sample_loss = per_sample_mse
            if not torch.isnan(per_sample_loss).any():
                ipw_test_loss = (ipw.to(per_sample_loss.device) * per_sample_loss).mean()
                metrics['test_ipw_loss'] = ipw_test_loss.item() # Use test_ prefix
            metrics['test_mean_ipw'] = ipw.mean().item()

        # --- Generate Images for Gen Loss and ITE ---
        generated_images = None
        with torch.no_grad(): # Ensure VAE decode is without grad
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval() # Ensure VAE is in eval mode
                # Use autocast consistent with validation/training if applicable
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            if generated_images is None or torch.isnan(generated_images).any(): generated_images = target_images # Fallback

        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()

        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()
            if ipw is not None:
                w_t = ipw[treat_mask] / ipw[treat_mask].sum().clamp(min=1e-6)
                ipw_gen_treated = (w_t * per_sample_gen[treat_mask]).sum()
                metrics['test_ipw_gen_loss_treated'] = ipw_gen_treated.item()

        if ipw is not None:
            gen_loss_per_sample = per_sample_gen
            if not torch.isnan(gen_loss_per_sample).any():
                ipw_test_gen_loss = (ipw.to(gen_loss_per_sample.device) * gen_loss_per_sample).mean()
                metrics['test_ipw_gen_loss'] = ipw_test_gen_loss.item()

        # --- Calculate ITE Metrics ---
        if len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images):
            min_expected_earlier_dim = 0
            try:
                min_expected_earlier_dim = max([PERCEPTUAL_FEATURES.index(f) for f in self.feature_predictors.keys() if f in PERCEPTUAL_FEATURES] + [-1]) + 1
            except ValueError:
                print("WARN ITE: Mismatch between feature_predictors keys and PERCEPTUAL_FEATURES list.")

            if earlier_features.shape[1] >= min_expected_earlier_dim:
                # Define ImageNet normalization here for clarity
                imagenet_norm_transform = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

                for feat_name, predictor in self.feature_predictors.items():
                    try:
                        predictor.to(self.device)
                        feat_idx = PERCEPTUAL_FEATURES.index(feat_name)
                        earlier_feat_gt = earlier_features[:, feat_idx].long().to(self.device)
                        later_feat_gt = later_features[:, feat_idx].long().to(self.device)
                        # --- Replacement block for ITE metrics ---
                        # Base validity mask (all samples with valid ground truth)
                        base_mask = (earlier_feat_gt != -1) & (later_feat_gt != -1) & (later_feat_gt != -1.0)
                        if base_mask.sum() == 0:
                            continue
                        # Treated subset mask
                        treated_mask = base_mask & treat_mask
                        # Prepare images for predictor (denormalize [-1,1] -> [0,1], then ImageNet norm)
                        gen_images_0_to_1 = (generated_images.to(self.device) * 0.5) + 0.5
                        target_images_0_to_1 = (target_images.to(self.device) * 0.5) + 0.5
                        with torch.no_grad(), autocast('cuda', enabled=self.trainer.precision.startswith("16") or self.trainer.precision.startswith("bf16")):
                            head_layer = getattr(predictor, 'head', getattr(predictor, 'fc', getattr(predictor, 'classifier', None)))
                            if head_layer is None:
                                print(f"WARN ITE: Could not find head layer for predictor {feat_name}")
                                continue
                            processed_gen_images = imagenet_norm_transform(gen_images_0_to_1.to(head_layer.weight.dtype))
                            processed_target_images = imagenet_norm_transform(target_images_0_to_1.to(head_layer.weight.dtype))
                            pred_feat_logits = predictor(processed_gen_images)
                            target_feat_logits = predictor(processed_target_images)
                        if torch.isnan(pred_feat_logits).any() or torch.isnan(target_feat_logits).any():
                            print(f"WARN ITE {feat_name}: NaN in predictor logits.")
                            continue
                        # Get predicted values and ground truth for all valid samples (base_mask)
                        pred_vals = pred_feat_logits.argmax(dim=-1).float()
                        target_vals = target_feat_logits.argmax(dim=-1).float()
                        gt_vals = earlier_feat_gt.float()
                        # Deltas for all valid
                        pred_delta_all = pred_vals[base_mask] - gt_vals[base_mask]
                        gt_delta_all = later_feat_gt.float()[base_mask] - gt_vals[base_mask]
                        model_delta_all = target_vals[base_mask] - gt_vals[base_mask]
                        ite_err_all = torch.abs(pred_delta_all - gt_delta_all)
                        model_ite_err_all = torch.abs(pred_delta_all - model_delta_all)
                        metrics[f'test_ite_{feat_name}'] = ite_err_all.mean().item()
                        metrics[f'test_model_ite_{feat_name}'] = model_ite_err_all.mean().item()
                        if ipw is not None:
                            ipw_vals = ipw[base_mask].to(ite_err_all.device)
                            w_norm = ipw_vals / ipw_vals.sum().clamp(min=1e-6)
                            metrics[f'test_ipw_ite_{feat_name}'] = (ite_err_all * w_norm).sum().item()
                            metrics[f'test_ipw_model_ite_{feat_name}'] = (model_ite_err_all * w_norm).sum().item()
                        # Now compute metrics for treated only, if any
                        if treated_mask.sum() > 0:
                            pred_delta_tr = pred_vals[treated_mask] - gt_vals[treated_mask]
                            gt_delta_tr = later_feat_gt.float()[treated_mask] - gt_vals[treated_mask]
                            model_delta_tr = target_vals[treated_mask] - gt_vals[treated_mask]
                            ite_err_tr = torch.abs(pred_delta_tr - gt_delta_tr)
                            model_ite_err_tr = torch.abs(pred_delta_tr - model_delta_tr)
                            metrics[f'test_ite_{feat_name}_treated'] = ite_err_tr.mean().item()
                            metrics[f'test_model_ite_{feat_name}_treated'] = model_ite_err_tr.mean().item()
                            if ipw is not None:
                                ipw_tr = ipw[treated_mask].to(ite_err_tr.device)
                                w_tr = ipw_tr / ipw_tr.sum().clamp(min=1e-6)
                                metrics[f'test_ipw_ite_{feat_name}_treated'] = (ite_err_tr * w_tr).sum().item()
                                metrics[f'test_ipw_model_ite_{feat_name}_treated'] = (model_ite_err_tr * w_tr).sum().item()
                    except (ValueError, IndexError, KeyError) as e:
                        print(f"WARN ITE: Skipping feature {feat_name} due to index/key error: {e}")
                        continue
                    except Exception as e:
                        print(f"WARN ITE: Error calculating ITE for {feat_name}: {e}")

        metrics['batch_size'] = float(input_images.size(0))
        # accumulate for on_test_epoch_end
        self._test_outputs.append(metrics.copy())
        # Return only finite values
        return {k: v for k, v in metrics.items() if np.isfinite(v)}

    def on_test_epoch_end(self):
        outputs = self._test_outputs
        if not outputs: print("No valid outputs collected during testing."); return

        aggregated_metrics = {}
        valid_outputs = [out for out in outputs if out is not None] # Filter Nones
        if not valid_outputs: print("All test steps failed or returned None."); return

        # Aggregate all keys found in the outputs, except 'batch_size'
        metric_keys = {k for out in valid_outputs for k in out.keys()} - {'batch_size'}
        total_samples = sum(out.get('batch_size', 0) for out in valid_outputs)

        print(f"\n--- Aggregating Test Results ({len(valid_outputs)} Batches, {total_samples:.0f} Samples) ---")

        for key in sorted(list(metric_keys)):
            # Filter outputs that contain the key
            key_valid_outputs = [out for out in valid_outputs if key in out]
            if not key_valid_outputs: continue

            # Calculate weighted average using batch_size
            total_value = sum(out[key] * out['batch_size'] for out in key_valid_outputs)
            total_weight = sum(out['batch_size'] for out in key_valid_outputs)
            aggregated_metrics[key] = total_value / total_weight if total_weight > 0 else float('nan')

        print("--- Final Test Metrics ---")
        results_str = [f"Checkpoint: {getattr(self.hparams, 'ckpt_path', 'N/A')}", f"Total Samples: {total_samples:.0f}", "---"]
        for key, value in aggregated_metrics.items():
            print(f"{key}: {value:.6f}")
            results_str.append(f"{key}: {value:.6f}")

        # Save results to file (path specified in hparams, potentially added after loading)
        results_file_path = getattr(self.hparams, 'results_file', None)
        if results_file_path:
            print(f"Saving results to {results_file_path}...")
            try:
                # Only make parent directory if one is specified
                dirpath = os.path.dirname(results_file_path)
                if dirpath:
                    os.makedirs(dirpath, exist_ok=True)
                with open(results_file_path, 'w') as f:
                    f.write("\n".join(results_str))
                print("Results saved.")
            except Exception as e:
                print(f"Error saving results: {e}")
        # Optionally clear test outputs after aggregation
        self._test_outputs.clear()

# --- Main Script Logic ---
if __name__ == "__main__":
    args = parse_args()
    pl.seed_everything(42, workers=True)
    def collate_fn_diffusion(batch, include_propensity_history=True, propensity_model_uses_images=False): # Added propensity_model_uses_images
        input_images = torch.stack([b["input_image"] for b in batch])
        target_images = torch.stack([b["target_image"] for b in batch])
        prompts = [b["prompt"] for b in batch]
        earlier_features = torch.stack([b["earlier_features"] for b in batch])
        later_features = torch.stack([b["later_features"] for b in batch])
        interval_labels = torch.stack([b["interval_labels"] for b in batch]) # For IPW calculation

        collated_batch = {
            "input_image": input_images, "target_image": target_images, "prompt": prompts,
            "earlier_features": earlier_features, "later_features": later_features,
            "interval_labels": interval_labels
        }

        if include_propensity_history:
            try:
                collated_batch["cov_seq"] = pad_sequence([b["cov_seq"] for b in batch], batch_first=True, padding_value=0.0)
                collated_batch["trt_seq"] = pad_sequence([b["trt_seq"] for b in batch], batch_first=True, padding_value=0.0)
                collated_batch["lengths"] = torch.tensor([b["cov_seq"].size(0) for b in batch], dtype=torch.long)
                collated_batch["delta_t"] = torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1)
                
                # --- NEW: Add side processing ---
                side_list = [b["side"] for b in batch] # Get list of 'L'/'R' strings
                # Convert to numeric tensor: 0 for L, 1 for R (or any consistent mapping)
                collated_batch["side"] = torch.tensor(
                    [1.0 if s.upper().startswith("R") else 0.0 for s in side_list],
                    dtype=torch.float
                ).unsqueeze(1)
                # --- END NEW ---

                collated_batch["image_seq"] = None
                if propensity_model_uses_images: # Use the specific arg
                    if all("image_seq" in b and b["image_seq"] is not None for b in batch):
                        collated_batch["image_seq"] = pad_sequence([b["image_seq"] for b in batch], batch_first=True, padding_value=0.0)
                    else:
                        print("Warn: Propensity model uses images, but image_seq missing/None in a batch sample during collate.")
                        # Create a placeholder if necessary, or ensure data is always present
                        # For now, it will remain None if any sample is missing it, which _calculate_ipw handles
            except KeyError as e:
                print(f"Error collating history for IPW: Missing key {e}.")
                # Set keys to None so _calculate_ipw can gracefully skip
                for k in ["cov_seq", "trt_seq", "lengths", "delta_t", "side", "image_seq"]:
                    collated_batch[k] = None
            except Exception as ex:
                print(f"Unexpected error in collate_fn_diffusion: {ex}")
                for k in ["cov_seq", "trt_seq", "lengths", "delta_t", "side", "image_seq"]:
                    collated_batch[k] = None
                    
        return collated_batch

    print(f"Run Name: {args.run_name}")
    image_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    if args.test_only:
        # Load Propensity Model (needed to pass to LitModule loading)
        print("Loading Temporal Propensity Model for testing...")
        propensity_model_instance = TemporalPropensityModel(
            cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim,
            num_layers=args.propensity_num_layers, dropout=args.propensity_dropout,
            model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
            delta_t_feat_dim=args.propensity_delta_t_feat_dim
            # Add image_model_name/img_feat_dim if needed
        )
        state_dict = torch.load(args.propensity_model_path, map_location="cpu")
        if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
        propensity_model_instance.load_state_dict(state_dict)
        propensity_model_instance.eval()
        print(f"Loaded temporal propensity model from {args.propensity_model_path}")

        # Load Lightning Module from Checkpoint
        print(f"Loading Diffusion LitModule from checkpoint: {args.ckpt_path}")
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        # We need the hidden size for ImageConditioning
        text_encoder_config = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder").config
        img_conditioning = ImageConditioning(vae.config.latent_channels, text_encoder_config.hidden_size)
        # Use the original DiffusionLitModule class for loading
        diffusion_model = DiffusionLitModule.load_from_checkpoint(
            args.ckpt_path,
            map_location="cpu",
            args=args,
            vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning,
            propensity_model=propensity_model_instance,
            delta_t_mean=delta_t_mean, delta_t_std=delta_t_std,
            strict=False
        )
        diffusion_model.hparams.results_file = args.results_file
        diffusion_model.hparams.ckpt_path = args.ckpt_path
        prop_uses_images = getattr(diffusion_model.hparams, 'propensity_model_uses_images', False)

        test_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "test.csv"),
            image_transform=image_transforms,
            include_images=prop_uses_images
        )
        print(f"Test dataset size: {len(test_dataset)}")
        collate_fn_test = partial(collate_fn_diffusion,
                                  include_propensity_history=True,
                                  propensity_model_uses_images=prop_uses_images)

        test_loader = DataLoader(
            test_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True,
            collate_fn=collate_fn_test, drop_last=False
        )
        trainer = pl.Trainer(
            accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=False, callbacks=[]
        )
        trainer.test(model=diffusion_model, dataloaders=test_loader, verbose=True)
    else:
        train_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "train.csv"),
            image_transform=image_transforms, # This will make image_seq [-1,1]
            include_images=args.propensity_model_uses_images # This flag controls if image_seq is loaded *by the dataset*
        )
        val_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "val.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True,
            collate_fn=partial(collate_fn_diffusion, 
                            include_propensity_history=True, 
                            propensity_model_uses_images=args.propensity_model_uses_images), 
            drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True,
            collate_fn=partial(collate_fn_diffusion, 
                            include_propensity_history=True, 
                            propensity_model_uses_images=args.propensity_model_uses_images),
        )

        # Stable Diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        print("Base pipeline components loaded.")

        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)

        # --- Load Propensity Model (Conditional based on args.disable_ipw_loss) ---
        temporal_propensity_model_instance = None
        print("Loading Temporal Propensity Model...")
        if os.path.exists(args.propensity_model_path):
            try:
                temporal_propensity_model_instance = TemporalPropensityModel(
                    cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim, num_layers=args.propensity_num_layers,
                    dropout=args.propensity_dropout, model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
                    delta_t_feat_dim=args.propensity_delta_t_feat_dim ).to("cuda") # Load to GPU
                state_dict = torch.load(args.propensity_model_path, map_location="cuda")
                if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                temporal_propensity_model_instance.load_state_dict(state_dict)
                temporal_propensity_model_instance.eval(); [p.requires_grad_(False) for p in temporal_propensity_model_instance.parameters()]
                print(f"Loaded temporal propensity model from {args.propensity_model_path}")
            except Exception as e: print(f"Error loading propensity model: {e}"); temporal_propensity_model_instance = None
        else: print(f"Propensity model checkpoint not found: {args.propensity_model_path}")

        if temporal_propensity_model_instance is None:
            print("ERROR: IPW loss was enabled, but propensity model failed to load. Exiting.")
            exit(1)
        

        # --- Instantiate Lightning Module ---
        print("Instantiating DiffusionLitModule...")
        diffusion_model = DiffusionLitModule(
            args=args, vae=vae, unet=unet, scheduler=scheduler, text_encoder=text_encoder,
            tokenizer=tokenizer, img_conditioning=img_conditioning,
            propensity_model=temporal_propensity_model_instance, # Pass loaded model (or None)
            delta_t_mean=delta_t_mean, delta_t_std=delta_t_std
        )

        # --- Callbacks and Logger ---
        monitor_metric = "val_loss"
        filename_metric = "val_loss={val_loss:.4f}"
        print(f"ModelCheckpoint monitoring: {monitor_metric}")
        checkpoint_callback = ModelCheckpoint(
            monitor=monitor_metric,
            dirpath=args.checkpoint_dir,
            filename=f"best-{{epoch}}-{filename_metric}",
            auto_insert_metric_name=False,
            save_top_k=2, mode="min", save_last=True
        )
        tensorboard_logger = TensorBoardLogger(save_dir=args.log_dir_base, name=args.run_name)

        trainer = pl.Trainer(
            max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=tensorboard_logger, log_every_n_steps=25,
            strategy="ddp_find_unused_parameters_true",
            callbacks=[checkpoint_callback],accumulate_grad_batches=args.accumulate_grad_batches
        )
        print(f"Starting training run: {args.run_name}"); print(f" Logs: {os.path.join(args.log_dir_base, args.run_name)}")
        trainer.fit(diffusion_model, train_loader, val_loader)

        print("--- Training Finished ---")
        print(f"Best model checkpoint path: {checkpoint_callback.best_model_path}")