"""
Refactored Adversarial + IPW Diffusion Model Training
Simplified version using base classes and utility functions.
"""
import os
import argparse
from functools import partial
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score

from pretrain_propensity_model_temporal import TemporalPropensityModel
from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset, NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES
from diff_base import BaseDiffusionLitModule, BaseAdversarialMixin, BaseIPWMixin, ContextEncoderRNN, TreatmentDiscriminator
from data_utils import create_unified_collate_fn
from diffusion_utils import safe_tensor_log, calculate_discriminator_metrics, EPSILON_DELTA_T, EPSILON_IPW, calculate_ssim

torch.set_float32_matmul_precision('high')


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train Adversarial + IPW Diffusion Model")
    
    # Mode
    parser.add_argument("--test_only", action="store_true", help="Run in testing mode only")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint for testing")
    parser.add_argument("--results_file", type=str, default="test_results.txt", help="File to save test metrics")
    
    # Paths
    parser.add_argument("--propensity_model_path", type=str, 
                       default="./checkpoints/Propensity_Model_Temporal/propensity_model_temporal_best.pth")
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    parser.add_argument("--base_checkpoint_dir", type=str, default="/local2/acc/OAI_checkpoints")
    parser.add_argument("--log_dir_base", type=str, default="/local2/acc/OAI_checkpoints/tb_logs")
    
    # Model Params
    parser.add_argument("--diffusion_model_pretrain", type=str, default="runwayml/stable-diffusion-v1-5")
    
    # Training Params
    parser.add_argument("--lr_generator", type=float, default=1e-5, help="Generator learning rate")
    parser.add_argument("--lr_discriminator", type=float, default=1e-4, help="Discriminator learning rate")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--max_epochs", type=int, default=50)
    parser.add_argument("--feature_loss_weight", type=float, default=0.1, help="Weight for perceptual feature loss")
    parser.add_argument("--adversarial_weight", type=float, default=0.01, help="Weight for adversarial loss")
    parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="Gradient clipping value")
    parser.add_argument("--precision", type=str, default="16-mixed", choices=["16-mixed", "bf16-mixed", "32-true"])
    parser.add_argument("--devices", type=int, nargs='+', default=[0,1], help="GPU devices to use")
    
    # IPW Params
    parser.add_argument("--disable_ipw_loss", action='store_true', help="Train without IPW loss weighting")
    parser.add_argument("--ipw_max_clamp", type=float, default=150.0, help="Max value to clamp IPW weights")
    
    # Discriminator Params
    parser.add_argument("--use_disc_pos_weight", action='store_true', help="Use positive weight for discriminator")
    
    # Propensity Model Params
    parser.add_argument("--propensity_model_type", type=str, default='transformer')
    parser.add_argument("--propensity_hidden_dim", type=int, default=128)
    parser.add_argument("--propensity_num_layers", type=int, default=2)
    parser.add_argument("--propensity_dropout", type=float, default=0.2)
    parser.add_argument("--propensity_delta_t_feat_dim", type=int, default=32)
    parser.add_argument("--propensity_model_uses_images", action='store_true')
    
    # Context Encoder
    parser.add_argument("--disc_model_type", type=str, choices=["rnn", "transformer"], default="rnn")
    parser.add_argument("--disc_hidden_dim", type=int, default=128)
    parser.add_argument("--disc_num_layers", type=int, default=2)
    parser.add_argument("--disc_dropout", type=float, default=0.2)
    parser.add_argument("--disc_delta_t_feat_dim", type=int, default=8)
    parser.add_argument("--disc_side_feat_dim", type=int, default=4)
    parser.add_argument("--disc_current_img_cond_dim", type=int, default=768)
    
    # Other
    parser.add_argument("--run_suffix", type=str, default="ai", help="Suffix for run name")

    args = parser.parse_args()
    
    if args.test_only:
        if not args.ckpt_path or not os.path.exists(args.ckpt_path):
            parser.error("--ckpt_path is required and must exist for --test_only mode.")
        args.run_name = f"TEST_{os.path.basename(args.ckpt_path)}"
    else:
        suffix = "_NoIPWLoss" if args.disable_ipw_loss else ""
        args.run_name = f"{args.run_suffix}{suffix}_b{args.batch_size}x{len(args.devices)}"
        args.checkpoint_dir = os.path.join(args.base_checkpoint_dir, args.run_name)
        args.log_dir = os.path.join(args.log_dir_base, args.run_name)
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        
    return args


