"""
Refactored IPW Diffusion Model Training
Simplified version using base classes and utility functions.
"""
import os
import argparse
from functools import partial
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast

from pretrain_propensity_model_temporal import TemporalPropensityModel
from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset, NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES
from diff_base import BaseDiffusionLitModule, BaseIPWMixin, ContextEncoderRNN
from data_utils import create_unified_collate_fn
from diffusion_utils import safe_tensor_log, EPSILON_DELTA_T, EPSILON_IPW, calculate_ssim

delta_t_mean, delta_t_std = 35.0731, 22.9570
torch.set_float32_matmul_precision('high')


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train IPW Diffusion Model")
    
    # Mode
    parser.add_argument("--test_only", action="store_true", help="Run in testing mode only")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint for testing")
    parser.add_argument("--results_file", type=str, default="test_results.txt", help="File to save test metrics")
    
    # Paths
    parser.add_argument("--propensity_model_path", type=str, 
                       default="./checkpoints/Propensity_Model_Temporal/propensity_model_temporal_best.pth")
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    parser.add_argument("--base_checkpoint_dir", type=str, default="/local2/acc/OAI_checkpoints")
    parser.add_argument("--log_dir_base", type=str, default="/local2/acc/OAI_checkpoints/tb_logs")
    
    # Model Params
    parser.add_argument("--diffusion_model_pretrain", type=str, default="runwayml/stable-diffusion-v1-5")
    
    # Training Params
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--accumulate_grad_batches", type=int, default=1)
    parser.add_argument("--max_epochs", type=int, default=50)
    parser.add_argument("--feature_loss_weight", type=float, default=0, help="Weight for perceptual feature loss")
    parser.add_argument("--precision", type=str, default="16-mixed", choices=["16-mixed", "bf16-mixed", "32-true"])
    parser.add_argument("--devices", type=int, nargs='+', default=[0,1], help="GPU devices to use")
    
    # IPW Params
    parser.add_argument("--disable_ipw_loss", action='store_true', help="Train without IPW loss weighting")
    parser.add_argument("--ipw_max_clamp", type=float, default=150.0, help="Max value to clamp IPW weights")
    
    # Propensity Model Params
    parser.add_argument("--propensity_model_type", type=str, default='transformer')
    parser.add_argument("--propensity_hidden_dim", type=int, default=128)
    parser.add_argument("--propensity_num_layers", type=int, default=2)
    parser.add_argument("--propensity_dropout", type=float, default=0.2)
    parser.add_argument("--propensity_delta_t_feat_dim", type=int, default=32)
    parser.add_argument("--propensity_model_uses_images", action='store_true')
    
    # Context Encoder
    parser.add_argument("--disc_model_type", type=str, choices=["rnn", "transformer"], default="rnn")
    parser.add_argument("--disc_hidden_dim", type=int, default=128)
    parser.add_argument("--disc_num_layers", type=int, default=2)
    parser.add_argument("--disc_dropout", type=float, default=0.2)
    parser.add_argument("--disc_delta_t_feat_dim", type=int, default=8)
    parser.add_argument("--disc_side_feat_dim", type=int, default=4)
    parser.add_argument("--disc_current_img_cond_dim", type=int, default=768)
    
    # Other
    parser.add_argument("--run_suffix", type=str, default="i", help="Suffix for run name")

    args = parser.parse_args()
    
    if args.test_only:
        if not args.ckpt_path or not os.path.exists(args.ckpt_path):
            parser.error("--ckpt_path is required and must exist for --test_only mode.")
        args.run_name = f"TEST_{os.path.basename(args.ckpt_path)}"
    else:
        feat_str = '_'.join(sorted(PERCEPTUAL_FEATURES))
        args.run_name = (f"{args.run_suffix}_{args.lr}_{args.ipw_max_clamp}_"
                        f"b{args.batch_size}x{len(args.devices)}_flw{args.feature_loss_weight}")
        if args.disable_ipw_loss:
            args.run_name += "_NoIPWLoss"
        args.checkpoint_dir = os.path.join(args.base_checkpoint_dir, args.run_name)
        args.log_dir = os.path.join(args.log_dir_base, args.run_name)
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        
    return args


