"""
Refactored Adversarial Diffusion Model Training
Simplified version using base classes and utility functions.
"""
import os
import argparse
from functools import partial
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import autocast
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score

from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset, NUM_TREATMENTS, FEATURE_DIM, PERCEPTUAL_FEATURES
from diff_base import BaseDiffusionLitModule, BaseAdversarialMixin, ContextEncoderRNN, TreatmentDiscriminator
from data_utils import create_unified_collate_fn
from diffusion_utils import safe_tensor_log, calculate_discriminator_metrics, calculate_ssim

torch.set_float32_matmul_precision('high')


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train Adversarial Diffusion Model")
    
    # Mode
    parser.add_argument("--test_only", action="store_true", help="Run in testing mode only")
    parser.add_argument("--ckpt_path", type=str, default=None, help="Path to checkpoint for testing")
    parser.add_argument("--results_file", type=str, default="test_results.txt", help="File to save test metrics")
    
    # Paths
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    parser.add_argument("--base_checkpoint_dir", type=str, default="/local2/acc/OAI_checkpoints")
    parser.add_argument("--log_dir_base", type=str, default="/local2/acc/OAI_checkpoints/tb_logs")
    
    # Model Params
    parser.add_argument("--diffusion_model_pretrain", type=str, default="runwayml/stable-diffusion-v1-5")
    
    # Training Params
    parser.add_argument("--lr_generator", type=float, default=1e-5, help="Generator learning rate")
    parser.add_argument("--lr_discriminator", type=float, default=1e-4, help="Discriminator learning rate")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--max_epochs", type=int, default=50)
    parser.add_argument("--feature_loss_weight", type=float, default=0.1, help="Weight for perceptual feature loss")
    parser.add_argument("--adversarial_weight", type=float, default=0.01, help="Weight for adversarial loss")
    parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="Gradient clipping value")
    parser.add_argument("--precision", type=str, default="16-mixed", choices=["16-mixed", "bf16-mixed", "32-true"])
    parser.add_argument("--devices", type=int, nargs='+', default=[0,1], help="GPU devices to use")
    
    # Discriminator Params
    parser.add_argument("--use_disc_pos_weight", action='store_true', help="Use positive weight for discriminator")
    
    # Context Encoder
    parser.add_argument("--disc_model_type", type=str, choices=["rnn", "transformer"], default="rnn")
    parser.add_argument("--disc_hidden_dim", type=int, default=128)
    parser.add_argument("--disc_num_layers", type=int, default=2)
    parser.add_argument("--disc_dropout", type=float, default=0.2)
    parser.add_argument("--disc_delta_t_feat_dim", type=int, default=8)
    parser.add_argument("--disc_side_feat_dim", type=int, default=4)
    parser.add_argument("--disc_current_img_cond_dim", type=int, default=768)
    
    # Other
    parser.add_argument("--run_suffix", type=str, default="a", help="Suffix for run name")

    args = parser.parse_args()
    
    if args.test_only:
        if not args.ckpt_path or not os.path.exists(args.ckpt_path):
            parser.error("--ckpt_path is required and must exist for --test_only mode.")
        args.run_name = f"TEST_{os.path.basename(args.ckpt_path)}"
    else:
        args.run_name = (f"{args.run_suffix}_flw{args.feature_loss_weight}_lrg{args.lr_generator}_"
                        f"lrd{args.lr_discriminator}_aw{args.adversarial_weight}_"
                        f"pw{args.use_disc_pos_weight}_gc{args.gradient_clip_val}_"
                        f"b{args.batch_size}x{len(args.devices)}_dnl{args.disc_num_layers}_"
                        f"dhd{args.disc_hidden_dim}_icd{args.disc_current_img_cond_dim}")
        args.checkpoint_dir = os.path.join(args.base_checkpoint_dir, args.run_name)
        args.log_dir = os.path.join(args.log_dir_base, args.run_name)
        os.makedirs(args.checkpoint_dir, exist_ok=True)
        
    return args