class AdversarialIPWDiffusionLitModule(BaseAdversarialMixin, BaseIPWMixin, BaseDiffusionLitModule):
    """Combined Adversarial + IPW Diffusion model using base functionality."""
    
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 propensity_model=None, discriminator_pos_weight=None, delta_t_mean=0.0, delta_t_std=1.0):
        super().__init__(args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                        delta_t_mean, delta_t_std)
        
        # Initialize both adversarial and IPW components
        self._init_adversarial_components(args, discriminator_pos_weight)
        self._init_ipw_components(args, propensity_model)
        
        # Save hyperparameters including learning rates
        self.save_hyperparameters(args)

    def _calculate_ipw(self, batch):
        """Calculate IPW weights using propensity model."""
        if self.propensity_model is None:
            return None
            
        required_keys = ["cov_seq", "trt_seq", "lengths", "delta_t", "side"]
        if not all(k in batch and batch[k] is not None for k in required_keys):
            print("WARN IPW: Missing required inputs for IPW calculation.")
            return None
            
        try:
            cov_seq = batch["cov_seq"]
            trt_seq = batch["trt_seq"] 
            lengths = batch["lengths"]
            delta_t = batch["delta_t"]
            side = batch["side"]
            interval_labels = batch["interval_labels"]
            
            # Handle image sequence if needed
            image_seq = None
            if self.hparams.propensity_model_uses_images:
                if "image_seq" in batch and batch["image_seq"] is not None:
                    image_seq = batch["image_seq"]
            
            # Move to propensity model device
            prop_device = next(self.propensity_model.parameters()).device
            model_dtype = next(self.propensity_model.parameters()).dtype
            
            cov_seq_p = cov_seq.to(prop_device, dtype=model_dtype)
            trt_seq_p = trt_seq.to(prop_device, dtype=model_dtype)
            delta_t_p = delta_t.to(prop_device, dtype=model_dtype)
            side_p = side.to(prop_device, dtype=model_dtype)
            
            if image_seq is not None:
                image_seq = image_seq.to(prop_device, dtype=model_dtype)
                
            delta_t_norm = (delta_t_p - self.propensity_delta_t_mean) / (self.propensity_delta_t_std + EPSILON_DELTA_T)
            
            # Calculate propensity
            with torch.no_grad(), autocast('cuda', enabled=False):
                propensity_logits = self.propensity_model(
                    cov_seq_p, trt_seq_p, lengths.cpu(), delta_t_norm, image_seq, side_p
                )
                
            if torch.isnan(propensity_logits).any():
                print("WARN IPW: NaNs in logits")
                return None
                
            propensity_probs = torch.sigmoid(propensity_logits.float())
            propensity_probs = torch.clamp(propensity_probs, 1e-6, 1.0 - 1e-6)
            
            if torch.isnan(propensity_probs).any():
                print("WARN IPW: NaNs in probs")
                return None
                
            label = interval_labels.to(prop_device).float()
            joint_prob = torch.prod(torch.where(label == 1, propensity_probs, 1.0 - propensity_probs), dim=1)
            
            if torch.isnan(joint_prob).any():
                print("WARN IPW: NaNs in joint_prob")
                return None
                
            ipw = 1.0 / (joint_prob + EPSILON_IPW)
            ipw = torch.clamp(ipw, min=0.01, max=self.ipw_max_clamp)
            
            if torch.isnan(ipw).any():
                print("WARN IPW: NaNs in final ipw")
                return None
                
            return ipw.detach()
            
        except Exception as e:
            print(f"Error during IPW calculation: {e}")
            return None

    def forward(self, input_images, target_images, prompts, 
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side):
        """Forward pass with context encoding."""
        input_dtype = input_images.dtype
        
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                      self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, 
                                    (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps)
            
            text_inputs = self.tokenizer(prompts, padding="max_length", 
                                       max_length=self.tokenizer.model_max_length, 
                                       truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0]
            
            # Image conditioning from current input
            img_cond_current = self.img_conditioning(input_latents)
            
            # Rich context from history
            rich_context = self.context_encoder(
                img_cond_current,
                cov_seq_hist.to(self.device).to(img_cond_current.dtype),
                trt_seq_hist.to(self.device).to(img_cond_current.dtype),
                lengths_hist.to(self.device),
                delta_t.to(self.device).to(img_cond_current.dtype),
                side.to(self.device).to(img_cond_current.dtype),
                delta_t_mean=self.propensity_delta_t_mean,
                delta_t_std=self.propensity_delta_t_std
            )
            
            projected_rich_context = self.rich_context_projector(rich_context)
            
            # Combined conditioning
            unet_conditioning = text_embeddings + projected_rich_context.unsqueeze(1)
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, 
                                 encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample
            
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype), rich_context

    def training_step(self, batch, batch_idx):
        """Training step with both adversarial and IPW losses."""
        g_opt, d_opt = self.optimizers()
        
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        # Set training mode
        self.unet.train()
        self.img_conditioning.train()
        self.context_encoder.train()
        self.treatment_discriminator.train()
        if not isinstance(self.rich_context_projector, nn.Identity):
            self.rich_context_projector.train()

        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t, side
        )
        
        if torch.isnan(pred_noise).any():
            print("FATAL TRAIN: NaN in pred_noise")
            return None
            
        # Calculate per-sample diffusion loss for IPW weighting
        per_sample_diff_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        unweighted_diff_loss = per_sample_diff_loss.mean()
        
        # Apply IPW weighting if enabled
        weighted_diff_loss = unweighted_diff_loss
        ipw = None
        if not self.disable_ipw_loss:
            ipw = self._calculate_ipw(batch)
            if ipw is not None:
                if torch.isnan(ipw).any():
                    ipw = None
                    print("WARN TRAIN: NaN in IPW, using unweighted.")
                else:
                    weighted_diff_loss = (ipw.to(per_sample_diff_loss.device) * per_sample_diff_loss).mean()
        
        # Use weighted or unweighted diffusion loss
        diffusion_loss = weighted_diff_loss if (ipw is not None and not self.disable_ipw_loss) else unweighted_diff_loss
        diffusion_loss = safe_tensor_log(diffusion_loss, 1000.0, self.device)
        
        # Feature loss
        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            from diffusion_utils import calculate_feature_loss
            feature_loss = calculate_feature_loss(
                self, pred_noise, timesteps, noisy_latents, later_features,
                self.feature_predictors, self.vae, self.latent_scale,
                self.trainer.precision, self.device
            )
            
        # Generator adversarial loss (want discriminator to predict 0, i.e., no treatment)
        treatment_logits = self.treatment_discriminator(rich_context.detach())
        target_no_treatment = torch.zeros_like(interval_labels, dtype=torch.float)
        adversarial_loss_g = self.criterion_discriminator(treatment_logits, target_no_treatment)
        adversarial_loss_g = safe_tensor_log(adversarial_loss_g, 1000.0, self.device)
        
        # Total generator loss
        generator_loss = (diffusion_loss + 
                         self.feature_loss_weight * feature_loss + 
                         self.adversarial_weight * adversarial_loss_g)
        generator_loss = safe_tensor_log(generator_loss, 1000.0, self.device)

        # Optimize generator
        g_opt.zero_grad()
        self.manual_backward(generator_loss)
        torch.nn.utils.clip_grad_norm_(
            [p for group in g_opt.param_groups for p in group['params']], 
            max_norm=self.hparams.gradient_clip_val
        )
        g_opt.step()

        # Train discriminator (want to correctly predict actual treatment labels)
        treatment_logits_d = self.treatment_discriminator(rich_context.detach())
        discriminator_loss = self.criterion_discriminator(treatment_logits_d, interval_labels.float())
        discriminator_loss = safe_tensor_log(discriminator_loss, 1000.0, self.device)

        # Optimize discriminator
        d_opt.zero_grad()
        self.manual_backward(discriminator_loss)
        torch.nn.utils.clip_grad_norm_(self.treatment_discriminator.parameters(), 
                                      max_norm=self.hparams.gradient_clip_val)
        d_opt.step()

        # Logging
        bs = input_images.size(0)
        self.log("train_loss", generator_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=bs)
        self.log("train_diffusion_loss", diffusion_loss, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        self.log("train_diff_loss_uw", unweighted_diff_loss, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        self.log("train_adversarial_loss_g", adversarial_loss_g, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        self.log("train_discriminator_loss", discriminator_loss, prog_bar=True, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        
        if ipw is not None and not self.disable_ipw_loss:
            self.log("train_diff_loss_w", weighted_diff_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=bs)
            self.log("mean_ipw", ipw.mean(), prog_bar=False, sync_dist=True, batch_size=bs)
            
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            self.log("train_feature_loss", feature_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=bs)

    def validation_step(self, batch, batch_idx):
        """Validation step with both adversarial and IPW metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"] 
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        # Calculate IPW
        ipw = self._calculate_ipw(batch)

        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context_val = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t, side
        )
        
        if torch.isnan(pred_noise).any():
            print("WARN VAL: NaN pred_noise")
            return

        # Log losses
        val_loss = F.mse_loss(pred_noise, noise)
        val_loss = safe_tensor_log(val_loss, 1000.0, self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        val_mae_loss = F.l1_loss(pred_noise, noise)
        val_mae_loss = safe_tensor_log(val_mae_loss, 1000.0, self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        # IPW-weighted validation loss
        if ipw is not None:
            per_sample_loss = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
            if not torch.isnan(per_sample_loss).any():
                ipw_val_loss = (ipw.to(per_sample_loss.device) * per_sample_loss).mean()
                if not torch.isnan(ipw_val_loss):
                    self.log("ipw_val_loss", ipw_val_loss, prog_bar=True, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Discriminator performance
        with torch.no_grad():
            self.treatment_discriminator.eval()
            treatment_logits_val = self.treatment_discriminator(rich_context_val)
            val_disc_loss = self.criterion_discriminator(treatment_logits_val, interval_labels.float())
            
            # Calculate discriminator accuracy and AUC
            disc_acc, disc_auc = calculate_discriminator_metrics(treatment_logits_val, interval_labels)
            
            self.log("val_disc_loss", val_disc_loss, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            self.log("val_disc_acc", disc_acc, prog_bar=True, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            self.log("val_disc_auc", disc_auc, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))

        # Generate images
        generated_images = None
        try:
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
        except Exception:
            print("WARN VAL: Image Generation Failed")
            generated_images = target_images
            
        if generated_images is None:
            generated_images = target_images

        # Generation loss
        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=input_images.size(0))
        
        # Log SSIM metric
        val_ssim = calculate_ssim(generated_images, target_images, self.device)
        if not torch.isnan(val_ssim):
            self.log("val_ssim", val_ssim, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            
        if ipw is not None:
            gen_loss_per_sample = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
            if not torch.isnan(gen_loss_per_sample).any():
                ipw_val_gen_loss = (ipw.to(gen_loss_per_sample.device) * gen_loss_per_sample).mean()
                if not torch.isnan(ipw_val_gen_loss):
                    self.log("ipw_gen_val_loss", ipw_val_gen_loss, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, ipw=ipw
            )
            
            for metric_name, metric_value in ite_metrics.items():
                if not np.isnan(metric_value):
                    self.log(f"val_{metric_name}", metric_value, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Log qualitative images
        if self.global_rank == 0 and batch_idx == 0:
            from diffusion_utils import log_qualitative_images
            log_qualitative_images(
                self.pipe, input_images[0:16], target_images[0:16], prompts[0:16],
                self.logger, self.current_epoch, device=self.device
            )

    def test_step(self, batch, batch_idx):
        """Test step with comprehensive metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        # Treatment mask
        treat_mask = (interval_labels[:, -1] == 0)
        
        # Calculate IPW
        ipw = self._calculate_ipw(batch)
        
        metrics = {}
        
        with torch.no_grad():
            pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context = self(
                input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
                delta_t, side
            )
            
        if torch.isnan(pred_noise).any():
            raise ValueError("NaN in pred_noise")
            
        # Calculate losses
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)

        # Calculate R-squared for noise prediction
        noise_flat = noise.view(noise.size(0), -1)
        pred_noise_flat = pred_noise.view(pred_noise.size(0), -1)
        ss_res = torch.sum((noise_flat - pred_noise_flat) ** 2, dim=1)
        ss_tot = torch.sum((noise_flat - torch.mean(noise_flat, dim=1, keepdim=True)) ** 2, dim=1)
        r2_noise = 1 - (ss_res / (ss_tot + 1e-8))  # Add small epsilon to avoid division by zero
        r2_noise = torch.clamp(r2_noise, min=-10, max=1)  # Clamp to reasonable range

        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()
        metrics['test_r2_noise'] = r2_noise.mean().item()
        
        # Per-sample losses
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_r2_noise_treated'] = r2_noise[treat_mask].mean().item()

            if ipw is not None:
                w_t = ipw[treat_mask] / ipw[treat_mask].sum().clamp(min=EPSILON_IPW)
                ipw_mse_treated = (w_t * per_sample_mse[treat_mask]).sum()
                metrics['test_ipw_loss_treated'] = ipw_mse_treated.item()
                # IPW R-squared for treated
                ipw_r2_treated = (w_t * r2_noise[treat_mask]).sum()
                metrics['test_ipw_r2_noise_treated'] = ipw_r2_treated.item()
                
        # IPW metrics
        if ipw is not None:
            if not torch.isnan(per_sample_mse).any():
                ipw_test_loss = (ipw.to(per_sample_mse.device) * per_sample_mse).mean()
                metrics['test_ipw_loss'] = ipw_test_loss.item()
            if not torch.isnan(r2_noise).any():
                ipw_r2_noise = (ipw.to(r2_noise.device) * r2_noise).mean()
                metrics['test_ipw_r2_noise'] = ipw_r2_noise.item()
            metrics['test_mean_ipw'] = ipw.mean().item()

        # Discriminator metrics
        treatment_logits_test = self.treatment_discriminator(rich_context)
        disc_loss = self.criterion_discriminator(treatment_logits_test, interval_labels.float())
        metrics['test_disc_loss'] = disc_loss.item()
        
        disc_acc, disc_auc = calculate_discriminator_metrics(treatment_logits_test, interval_labels)
        metrics['test_disc_acc'] = disc_acc
        metrics['test_disc_auc'] = disc_auc

        # Generate images
        with torch.no_grad():
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval()
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                             self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
                
        if generated_images is None or torch.isnan(generated_images).any():
            generated_images = target_images

        # Generation metrics
        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()

        # Calculate R-squared for generated images
        target_flat = target_images.view(target_images.size(0), -1)
        generated_flat = generated_images.view(generated_images.size(0), -1)
        ss_res_gen = torch.sum((target_flat - generated_flat) ** 2, dim=1)
        # Fixed: Use overall mean across all pixels, not per-image mean
        overall_mean = torch.mean(target_flat)
        ss_tot_gen = torch.sum((target_flat - overall_mean) ** 2, dim=1)
        r2_gen = 1 - (ss_res_gen / (ss_tot_gen + 1e-8))  # Add small epsilon to avoid division by zero
        r2_gen = torch.clamp(r2_gen, min=-10, max=1)  # Clamp to reasonable range
        metrics['test_r2_gen'] = r2_gen.mean().item()

        # Calculate SSIM
        test_ssim = calculate_ssim(generated_images, target_images, self.device)
        metrics['test_ssim'] = test_ssim.item()

        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()
            metrics['test_r2_gen_treated'] = r2_gen[treat_mask].mean().item()
            if ipw is not None:
                w_t = ipw[treat_mask] / ipw[treat_mask].sum().clamp(min=EPSILON_IPW)
                ipw_gen_treated = (w_t * per_sample_gen[treat_mask]).sum()
                metrics['test_ipw_gen_loss_treated'] = ipw_gen_treated.item()
                # IPW R-squared for treated generation
                ipw_r2_gen_treated = (w_t * r2_gen[treat_mask]).sum()
                metrics['test_ipw_r2_gen_treated'] = ipw_r2_gen_treated.item()

        if ipw is not None:
            if not torch.isnan(per_sample_gen).any():
                ipw_test_gen_loss = (ipw.to(per_sample_gen.device) * per_sample_gen).mean()
                metrics['test_ipw_gen_loss'] = ipw_test_gen_loss.item()
            if not torch.isnan(r2_gen).any():
                ipw_r2_gen = (ipw.to(r2_gen.device) * r2_gen).mean()
                metrics['test_ipw_r2_gen'] = ipw_r2_gen.item()

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, 
                treat_mask, ipw
            )
            metrics.update({f'test_{k}': v for k, v in ite_metrics.items()})

        metrics['batch_size'] = float(input_images.size(0))
        self._test_outputs.append(metrics.copy())
        
        return {k: v for k, v in metrics.items() if np.isfinite(v)}

    def configure_optimizers(self):
        """Configure generator and discriminator optimizers."""
        # Generator parameters
        gen_params = (list(self.unet.parameters()) + 
                     list(self.img_conditioning.parameters()) +
                     list(self.context_encoder.parameters()))
        if not isinstance(self.rich_context_projector, nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
            
        # Optimizers
        g_optimizer = AdamW(gen_params, lr=self.lr_generator)
        d_optimizer = AdamW(self.treatment_discriminator.parameters(), lr=self.lr_discriminator)
        
        return [g_optimizer, d_optimizer]


# Main script
if __name__ == "__main__":
    args = parse_args()
    pl.seed_everything(42, workers=True)
    
    print(f"Run Name: {args.run_name}")
    
    # Image transforms
    image_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Collate function
    collate_fn = create_unified_collate_fn(
        include_adversarial=True,
        include_ipw=True,
        propensity_model_uses_images=args.propensity_model_uses_images
    )
    
    if args.test_only:
        # Load propensity model
        print("Loading Temporal Propensity Model for testing...")
        propensity_model = TemporalPropensityModel(
            cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim,
            num_layers=args.propensity_num_layers, dropout=args.propensity_dropout,
            model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
            delta_t_feat_dim=args.propensity_delta_t_feat_dim
        )
        state_dict = torch.load(args.propensity_model_path, map_location="cpu")
        if list(state_dict.keys())[0].startswith('module.'):
            state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
        propensity_model.load_state_dict(state_dict)
        propensity_model.eval()
        
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Load model from checkpoint
        diffusion_model = AdversarialIPWDiffusionLitModule.load_from_checkpoint(
            args.ckpt_path, map_location="cpu", args=args,
            vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning, propensity_model=propensity_model,
            strict=False
        )
        diffusion_model.hparams.results_file = args.results_file
        diffusion_model.hparams.ckpt_path = args.ckpt_path
        
        # Test dataset and loader
        test_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "test.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        test_loader = DataLoader(
            test_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=False
        )
        
        # Run test
        trainer = pl.Trainer(
            accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=False, callbacks=[]
        )
        trainer.test(model=diffusion_model, dataloaders=test_loader, verbose=True)
        
    else:
        # Training mode
        train_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "train.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        val_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "val.csv"),
            image_transform=image_transforms,
            include_images=args.propensity_model_uses_images
        )
        
        # Data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, 
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn
        )
        
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Load propensity model
        propensity_model = None
        if os.path.exists(args.propensity_model_path):
            try:
                propensity_model = TemporalPropensityModel(
                    cov_dim=FEATURE_DIM, hidden_dim=args.propensity_hidden_dim,
                    num_layers=args.propensity_num_layers, dropout=args.propensity_dropout,
                    model_type=args.propensity_model_type, include_images=args.propensity_model_uses_images,
                    delta_t_feat_dim=args.propensity_delta_t_feat_dim
                ).to("cuda")
                state_dict = torch.load(args.propensity_model_path, map_location="cuda")
                if list(state_dict.keys())[0].startswith('module.'):
                    state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
                propensity_model.load_state_dict(state_dict)
                propensity_model.eval()
                for p in propensity_model.parameters():
                    p.requires_grad_(False)
                print(f"Loaded propensity model from {args.propensity_model_path}")
            except Exception as e:
                print(f"Error loading propensity model: {e}")
                propensity_model = None
        else:
            print(f"Propensity model not found: {args.propensity_model_path}")
            
        if propensity_model is None:
            print("ERROR: Propensity model failed to load but IPW is enabled. Exiting.")
            exit(1)
            
        # Discriminator positive weight if enabled
        discriminator_pos_weight = None
        if args.use_disc_pos_weight:
            pos_weight_value = 2.0  # Can be made configurable
            discriminator_pos_weight = torch.tensor([pos_weight_value] * NUM_TREATMENTS)
            print(f"Using discriminator positive weight: {pos_weight_value}")
        
        # Create model
        diffusion_model = AdversarialIPWDiffusionLitModule(
            args=args, vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning, propensity_model=propensity_model,
            discriminator_pos_weight=discriminator_pos_weight
        )
        
        # Callbacks and logger
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss", dirpath=args.checkpoint_dir,
            filename="best-{epoch}-val_loss={val_loss:.4f}",
            auto_insert_metric_name=False, save_top_k=2, mode="min", save_last=True
        )
        tensorboard_logger = TensorBoardLogger(save_dir=args.log_dir_base, name=args.run_name)
        
        # Trainer
        trainer = pl.Trainer(
            max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices,
            precision=args.precision, logger=tensorboard_logger, log_every_n_steps=25,
            strategy="ddp_find_unused_parameters_true", callbacks=[checkpoint_callback]
        )
        
        print(f"Starting training run: {args.run_name}")
        trainer.fit(diffusion_model, train_loader, val_loader)
        print("--- Training Finished ---")
        print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")