class IPWDiffusionLitModule(BaseIPWMixin, BaseDiffusionLitModule):
    """IPW Diffusion model using base functionality."""
    
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 propensity_model=None, delta_t_mean=0.0, delta_t_std=1.0):
        super().__init__(args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                        delta_t_mean, delta_t_std)
        
        # Initialize IPW-specific components
        self._init_ipw_components(args, propensity_model)
        
        # Save hyperparameters including learning rate
        self.save_hyperparameters(args)
        
        # Context encoder for rich conditioning
        self.context_encoder = ContextEncoderRNN(
            args=args,
            current_img_cond_dim=args.disc_current_img_cond_dim,
            cov_hist_dim=FEATURE_DIM,
            num_treatments_hist=NUM_TREATMENTS
        )
        
        # U-Net conditioning projection
        unet_cross_attn_dim = text_encoder.config.hidden_size
        if self.context_encoder.output_dim != unet_cross_attn_dim:
            self.rich_context_projector = torch.nn.Linear(self.context_encoder.output_dim, unet_cross_attn_dim)
        else:
            self.rich_context_projector = torch.nn.Identity()

    def forward(self, input_images, target_images, prompts, 
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side):
        """Forward pass with context encoding."""
        input_dtype = input_images.dtype
        
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                      self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, 
                                    (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps)
            
            text_inputs = self.tokenizer(prompts, padding="max_length", 
                                       max_length=self.tokenizer.model_max_length, 
                                       truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0]
            
            # Image conditioning from current input
            img_cond_current = self.img_conditioning(input_latents)
            
            # Rich context from history
            rich_context = self.context_encoder(
                img_cond_current,
                cov_seq_hist.to(self.device).to(img_cond_current.dtype),
                trt_seq_hist.to(self.device).to(img_cond_current.dtype),
                lengths_hist.to(self.device),
                delta_t.to(self.device).to(img_cond_current.dtype),
                side.to(self.device).to(img_cond_current.dtype),
                delta_t_mean=self.propensity_delta_t_mean,
                delta_t_std=self.propensity_delta_t_std
            )
            
            projected_rich_context = self.rich_context_projector(rich_context)
            
            # Combined conditioning
            unet_conditioning = text_embeddings + projected_rich_context.unsqueeze(1)
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, 
                                 encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample
            
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype)

    def _calculate_ipw(self, batch):
        """Calculate IPW weights using propensity model."""
        if self.propensity_model is None:
            return None
            
        required_keys = ["cov_seq", "trt_seq", "lengths", "delta_t", "side"]
        if not all(k in batch and batch[k] is not None for k in required_keys):
            print("WARN IPW: Missing required inputs for IPW calculation.")
            return None
            
        try:
            cov_seq = batch["cov_seq"]
            trt_seq = batch["trt_seq"] 
            lengths = batch["lengths"]
            delta_t = batch["delta_t"]
            side = batch["side"]
            interval_labels = batch["interval_labels"]
            
            # Handle image sequence if needed
            image_seq = None
            if self.hparams.propensity_model_uses_images:
                if "image_seq" in batch and batch["image_seq"] is not None:
                    image_seq = batch["image_seq"]
            
            # Move to propensity model device
            prop_device = next(self.propensity_model.parameters()).device
            model_dtype = next(self.propensity_model.parameters()).dtype
            
            cov_seq_p = cov_seq.to(prop_device, dtype=model_dtype)
            trt_seq_p = trt_seq.to(prop_device, dtype=model_dtype)
            delta_t_p = delta_t.to(prop_device, dtype=model_dtype)
            side_p = side.to(prop_device, dtype=model_dtype)
            
            if image_seq is not None:
                image_seq = image_seq.to(prop_device, dtype=model_dtype)
                
            delta_t_norm = (delta_t_p - self.propensity_delta_t_mean) / (self.propensity_delta_t_std + EPSILON_DELTA_T)
            
            # Calculate propensity
            with torch.no_grad(), autocast('cuda', enabled=False):
                propensity_logits = self.propensity_model(
                    cov_seq_p, trt_seq_p, lengths.cpu(), delta_t_norm, image_seq, side_p
                )
                
            if torch.isnan(propensity_logits).any():
                print("WARN IPW: NaNs in logits")
                return None
                
            propensity_probs = torch.sigmoid(propensity_logits.float())
            propensity_probs = torch.clamp(propensity_probs, 1e-6, 1.0 - 1e-6)
            
            if torch.isnan(propensity_probs).any():
                print("WARN IPW: NaNs in probs")
                return None
                
            label = interval_labels.to(prop_device).float()
            joint_prob = torch.prod(torch.where(label == 1, propensity_probs, 1.0 - propensity_probs), dim=1)
            
            if torch.isnan(joint_prob).any():
                print("WARN IPW: NaNs in joint_prob")
                return None
                
            ipw = 1.0 / (joint_prob + EPSILON_IPW)
            ipw = torch.clamp(ipw, min=0.01, max=self.ipw_max_clamp)
            
            if torch.isnan(ipw).any():
                print("WARN IPW: NaNs in final ipw")
                return None
                
            return ipw.detach()
            
        except Exception as e:
            print(f"Error during IPW calculation: {e}")
            return None

    def training_step(self, batch, batch_idx):
        """Training step with IPW weighting."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        later_features = batch["later_features"]
        cov_seq_hist = batch["cov_seq"]
        trt_seq_hist = batch["trt_seq"]
        lengths_hist = batch["lengths"]
        delta_t_hist = batch["delta_t"]
        side_hist = batch["side"]

        # Set training mode
        self.unet.train()
        self.img_conditioning.train()
        self.context_encoder.train()
        if not isinstance(self.rich_context_projector, torch.nn.Identity):
            self.rich_context_projector.train()

        pred_noise, noise, noisy_latents, timesteps, target_latents = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t_hist, side_hist
        )
        
        if torch.isnan(pred_noise).any():
            print("FATAL TRAIN: NaN in pred_noise")
            return None
            
        # Calculate per-sample diffusion loss
        per_sample_diff_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if torch.isnan(per_sample_diff_loss).any():
            print("FATAL TRAIN: NaN in per_sample_diff_loss")
            return None
            
        unweighted_diff_loss = per_sample_diff_loss.mean()
        
        # Apply IPW weighting if enabled
        weighted_diff_loss = unweighted_diff_loss
        ipw = None
        if not self.disable_ipw_loss:
            ipw = self._calculate_ipw(batch)
            if ipw is not None:
                if torch.isnan(ipw).any():
                    ipw = None
                    print("WARN TRAIN: NaN in IPW, using unweighted.")
                else:
                    weighted_diff_loss = (ipw.to(per_sample_diff_loss.device) * per_sample_diff_loss).mean()
                    
        # Feature loss
        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            from diffusion_utils import calculate_feature_loss
            feature_loss = calculate_feature_loss(
                self, pred_noise, timesteps, noisy_latents, later_features,
                self.feature_predictors, self.vae, self.latent_scale,
                self.trainer.precision, self.device
            )
            
        # Total loss
        loss_to_use = weighted_diff_loss if (ipw is not None and not self.disable_ipw_loss) else unweighted_diff_loss
        total_loss = loss_to_use + self.feature_loss_weight * feature_loss
        
        if torch.isnan(total_loss):
            print("FATAL TRAIN: NaN in total_loss")
            return None
            
        # Logging
        bs = input_images.size(0)
        self.log("train_loss", total_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=bs)
        self.log("train_diff_loss_uw", unweighted_diff_loss, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        
        if ipw is not None and not self.disable_ipw_loss:
            self.log("train_diff_loss_w", weighted_diff_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=bs)
            self.log("mean_ipw", ipw.mean(), prog_bar=False, sync_dist=True, batch_size=bs)
            
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            self.log("train_feat_loss", feature_loss, prog_bar=False, sync_dist=True, batch_size=bs)
            
        return total_loss

    def validation_step(self, batch, batch_idx):
        """Validation step with IPW metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"] 
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        cov_seq_hist = batch["cov_seq"]
        trt_seq_hist = batch["trt_seq"]
        lengths_hist = batch["lengths"]
        delta_t_hist = batch["delta_t"]
        side_hist = batch["side"]

        # Calculate IPW
        ipw = self._calculate_ipw(batch)

        pred_noise, noise, noisy_latents, timesteps, target_latents = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t_hist, side_hist
        )
        
        if torch.isnan(pred_noise).any():
            print("WARN VAL: NaN in pred_noise")
            return

        # Log losses
        val_loss = F.mse_loss(pred_noise, noise)
        val_loss = safe_tensor_log(val_loss, 1000.0, self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        val_mae_loss = F.l1_loss(pred_noise, noise)
        val_mae_loss = safe_tensor_log(val_mae_loss, 1000.0, self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        # IPW-weighted validation loss
        if ipw is not None:
            per_sample_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
            if not torch.isnan(per_sample_loss).any():
                ipw_val_loss = (ipw.to(per_sample_loss.device) * per_sample_loss).mean()
                if not torch.isnan(ipw_val_loss):
                    self.log("ipw_val_loss", ipw_val_loss, prog_bar=True, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Generate images and calculate metrics
        generated_images = None
        try:
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
        except Exception as e:
            print(f"ERROR val img gen: {e}")
            generated_images = target_images
            
        if generated_images is None:
            generated_images = target_images

        # Generation loss
        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=input_images.size(0))
        
        # Log SSIM metric
        val_ssim = calculate_ssim(generated_images, target_images, self.device)
        if not torch.isnan(val_ssim):
            self.log("val_ssim", val_ssim, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            
        if ipw is not None:
            gen_loss_per_sample = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
            if not torch.isnan(gen_loss_per_sample).any():
                ipw_val_gen_loss = (ipw.to(gen_loss_per_sample.device) * gen_loss_per_sample).mean()
                if not torch.isnan(ipw_val_gen_loss):
                    self.log("ipw_gen_val_loss", ipw_val_gen_loss, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, ipw=ipw
            )
            
            for metric_name, metric_value in ite_metrics.items():
                if not torch.isnan(torch.tensor(metric_value)):
                    self.log(f"val_{metric_name}", metric_value, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Log qualitative images
        if self.global_rank == 0 and batch_idx == 0:
            from diffusion_utils import log_qualitative_images
            log_qualitative_images(
                self.pipe, input_images[0:16], target_images[0:16], prompts[0:16],
                self.logger, self.current_epoch, device=self.device
            )

    def test_step(self, batch, batch_idx):
        """Test step with comprehensive metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"]
        cov_seq_hist = batch["cov_seq"]
        trt_seq_hist = batch["trt_seq"]
        lengths_hist = batch["lengths"]
        delta_t_hist = batch["delta_t"]
        side_hist = batch["side"]

        # Treatment mask
        treat_mask = (interval_labels[:, -1] == 0)
        
        # Calculate IPW
        ipw = self._calculate_ipw(batch)
        
        metrics = {}
        
        with torch.no_grad():
            pred_noise, noise, noisy_latents, timesteps, target_latents = self(
                input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
                delta_t_hist, side_hist
            )
            
        if torch.isnan(pred_noise).any():
            raise ValueError("NaN in pred_noise")
            
        # Calculate losses
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)

        # Calculate R-squared for noise prediction
        noise_flat = noise.view(noise.size(0), -1)
        pred_noise_flat = pred_noise.view(pred_noise.size(0), -1)
        ss_res = torch.sum((noise_flat - pred_noise_flat) ** 2, dim=1)
        ss_tot = torch.sum((noise_flat - torch.mean(noise_flat, dim=1, keepdim=True)) ** 2, dim=1)
        r2_noise = 1 - (ss_res / (ss_tot + 1e-8))  # Add small epsilon to avoid division by zero
        r2_noise = torch.clamp(r2_noise, min=-10, max=1)  # Clamp to reasonable range

        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()
        metrics['test_r2_noise'] = r2_noise.mean().item()
        
        # Per-sample losses
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        per_sample_mae = F.l1_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        
        if treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_loss_mae_treated'] = per_sample_mae[treat_mask].mean().item()
            metrics['test_r2_noise_treated'] = r2_noise[treat_mask].mean().item()

            if ipw is not None:
                w_t = ipw[treat_mask] / ipw[treat_mask].sum().clamp(min=EPSILON_IPW)
                ipw_mse_treated = (w_t * per_sample_mse[treat_mask]).sum()
                metrics['test_ipw_loss_treated'] = ipw_mse_treated.item()
                # IPW R-squared for treated
                ipw_r2_treated = (w_t * r2_noise[treat_mask]).sum()
                metrics['test_ipw_r2_noise_treated'] = ipw_r2_treated.item()
                
        # IPW metrics
        if ipw is not None:
            if not torch.isnan(per_sample_mse).any():
                ipw_test_loss = (ipw.to(per_sample_mse.device) * per_sample_mse).mean()
                metrics['test_ipw_loss'] = ipw_test_loss.item()
            if not torch.isnan(r2_noise).any():
                ipw_r2_noise = (ipw.to(r2_noise.device) * r2_noise).mean()
                metrics['test_ipw_r2_noise'] = ipw_r2_noise.item()
            metrics['test_mean_ipw'] = ipw.mean().item()

        # Generate images
        with torch.no_grad():
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval()
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                             self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
                
        if generated_images is None or torch.isnan(generated_images).any():
            generated_images = target_images

        # Generation metrics
        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()

        # Calculate R-squared for generated images
        target_flat = target_images.view(target_images.size(0), -1)
        generated_flat = generated_images.view(generated_images.size(0), -1)
        ss_res_gen = torch.sum((target_flat - generated_flat) ** 2, dim=1)
        # Fixed: Use overall mean across all pixels, not per-image mean
        overall_mean = torch.mean(target_flat)
        ss_tot_gen = torch.sum((target_flat - overall_mean) ** 2, dim=1)
        r2_gen = 1 - (ss_res_gen / (ss_tot_gen + 1e-8))  # Add small epsilon to avoid division by zero
        r2_gen = torch.clamp(r2_gen, min=-10, max=1)  # Clamp to reasonable range
        metrics['test_r2_gen'] = r2_gen.mean().item()

        # Calculate SSIM
        test_ssim = calculate_ssim(generated_images, target_images, self.device)
        metrics['test_ssim'] = test_ssim.item()

        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()
            metrics['test_r2_gen_treated'] = r2_gen[treat_mask].mean().item()
            if ipw is not None:
                w_t = ipw[treat_mask] / ipw[treat_mask].sum().clamp(min=EPSILON_IPW)
                ipw_gen_treated = (w_t * per_sample_gen[treat_mask]).sum()
                metrics['test_ipw_gen_loss_treated'] = ipw_gen_treated.item()
                # IPW R-squared for treated generation
                ipw_r2_gen_treated = (w_t * r2_gen[treat_mask]).sum()
                metrics['test_ipw_r2_gen_treated'] = ipw_r2_gen_treated.item()
                
        if ipw is not None:
            if not torch.isnan(per_sample_gen).any():
                ipw_test_gen_loss = (ipw.to(per_sample_gen.device) * per_sample_gen).mean()
                metrics['test_ipw_gen_loss'] = ipw_test_gen_loss.item()
            if not torch.isnan(r2_gen).any():
                ipw_r2_gen = (ipw.to(r2_gen.device) * r2_gen).mean()
                metrics['test_ipw_r2_gen'] = ipw_r2_gen.item()

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, 
                treat_mask, ipw
            )
            metrics.update({f'test_{k}': v for k, v in ite_metrics.items()})

        metrics['batch_size'] = float(input_images.size(0))
        self._test_outputs.append(metrics.copy())
        
        return {k: v for k, v in metrics.items() if torch.isfinite(torch.tensor(v))}

    def configure_optimizers(self):
        """Configure optimizers."""
        gen_params = (list(self.unet.parameters()) + 
                     list(self.img_conditioning.parameters()) +
                     list(self.context_encoder.parameters()))
        if not isinstance(self.rich_context_projector, torch.nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
        return AdamW(gen_params, lr=self.lr)


# Main script
if __name__ == "__main__":
    args = parse_args()
    pl.seed_everything(42, workers=True)
    
    print(f"Run Name: {args.run_name}")
    
    # Image transforms
    image_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Collate function
    collate_fn = create_unified_collate_fn(
        include_ipw=True,
        propensity_model_uses_images=args.propensity_model_uses_images
    )
    
    if args.test_only:
        # Load propensity model
        print("Loading Temporal Propensity Model for testing...")
        propensity_model = TemporalPropensityModel(
            cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim,
            num_layers=args.propensity_num_layers, dropout=args.propensity_dropout,
            model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
            delta_t_feat_dim=args.propensity_delta_t_feat_dim
        )
        state_dict = torch.load(args.propensity_model_path, map_location="cpu")
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
        propensity_model.load_state_dict(state_dict)
        propensity_model.eval()
        
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Load model from checkpoint
        diffusion_model = IPWDiffusionLitModule.load_from_checkpoint(
            args.ckpt_path, map_location="cpu", args=args,
            vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning,
            propensity_model=propensity_model,
            delta_t_mean=delta_t_mean, delta_t_std=delta_t_std,
            strict=False
        )
        diffusion_model.hparams.results_file = args.results_file
        diffusion_model.hparams.ckpt_path = args.ckpt_path
        
        # Test dataset and loader
        test_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "test.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        test_loader = DataLoader(
            test_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=False
        )
        
        # Run test
        trainer = pl.Trainer(
            accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=False, callbacks=[]
        )
        trainer.test(model=diffusion_model, dataloaders=test_loader, verbose=True)
        
    else:
        # Training mode
        train_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "train.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        val_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "val.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        
        # Data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, 
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn
        )
        
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Load propensity model
        propensity_model = None
        if os.path.exists(args.propensity_model_path):
            try:
                propensity_model = TemporalPropensityModel(
                    cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim,
                    num_layers=args.propensity_num_layers, dropout=args.propensity_dropout,
                    model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
                    delta_t_feat_dim=args.propensity_delta_t_feat_dim
                ).to("cuda")
                state_dict = torch.load(args.propensity_model_path, map_location="cuda")
                if list(state_dict.keys())[0].startswith('module.'):
                    state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                propensity_model.load_state_dict(state_dict)
                propensity_model.eval()
                for p in propensity_model.parameters():
                    p.requires_grad_(False)
                print(f"Loaded propensity model from {args.propensity_model_path}")
            except Exception as e:
                print(f"Error loading propensity model: {e}")
                propensity_model = None
        else:
            print(f"Propensity model not found: {args.propensity_model_path}")
            
        if propensity_model is None:
            print("ERROR: Propensity model failed to load but IPW is enabled. Exiting.")
            exit(1)
        
        # Create model
        diffusion_model = IPWDiffusionLitModule(
            args=args, vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning, propensity_model=propensity_model,
            delta_t_mean=delta_t_mean, delta_t_std=delta_t_std
        )
        
        # Callbacks and logger
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss", dirpath=args.checkpoint_dir,
            filename="best-{epoch}-val_loss={val_loss:.4f}",
            auto_insert_metric_name=False, save_top_k=2, mode="min", save_last=True
        )
        tensorboard_logger = TensorBoardLogger(save_dir=args.log_dir_base, name=args.run_name)
        
        # Trainer
        trainer = pl.Trainer(
            max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices,
            precision=args.precision, logger=tensorboard_logger, log_every_n_steps=25,
            strategy="ddp_find_unused_parameters_true", callbacks=[checkpoint_callback],
            accumulate_grad_batches=args.accumulate_grad_batches
        )
        
        print(f"Starting training run: {args.run_name}")
        trainer.fit(diffusion_model, train_loader, val_loader)
        print("--- Training Finished ---")
        print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")