class AdversarialDiffusionLitModule(BaseAdversarialMixin, BaseDiffusionLitModule):
    """Adversarial Diffusion model using base functionality."""
    
    def __init__(self, args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                 discriminator_pos_weight=None, delta_t_mean=0.0, delta_t_std=1.0):
        super().__init__(args, vae, unet, scheduler, text_encoder, tokenizer, img_conditioning,
                        delta_t_mean, delta_t_std)
        
        # Initialize adversarial components
        self._init_adversarial_components(args, discriminator_pos_weight)
        
        # Save hyperparameters including learning rates
        self.save_hyperparameters(args)

    def forward(self, input_images, target_images, prompts, 
                cov_seq_hist, trt_seq_hist, lengths_hist, delta_t, side):
        """Forward pass with context encoding."""
        input_dtype = input_images.dtype
        
        with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                      self.trainer.precision.startswith("bf16")):
            input_latents = self.vae.encode(input_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            target_latents = self.vae.encode(target_images.to(self.vae.dtype)).latent_dist.sample() * self.latent_scale
            timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, 
                                    (input_latents.size(0),), device=self.device).long()
            noise = torch.randn_like(target_latents)
            noisy_latents = self.scheduler.add_noise(target_latents, noise, timesteps)
            
            text_inputs = self.tokenizer(prompts, padding="max_length", 
                                       max_length=self.tokenizer.model_max_length, 
                                       truncation=True, return_tensors="pt").input_ids.to(self.device)
            text_embeddings = self.text_encoder(text_inputs)[0]
            
            # Image conditioning from current input
            img_cond_current = self.img_conditioning(input_latents)
            
            # Rich context from history
            rich_context = self.context_encoder(
                img_cond_current,
                cov_seq_hist.to(self.device).to(img_cond_current.dtype),
                trt_seq_hist.to(self.device).to(img_cond_current.dtype),
                lengths_hist.to(self.device),
                delta_t.to(self.device).to(img_cond_current.dtype),
                side.to(self.device).to(img_cond_current.dtype),
                delta_t_mean=self.delta_t_mean,
                delta_t_std=self.delta_t_std
            )
            
            projected_rich_context = self.rich_context_projector(rich_context)
            
            # Combined conditioning
            unet_conditioning = text_embeddings + projected_rich_context.unsqueeze(1)
            pred_noise = self.unet(noisy_latents.to(self.unet.dtype), timesteps, 
                                 encoder_hidden_states=unet_conditioning.to(self.unet.dtype)).sample
            
        return pred_noise, noise, noisy_latents, timesteps, target_latents.to(input_dtype), rich_context

    def training_step(self, batch, batch_idx):
        """Training step with adversarial loss."""
        g_opt, d_opt = self.optimizers()
        
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        # Set training mode
        self.unet.train()
        self.img_conditioning.train()
        self.context_encoder.train()
        self.treatment_discriminator.train()
        if not isinstance(self.rich_context_projector, nn.Identity):
            self.rich_context_projector.train()

        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t, side
        )
        
        if torch.isnan(pred_noise).any():
            print("FATAL TRAIN: NaN in pred_noise")
            return None
            
        # Diffusion loss
        diffusion_loss = F.mse_loss(pred_noise, noise)
        diffusion_loss = safe_tensor_log(diffusion_loss, 1000.0, self.device)
        
        # Feature loss
        feature_loss = torch.tensor(0.0, device=self.device)
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            from diffusion_utils import calculate_feature_loss
            feature_loss = calculate_feature_loss(
                self, pred_noise, timesteps, noisy_latents, later_features,
                self.feature_predictors, self.vae, self.latent_scale,
                self.trainer.precision, self.device
            )
            
        # Generator adversarial loss (want discriminator to predict 0, i.e., no treatment)
        treatment_logits = self.treatment_discriminator(rich_context.detach())
        target_no_treatment = torch.zeros_like(interval_labels, dtype=torch.float)
        adversarial_loss_g = self.criterion_discriminator(treatment_logits, target_no_treatment)
        adversarial_loss_g = safe_tensor_log(adversarial_loss_g, 1000.0, self.device)
        
        # Total generator loss
        generator_loss = (diffusion_loss + 
                         self.feature_loss_weight * feature_loss + 
                         self.adversarial_weight * adversarial_loss_g)
        generator_loss = safe_tensor_log(generator_loss, 1000.0, self.device)

        # Optimize generator
        g_opt.zero_grad()
        self.manual_backward(generator_loss)
        torch.nn.utils.clip_grad_norm_(
            [p for group in g_opt.param_groups for p in group['params']], 
            max_norm=self.hparams.gradient_clip_val
        )
        g_opt.step()

        # Train discriminator (want to correctly predict actual treatment labels)
        treatment_logits_d = self.treatment_discriminator(rich_context.detach())
        discriminator_loss = self.criterion_discriminator(treatment_logits_d, interval_labels.float())
        discriminator_loss = safe_tensor_log(discriminator_loss, 1000.0, self.device)

        # Optimize discriminator
        d_opt.zero_grad()
        self.manual_backward(discriminator_loss)
        torch.nn.utils.clip_grad_norm_(self.treatment_discriminator.parameters(), 
                                      max_norm=self.hparams.gradient_clip_val)
        d_opt.step()

        # Logging
        bs = input_images.size(0)
        self.log("train_loss", generator_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=bs)
        self.log("train_diffusion_loss", diffusion_loss, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        self.log("train_adversarial_loss_g", adversarial_loss_g, prog_bar=False, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        self.log("train_discriminator_loss", discriminator_loss, prog_bar=True, on_step=False, 
                on_epoch=True, sync_dist=True, batch_size=bs)
        
        if self.feature_loss_weight > 0 and len(self.feature_predictors) > 0:
            self.log("train_feature_loss", feature_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=bs)

    def validation_step(self, batch, batch_idx):
        """Validation step with discriminator metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"] 
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context_val = self(
            input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
            delta_t, side
        )
        
        if torch.isnan(pred_noise).any():
            print("WARN VAL: NaN pred_noise")
            return

        # Log losses
        val_loss = F.mse_loss(pred_noise, noise)
        val_loss = safe_tensor_log(val_loss, 1000.0, self.device)
        self.log("val_loss", val_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        val_mae_loss = F.l1_loss(pred_noise, noise)
        val_mae_loss = safe_tensor_log(val_mae_loss, 1000.0, self.device)
        self.log("val_mae_loss", val_mae_loss, prog_bar=True, on_step=False, on_epoch=True, 
                sync_dist=True, batch_size=input_images.size(0))

        # Discriminator performance
        with torch.no_grad():
            self.treatment_discriminator.eval()
            treatment_logits_val = self.treatment_discriminator(rich_context_val)
            val_disc_loss = self.criterion_discriminator(treatment_logits_val, interval_labels.float())
            
            # Calculate discriminator accuracy and AUC
            disc_acc, disc_auc = calculate_discriminator_metrics(treatment_logits_val, interval_labels)
            
            self.log("val_disc_loss", val_disc_loss, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            self.log("val_disc_acc", disc_acc, prog_bar=True, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))
            self.log("val_disc_auc", disc_auc, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))

        # Generate images
        generated_images = None
        try:
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
        except Exception:
            print("WARN VAL: Image Generation Failed")
            generated_images = target_images
            
        if generated_images is None:
            generated_images = target_images

        # Generation loss
        val_gen_loss = F.mse_loss(generated_images, target_images)
        if not torch.isnan(val_gen_loss):
            self.log("val_gen_loss", val_gen_loss, prog_bar=False, on_step=False, 
                    on_epoch=True, sync_dist=True, batch_size=input_images.size(0))
        
        # Log SSIM metric
        val_ssim = calculate_ssim(generated_images, target_images, self.device)
        if not torch.isnan(val_ssim):
            self.log("val_ssim", val_ssim, prog_bar=False, on_step=False, on_epoch=True, 
                    sync_dist=True, batch_size=input_images.size(0))

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device
            )
            
            for metric_name, metric_value in ite_metrics.items():
                if not np.isnan(metric_value):
                    self.log(f"val_{metric_name}", metric_value, prog_bar=False, on_step=False, 
                            on_epoch=True, sync_dist=True, batch_size=input_images.size(0))

        # Log qualitative images
        if self.global_rank == 0 and batch_idx == 0:
            from diffusion_utils import log_qualitative_images
            log_qualitative_images(
                self.pipe, input_images[0:16], target_images[0:16], prompts[0:16],
                self.logger, self.current_epoch, device=self.device
            )

    def test_step(self, batch, batch_idx):
        """Test step with comprehensive metrics."""
        input_images = batch["input_image"]
        target_images = batch["target_image"]
        prompts = batch["prompt"]
        earlier_features = batch["earlier_features"]
        later_features = batch["later_features"]
        interval_labels = batch["interval_labels"].to(self.device)
        cov_seq_hist = batch["cov_seq"].to(self.device)
        trt_seq_hist = batch["trt_seq"].to(self.device)
        lengths_hist = batch["lengths"]
        delta_t = batch["delta_t"].to(self.device)
        side = batch["side"].to(self.device)

        # Treatment mask
        treat_mask = (interval_labels[:, -1] == 0)
        
        metrics = {}
        
        with torch.no_grad():
            pred_noise, noise, noisy_latents, timesteps, target_latents, rich_context = self(
                input_images, target_images, prompts, cov_seq_hist, trt_seq_hist, lengths_hist, 
                delta_t, side
            )
            
        if torch.isnan(pred_noise).any():
            raise ValueError("NaN in pred_noise")
            
        # Calculate losses
        loss_mse = F.mse_loss(pred_noise, noise)
        loss_mae = F.l1_loss(pred_noise, noise)

        # Calculate R-squared for noise prediction
        noise_flat = noise.view(noise.size(0), -1)
        pred_noise_flat = pred_noise.view(pred_noise.size(0), -1)
        ss_res = torch.sum((noise_flat - pred_noise_flat) ** 2, dim=1)
        ss_tot = torch.sum((noise_flat - torch.mean(noise_flat, dim=1, keepdim=True)) ** 2, dim=1)
        r2_noise = 1 - (ss_res / (ss_tot + 1e-8))  # Add small epsilon to avoid division by zero
        r2_noise = torch.clamp(r2_noise, min=-10, max=1)  # Clamp to reasonable range

        metrics['test_loss_mse'] = loss_mse.item()
        metrics['test_loss_mae'] = loss_mae.item()
        metrics['test_r2_noise'] = r2_noise.mean().item()
        
        # Per-sample losses
        per_sample_mse = F.mse_loss(pred_noise, noise, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_loss_mse_treated'] = per_sample_mse[treat_mask].mean().item()
            metrics['test_r2_noise_treated'] = r2_noise[treat_mask].mean().item()

        # Discriminator metrics
        treatment_logits_test = self.treatment_discriminator(rich_context)
        disc_loss = self.criterion_discriminator(treatment_logits_test, interval_labels.float())
        metrics['test_disc_loss'] = disc_loss.item()
        
        disc_acc, disc_auc = calculate_discriminator_metrics(treatment_logits_test, interval_labels)
        metrics['test_disc_acc'] = disc_acc
        metrics['test_disc_auc'] = disc_auc

        # Generate images
        with torch.no_grad():
            pred_latents = self._calculate_predicted_x0(pred_noise, timesteps, noisy_latents)
            if not (torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any()):
                self.vae.eval()
                with autocast('cuda', enabled=self.trainer.precision.startswith("16") or 
                             self.trainer.precision.startswith("bf16")):
                    generated_images = self.vae.decode(pred_latents.to(self.vae.dtype) / self.latent_scale).sample
            else:
                generated_images = target_images
                
        if generated_images is None or torch.isnan(generated_images).any():
            generated_images = target_images

        # Generation metrics
        val_gen_loss = F.mse_loss(generated_images, target_images)
        metrics['test_gen_loss'] = val_gen_loss.item()

        # Calculate R-squared for generated images
        target_flat = target_images.view(target_images.size(0), -1)
        generated_flat = generated_images.view(generated_images.size(0), -1)
        ss_res_gen = torch.sum((target_flat - generated_flat) ** 2, dim=1)
        # Fixed: Use overall mean across all pixels, not per-image mean
        overall_mean = torch.mean(target_flat)
        ss_tot_gen = torch.sum((target_flat - overall_mean) ** 2, dim=1)
        r2_gen = 1 - (ss_res_gen / (ss_tot_gen + 1e-8))  # Add small epsilon to avoid division by zero
        r2_gen = torch.clamp(r2_gen, min=-10, max=1)  # Clamp to reasonable range
        metrics['test_r2_gen'] = r2_gen.mean().item()

        # Calculate SSIM
        test_ssim = calculate_ssim(generated_images, target_images, self.device)
        metrics['test_ssim'] = test_ssim.item()

        per_sample_gen = F.mse_loss(generated_images, target_images, reduction='none').view(pred_noise.size(0), -1).mean(dim=1)
        if treat_mask.any():
            metrics['test_gen_loss_treated'] = per_sample_gen[treat_mask].mean().item()
            metrics['test_r2_gen_treated'] = r2_gen[treat_mask].mean().item()

        # ITE Metrics
        if (len(self.feature_predictors) > 0 and not torch.equal(generated_images, target_images)):
            from diffusion_utils import calculate_ite_metrics
            ite_metrics = calculate_ite_metrics(
                generated_images, target_images, earlier_features, later_features,
                self.feature_predictors, self.trainer.precision, self.device, treat_mask
            )
            metrics.update({f'test_{k}': v for k, v in ite_metrics.items()})

        metrics['batch_size'] = float(input_images.size(0))
        self._test_outputs.append(metrics.copy())
        
        return {k: v for k, v in metrics.items() if np.isfinite(v)}

    def configure_optimizers(self):
        """Configure generator and discriminator optimizers."""
        # Generator parameters
        gen_params = (list(self.unet.parameters()) + 
                     list(self.img_conditioning.parameters()) +
                     list(self.context_encoder.parameters()))
        if not isinstance(self.rich_context_projector, nn.Identity):
            gen_params += list(self.rich_context_projector.parameters())
            
        # Optimizers
        g_optimizer = AdamW(gen_params, lr=self.lr_generator)
        d_optimizer = AdamW(self.treatment_discriminator.parameters(), lr=self.lr_discriminator)
        
        return [g_optimizer, d_optimizer]


# Main script
if __name__ == "__main__":
    args = parse_args()
    pl.seed_everything(42, workers=True)
    
    print(f"Run Name: {args.run_name}")
    
    # Image transforms
    image_transforms = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # Collate function
    collate_fn = create_unified_collate_fn(include_adversarial=True)
    
    if args.test_only:
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Load model from checkpoint
        diffusion_model = AdversarialDiffusionLitModule.load_from_checkpoint(
            args.ckpt_path, map_location="cpu", args=args,
            vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning, strict=False
        )
        diffusion_model.hparams.results_file = args.results_file
        diffusion_model.hparams.ckpt_path = args.ckpt_path
        
        # Test dataset and loader
        test_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "test.csv"),
            image_transform=image_transforms, include_images=False
        )
        test_loader = DataLoader(
            test_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=False
        )
        
        # Run test
        trainer = pl.Trainer(
            accelerator="gpu", devices=args.devices, precision=args.precision,
            logger=False, callbacks=[]
        )
        trainer.test(model=diffusion_model, dataloaders=test_loader, verbose=True)
        
    else:
        # Training mode
        train_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "train.csv"),
            image_transform=image_transforms, include_images=False
        )
        val_dataset = TemporalKneeFeatureConditioningDataset(
            csv_file=os.path.join(args.pairs_dir, "val.csv"),
            image_transform=image_transforms, include_images=False
        )
        
        # Data loaders
        train_loader = DataLoader(
            train_dataset, batch_size=args.batch_size, shuffle=True, 
            num_workers=8, pin_memory=True, collate_fn=collate_fn, drop_last=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=8, pin_memory=True, collate_fn=collate_fn
        )
        
        # Load diffusion components
        vae = AutoencoderKL.from_pretrained(args.diffusion_model_pretrain, subfolder="vae")
        unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_pretrain, subfolder="unet")
        text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_pretrain, subfolder="text_encoder")
        tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_pretrain, subfolder="tokenizer")
        scheduler = DDIMScheduler.from_pretrained(args.diffusion_model_pretrain, subfolder="scheduler")
        
        from diff_base import ImageConditioning
        img_conditioning = ImageConditioning(vae.config.latent_channels, args.disc_current_img_cond_dim)
        
        # Discriminator positive weight if enabled
        discriminator_pos_weight = None
        if args.use_disc_pos_weight:
            pos_weight_value = 2.0  # Can be made configurable
            discriminator_pos_weight = torch.tensor([pos_weight_value] * NUM_TREATMENTS)
            print(f"Using discriminator positive weight: {pos_weight_value}")
        
        # Create model
        diffusion_model = AdversarialDiffusionLitModule(
            args=args, vae=vae, unet=unet, scheduler=scheduler,
            text_encoder=text_encoder, tokenizer=tokenizer,
            img_conditioning=img_conditioning, discriminator_pos_weight=discriminator_pos_weight
        )
        
        # Callbacks and logger
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss", dirpath=args.checkpoint_dir,
            filename="best-{epoch}-val_loss={val_loss:.4f}",
            auto_insert_metric_name=False, save_top_k=2, mode="min", save_last=True
        )
        tensorboard_logger = TensorBoardLogger(save_dir=args.log_dir_base, name=args.run_name)
        
        # Trainer
        trainer = pl.Trainer(
            max_epochs=args.max_epochs, accelerator="gpu", devices=args.devices,
            precision=args.precision, logger=tensorboard_logger, log_every_n_steps=25,
            strategy="ddp_find_unused_parameters_true", callbacks=[checkpoint_callback]
        )
        
        print(f"Starting training run: {args.run_name}")
        trainer.fit(diffusion_model, train_loader, val_loader)
        print("--- Training Finished ---")
        print(f"Best model checkpoint: {checkpoint_callback.best_model_path}")