#!/usr/bin/env python3

import torch
import torch.nn.functional as F
import numpy as np
import os
from PIL import Image as PImage
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torchvision.utils import save_image
warnings.filterwarnings("ignore")

from inference_VQ_Diffusion import VQ_Diffusion
from image_synthesis.modeling.codecs.image_codec.ema_vqvae import PatchVQVAE
from image_synthesis.utils.misc import instantiate_from_config


class CustomImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.transform = transform
        self.image_files = sorted([
            os.path.join(image_dir, f) for f in os.listdir(image_dir)
            if f.lower().endswith((".png", ".jpg", ".jpeg"))
        ])
        if not self.image_files:
            raise FileNotFoundError(f"No images found in {image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image = PImage.open(self.image_files[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]


class Config:
    def __init__(self, real_dataset_name, real_folder, generated_folder, output_dir):
        self.REAL_DATASET_NAME = real_dataset_name
        self.REAL_FOLDER = real_folder
        self.GENERATED_FOLDER = generated_folder
        self.OUTPUT_DIR = output_dir
        self.BATCH_SIZE = 32
        self.MAX_IMAGES = 1000
        self.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.IMAGE_SIZE = 256
        
        # Latent tracer control
        self.USE_LATENT_TRACER = True  # Set to False to disable latent tracer computation
        
        # Final losses saving control
        self.SAVE_FINAL_LOSSES = True  # Set to False to disable saving losses for distribution plotting
        
        # Finetuned VQVAE weights path (set to None to use default weights)
        #self.FINETUNED_VQVAE_WEIGHTS = "/workspace/VQ-Diffusion/finetuned_models_BS16_L5e-05_W0.0001_E50_SEED0/vqvae_finetune_20250915_130057_enc_epochs50_final.pth"
        self.FINETUNED_VQVAE_WEIGHTS = None
        
        # Configurable save paths
        self.BATCH_OUTPUT_DIR = "/workspace/VQ-Diffusion/final_no_finetuned"  # Main results directory
        self.GENERATED_IMAGES_FOLDER = "/workspace/VQ-Diffusion/generated_images/ithq_np1000_pr2_seed0"  # VQ-Diffusion generated images
        
        # Numerical stability for ratios
        self.EPSILON = 1e-8
        
        # Configurable datasets for analysis
        self.DATASETS = {
            "ImageNet (val)": "/datasets/imagenet256_eval",
            "LAION": "/data/datasets/laion_1k_clean",
            "MS-COCO": "/data/datasets/mscoco2014val/val2014_subset",
            "RAR Generated": "/data/generated_datasets/rar/rar_xl_dec_original_generated",
            "LlamaGen Generated": "/data/generated_datasets/llamagen/256x256",
            "Taming Generated": "/data/generated_datasets/taming/taming_gen_images_original_decoder/images",
            "VAR Generated": "/data/generated_datasets/var/var_generated",
            "Infinity Generated": "/data/generated_datasets/infinity/infinity_2b/delta_0"
        }


# Extended map for loss saving
EXTENDED_LOSS_MAP = {
    0: "Codebook Loss MSE",
    1: "Codebook Loss MSE Double", 
    2: "Codebook Loss MSE Double Ratio",
    3: "Reconstruction Quant Loss MSE",
    4: "Reconstruction Quant Loss MSE Double",
    5: "Reconstruction Quant Loss MSE Double Ratio",
    6: "Reconstruction No Quant Loss MSE",
    7: "Reconstruction No Quant Loss MSE Double",
    8: "Reconstruction No Quant Loss MSE Double Ratio", 
    9: "Main Combined"
}

# Mapping from extended map indices to current code loss keys
LOSS_KEY_MAPPING = {
    0: "codebook_single",           # Codebook Loss MSE
    1: "codebook_double",           # Codebook Loss MSE Double  
    2: "codebook_ratios",           # Codebook Loss MSE Double Ratio
    3: "recon_single",              # Reconstruction Quant Loss MSE
    4: "recon_double",              # Reconstruction Quant Loss MSE Double
    5: "recon_ratios",              # Reconstruction Quant Loss MSE Double Ratio
    6: "recon_no_quant",            # Reconstruction No Quant Loss MSE
    7: "recon_no_quant_double",     # Reconstruction No Quant Loss MSE Double
    8: "recon_no_quant_ratios",     # Reconstruction No Quant Loss MSE Double Ratio
    9: "combined1"                  # Main Combined
}


def save_final_losses(data, dataset_name, config):
    """Save losses in the specified format for distribution plotting."""
    if not config.SAVE_FINAL_LOSSES:
        return
    
    print(f"Saving final losses for {dataset_name}...")
    
    # Determine encoder type based on finetuned weights
    if config.FINETUNED_VQVAE_WEIGHTS is None:
        encoder_type = "orig_enc"
    else:
        encoder_type = "ft_enc"
    
    # Create base save directory
    base_save_dir = f"/results/losses/vq_diffusion/{encoder_type}"
    os.makedirs(base_save_dir, exist_ok=True)
    
    # Organize losses into 10-element tuple according to mapping
    loss_tuple = []
    for i in range(10):
        loss_key = LOSS_KEY_MAPPING[i]
        if loss_key in data:
            loss_tuple.append(data[loss_key])
        else:
            print(f"Warning: Loss key '{loss_key}' not found in data for index {i}")
            loss_tuple.append([])  # Empty list if loss not available
    
    # Convert to tuple
    final_losses = tuple(loss_tuple)
    
    # Save the losses
    save_path = f"{base_save_dir}/results_{dataset_name}_mse.pt"
    torch.save(final_losses, save_path)
    print(f"Saved losses to: {save_path}")
    
    # Save latent tracer separately if enabled and available
    if config.USE_LATENT_TRACER and 'latent_tracer' in data:
        latent_tracer_dir = "/results/losses/vq_diffusion/latent_tracer"
        os.makedirs(latent_tracer_dir, exist_ok=True)
        
        # Create a tuple with just the latent tracer losses (following the same format)
        latent_tracer_tuple = tuple([data['latent_tracer']] + [[]] * 9)  # First element is latent tracer, rest empty
        
        latent_save_path = f"{latent_tracer_dir}/results_{dataset_name}_mse.pt"
        torch.save(latent_tracer_tuple, latent_save_path)
        print(f"Saved latent tracer losses to: {latent_save_path}")


def create_transform(config):
    return transforms.Compose([
        transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])


def initialize_vq_codec(config):
    vq_model = VQ_Diffusion(
        config='/workspace/VQ-Diffusion/configs/ithq.yaml',
        path='/checkpoints/pretrained_model/ithq_learnable.pth'
    )
    codec = vq_model.model.content_codec
    
    # Load finetuned weights if specified
    if config.FINETUNED_VQVAE_WEIGHTS is not None:
        print(f"Loading finetuned VQVAE weights from: {config.FINETUNED_VQVAE_WEIGHTS}")
        try:
            # Load the checkpoint
            checkpoint = torch.load(config.FINETUNED_VQVAE_WEIGHTS, map_location='cpu')
            
            # Handle different checkpoint formats
            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
            
            # Load the state dict
            codec.load_state_dict(state_dict, strict=True)
            print("Successfully loaded finetuned VQVAE weights!")
            
        except Exception as e:
            print(f"Warning: Failed to load finetuned weights: {e}")
            print("Continuing with default pretrained weights...")
    
    return codec


def compute_latent_tracer_loss(codec, images, quantized_states, device, lr=0.01, iters=100):
    # Convert images to [0,1] range for loss computation and ensure on correct device
    image_batch = (images + 1.0) / 2.0  # Convert from [-1,1] to [0,1]
    image_batch = torch.clamp(image_batch, 0.0, 1.0).to(device)
    
    losses = []
    
    for i in range(images.shape[0]):  # Process each image individually
        # Initialize optimizable latent
        fhat_optim = torch.nn.Parameter(quantized_states[i:i+1].clone().detach()).to(device)
        optimizer = torch.optim.Adam([fhat_optim], lr=lr)
        
        current_image = image_batch[i:i+1].to(device)
        
        # Optimization loop
        for iter_idx in range(iters):
            optimizer.zero_grad()
            
            # Decode from optimized latent
            f_optimized = codec.dec.post_quant_conv(fhat_optim)
            rec_gen_img = codec.dec.decoder(f_optimized)
            rec_gen_img = (rec_gen_img + 1.0) / 2.0  # Convert to [0,1]
            rec_gen_img = torch.clamp(rec_gen_img, 0.0, 1.0)
            
            # Compute loss
            loss = torch.nn.functional.mse_loss(rec_gen_img, current_image)
            loss.backward()
            optimizer.step()
            
            # Learning rate decay every 50 iterations
            if iter_idx % 50 == 0 and iter_idx > 0:
                for g in optimizer.param_groups:
                    g['lr'] = g['lr'] * 0.5
        
        # Final loss computation
        with torch.no_grad():
            f_optimized = codec.dec.post_quant_conv(fhat_optim)
            rec_gen_img = codec.dec.decoder(f_optimized)
            rec_gen_img = (rec_gen_img + 1.0) / 2.0  # Convert to [0,1]
            rec_gen_img = torch.clamp(rec_gen_img, 0.0, 1.0)
            
            final_loss = torch.nn.functional.mse_loss(rec_gen_img, current_image, reduction='none')
            final_loss = final_loss.view(final_loss.size(0), -1).mean(dim=1).cpu().numpy()
            losses.extend(final_loss)
    
    return losses


@torch.no_grad()
def compute_all_losses_optimized(codec, images, device):
    """Optimized version that eliminates redundant encoding operations."""
    images = images.to(device)
    
    # Encode once to get both continuous and quantized features
    h = codec.enc.encoder(images)
    f_continuous = codec.enc.quant_conv(h)
    indices = codec.enc.quantize.only_get_indices(f_continuous).view(images.shape[0], -1)
    z_quantized = codec.enc.quantize.get_codebook_entry(indices.view(-1), shape=(images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1]))
    f_quantized = codec.dec.post_quant_conv(z_quantized)
    recon_quantized = codec.dec.decoder(f_quantized) 
    recon_quantized = torch.clamp(recon_quantized, -1., 1.)
    
    quant_no_quant = codec.dec.post_quant_conv(f_continuous)  
    dec_no_quant = codec.dec.decoder(quant_no_quant)
    recon_no_quant = torch.clamp(dec_no_quant, -1., 1.)
    
    # SINGLE LOSSES (original → reconstructed)
    recon_single = torch.nn.functional.mse_loss(recon_quantized, images, reduction='none')
    recon_single = recon_single.view(recon_single.size(0), -1).mean(dim=1).cpu().numpy()
    
    recon_no_quant_single = torch.nn.functional.mse_loss(recon_no_quant, images, reduction='none')
    recon_no_quant_single = recon_no_quant_single.view(recon_no_quant_single.size(0), -1).mean(dim=1).cpu().numpy()
    
    codebook_single = torch.nn.functional.mse_loss(z_quantized, f_continuous, reduction='none')
    codebook_single = codebook_single.view(codebook_single.size(0), -1).mean(dim=1).cpu().numpy()
    
    # DOUBLE LOSSES (reconstructed → double reconstructed)
    # Encode reconstructed images (only 2 more encoding passes instead of 6+)
    
    # For quantized reconstruction -> double quantized
    h2 = codec.enc.encoder(recon_quantized)
    f2_continuous = codec.enc.quant_conv(h2)
    indices2 = codec.enc.quantize.only_get_indices(f2_continuous).view(recon_quantized.shape[0], -1)
    z_quantized2 = codec.enc.quantize.get_codebook_entry(indices2.view(-1), shape=(recon_quantized.shape[0], f2_continuous.shape[-2], f2_continuous.shape[-1]))
    f2_quantized = codec.dec.post_quant_conv(z_quantized2)
    recon_double_quantized = codec.dec.decoder(f2_quantized)
    recon_double_quantized = torch.clamp(recon_double_quantized, -1., 1.)
    
    # For no-quantization reconstruction -> double no-quantization
    h3 = codec.enc.encoder(recon_no_quant)
    f3_continuous = codec.enc.quant_conv(h3)
    
    # Double no-quantization reconstruction
    quant_no_quant2 = codec.dec.post_quant_conv(f3_continuous)
    dec_no_quant2 = codec.dec.decoder(quant_no_quant2)
    recon_double_no_quant = torch.clamp(dec_no_quant2, -1., 1.)
    
    # Calculate double losses
    recon_double = torch.nn.functional.mse_loss(recon_double_quantized, recon_quantized, reduction='none')
    recon_double = recon_double.view(recon_double.size(0), -1).mean(dim=1).cpu().numpy()
    
    recon_no_quant_double = torch.nn.functional.mse_loss(recon_double_no_quant, recon_no_quant, reduction='none')
    recon_no_quant_double = recon_no_quant_double.view(recon_no_quant_double.size(0), -1).mean(dim=1).cpu().numpy()
    
    codebook_double = torch.nn.functional.mse_loss(z_quantized2, f2_continuous, reduction='none')
    codebook_double = codebook_double.view(codebook_double.size(0), -1).mean(dim=1).cpu().numpy()
    
    return {
        'recon_single': recon_single,
        'recon_double': recon_double,
        'recon_no_quant_single': recon_no_quant_single,
        'recon_no_quant_double': recon_no_quant_double,
        'codebook_single': codebook_single,
        'codebook_double': codebook_double
    }


def process_images_folder(codec, device, folder_path, config, label_name):
    recon_single_losses, recon_double_losses, recon_ratios = [], [], []
    recon_no_quant_losses, recon_no_quant_double_losses, recon_no_quant_ratios = [], [], []
    codebook_single_losses, codebook_double_losses, codebook_ratios = [], [], []
    latent_tracer_losses = []
    combined1_losses, combined2_losses = [], []  # New combined metrics
    image_paths = []
    
    dataset = CustomImageDataset(folder_path, transform=create_transform(config))
    if config.MAX_IMAGES > 0 and len(dataset) > config.MAX_IMAGES:
        dataset.image_files = dataset.image_files[:config.MAX_IMAGES]
    
    dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0, drop_last=False)
    print(f"Processing {len(dataset)} {label_name} images...")
    
    if config.USE_LATENT_TRACER:
        print(f"Latent tracer computation is ENABLED for {label_name}")
    else:
        print(f"Latent tracer computation is DISABLED for {label_name}")
    
    for batch_idx, (batch_images, batch_paths) in enumerate(tqdm(dataloader, desc=f"Processing {label_name}")):
        batch_losses = compute_all_losses_optimized(codec, batch_images, device)
        
        # Compute latent tracer loss separately (requires gradients) - only if enabled
        if config.USE_LATENT_TRACER:
            with torch.enable_grad():
                # Get quantized states for latent tracer
                h = codec.enc.encoder(batch_images.to(device))
                f_continuous = codec.enc.quant_conv(h)
                indices = codec.enc.quantize.only_get_indices(f_continuous).view(batch_images.shape[0], -1)
                z_quantized = codec.enc.quantize.get_codebook_entry(indices.view(-1), shape=(batch_images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1]))
                
                batch_latent_tracer = compute_latent_tracer_loss(codec, batch_images, z_quantized, device)
                batch_losses['latent_tracer'] = batch_latent_tracer
        else:
            # Create empty latent tracer losses when disabled
            batch_losses['latent_tracer'] = [0.0] * batch_images.shape[0]
        
        # Save sample reconstructions for the first batch only
        if batch_idx == 0:
            save_sample_reconstructions(codec, batch_images, batch_paths, config.OUTPUT_DIR, label_name)
        
        recon_single_losses.extend(batch_losses['recon_single'])
        recon_double_losses.extend(batch_losses['recon_double'])
        recon_no_quant_losses.extend(batch_losses['recon_no_quant_single'])
        recon_no_quant_double_losses.extend(batch_losses['recon_no_quant_double'])
        codebook_single_losses.extend(batch_losses['codebook_single'])
        codebook_double_losses.extend(batch_losses['codebook_double'])
        latent_tracer_losses.extend(batch_losses['latent_tracer'])
        image_paths.extend(batch_paths)
        
        # Calculate ratios from single/double losses with epsilon to avoid NaNs
        for single, double in zip(batch_losses['recon_single'], batch_losses['recon_double']):
            recon_ratios.append(single / (double + config.EPSILON))
        
        for no_quant_single, no_quant_double in zip(batch_losses['recon_no_quant_single'], batch_losses['recon_no_quant_double']):
            recon_no_quant_ratios.append(no_quant_single / (no_quant_double + config.EPSILON))
        
        for single, double in zip(batch_losses['codebook_single'], batch_losses['codebook_double']):
            codebook_ratios.append(single / (double + config.EPSILON))
        
        # Calculate combined metrics
        # Combined1: codebook_single * recon_no_quant_ratio
        # Combined2: codebook_ratio * recon_no_quant_ratio
        current_recon_nq_ratios = []
        current_codebook_ratios = []
        
        # Get the ratios for this batch (use epsilon)
        for no_quant_single, no_quant_double in zip(batch_losses['recon_no_quant_single'], batch_losses['recon_no_quant_double']):
            current_recon_nq_ratios.append(no_quant_single / (no_quant_double + config.EPSILON))
        
        for single, double in zip(batch_losses['codebook_single'], batch_losses['codebook_double']):
            current_codebook_ratios.append(single / (double + config.EPSILON))
        
        # Calculate combined metrics
        for i in range(len(batch_losses['codebook_single'])):
            codebook_single = batch_losses['codebook_single'][i]
            recon_nq_ratio = current_recon_nq_ratios[i] if i < len(current_recon_nq_ratios) else float('nan')
            codebook_ratio = current_codebook_ratios[i] if i < len(current_codebook_ratios) else float('nan')
            
            # Combined1: codebook_single * recon_no_quant_ratio
            if not (np.isnan(codebook_single) or np.isnan(recon_nq_ratio)):
                combined1_losses.append(codebook_single * recon_nq_ratio)
            else:
                combined1_losses.append(float('nan'))
            
            # Combined2: codebook_ratio * recon_no_quant_ratio  
            if not (np.isnan(codebook_ratio) or np.isnan(recon_nq_ratio)):
                combined2_losses.append(codebook_ratio * recon_nq_ratio)
            else:
                combined2_losses.append(float('nan'))
    
    return {
        'recon_single': recon_single_losses, 'recon_double': recon_double_losses, 'recon_ratios': recon_ratios,
        'recon_no_quant': recon_no_quant_losses, 'recon_no_quant_double': recon_no_quant_double_losses, 'recon_no_quant_ratios': recon_no_quant_ratios,
        'codebook_single': codebook_single_losses, 'codebook_double': codebook_double_losses, 'codebook_ratios': codebook_ratios,
        'latent_tracer': latent_tracer_losses,
        'combined1': combined1_losses, 'combined2': combined2_losses,
        'paths': image_paths
    }


def calculate_detection_metrics(real_scores, generated_scores, metric_name):
    """
    Calculate AUC and TPR@1%FPR for detecting generated images.
    
    For VQ-VAE reconstruction losses:
    - Real images typically have HIGHER losses (don't fit the model)
    - Generated images typically have LOWER losses (created by the model)
    
    We define the detection task as: Can we identify generated images?
    - Positive class (1): Generated images
    - Negative class (0): Real images
    """
    # Clean the scores
    real_clean = [x for x in real_scores if not np.isnan(x) and np.isfinite(x)]
    gen_clean = [x for x in generated_scores if not np.isnan(x) and np.isfinite(x)]
    
    if len(real_clean) == 0 or len(gen_clean) == 0:
        print(f"Warning: No valid scores for {metric_name}")
        return {'auc': 0.0, 'tpr_at_1_fpr': 0.0, 'fpr': [], 'tpr': []}
    
    # For loss-based detection, we use NEGATIVE losses as scores
    # This way: lower loss = higher score = more likely to be generated
    real_detection_scores = [-x for x in real_clean]  # Real images get negative of their high losses
    gen_detection_scores = [-x for x in gen_clean]    # Generated images get negative of their low losses
    
    # Create labels and scores
    y_true = np.array([0] * len(real_detection_scores) + [1] * len(gen_detection_scores))
    y_scores = np.array(real_detection_scores + gen_detection_scores)
    
    # Basic validation
    if len(np.unique(y_true)) < 2:
        print(f"Warning: Only one class present for {metric_name}")
        return {'auc': 0.0, 'tpr_at_1_fpr': 0.0, 'fpr': [], 'tpr': []}
    
    if np.std(y_scores) < 1e-10:
        print(f"Warning: No score variance for {metric_name}")
        return {'auc': 0.5, 'tpr_at_1_fpr': 0.0, 'fpr': [], 'tpr': []}
    
    try:
        # Calculate ROC curve
        fpr, tpr, thresholds = roc_curve(y_true, y_scores)
        roc_auc = auc(fpr, tpr)
        
        # Calculate TPR at 1% FPR by setting threshold directly
        target_fpr = 0.01
        n_real = len([label for label in y_true if label == 0])
        
        # Find threshold that gives exactly 1% FPR
        # 1% FPR means we misclassify 1% of real images (negative class)
        real_scores = y_scores[y_true == 0]  # Scores for real images
        threshold_idx = int(n_real * (1 - target_fpr))  # 99th percentile of real scores
        
        if threshold_idx < len(real_scores):
            threshold = np.sort(real_scores)[threshold_idx]
            
            # Count how many generated images are above this threshold
            gen_scores = y_scores[y_true == 1]  # Scores for generated images
            n_gen = len(gen_scores)
            n_detected_gen = np.sum(gen_scores > threshold)
            
            tpr_at_1_fpr = n_detected_gen / n_gen if n_gen > 0 else 0.0
        else:
            tpr_at_1_fpr = 0.0
        
        return {'auc': roc_auc, 'tpr_at_1_fpr': tpr_at_1_fpr, 'fpr': fpr, 'tpr': tpr}
        
    except ValueError as e:
        print(f"Warning: Could not calculate metrics for {metric_name}: {e}")
        return {'auc': 0.0, 'tpr_at_1_fpr': 0.0, 'fpr': [], 'tpr': []}


def save_sample_reconstructions(codec, batch_images, batch_paths, output_dir, label_name):
    """Save sample reconstructions for visual verification."""
    sample_dir = os.path.join(output_dir, f"sample_reconstructions_{label_name}")
    os.makedirs(sample_dir, exist_ok=True)
    
    device = next(codec.parameters()).device
    batch_images = batch_images.to(device)
    
    with torch.no_grad():
        # Get all reconstructions using optimized function
        batch_losses = compute_all_losses_optimized(codec, batch_images, device)
        
        # Follow EXACT same pipeline as compute_all_losses_optimized
        # Encode once to get both continuous and quantized features
        h = codec.enc.encoder(batch_images)
        f_continuous = codec.enc.quant_conv(h)
        indices = codec.enc.quantize.only_get_indices(f_continuous).view(batch_images.shape[0], -1)
        z_quantized = codec.enc.quantize.get_codebook_entry(indices.view(-1), shape=(batch_images.shape[0], f_continuous.shape[-2], f_continuous.shape[-1]))
        f_quantized = codec.dec.post_quant_conv(z_quantized)
        recon_single = codec.dec.decoder(f_quantized) 
        recon_single = torch.clamp(recon_single, -1., 1.)
        
        # Single no-quantization reconstruction
        quant_no_quant = codec.dec.post_quant_conv(f_continuous)  
        dec_no_quant = codec.dec.decoder(quant_no_quant)
        recon_no_quant_single = torch.clamp(dec_no_quant, -1., 1.)
        
        # Double quantized reconstruction
        h2 = codec.enc.encoder(recon_single)
        f2_continuous = codec.enc.quant_conv(h2)
        indices2 = codec.enc.quantize.only_get_indices(f2_continuous).view(recon_single.shape[0], -1)
        z_quantized2 = codec.enc.quantize.get_codebook_entry(indices2.view(-1), shape=(recon_single.shape[0], f2_continuous.shape[-2], f2_continuous.shape[-1]))
        f2_quantized = codec.dec.post_quant_conv(z_quantized2)
        recon_double = codec.dec.decoder(f2_quantized)
        recon_double = torch.clamp(recon_double, -1., 1.)
        
        # Double no-quantization reconstruction
        h3 = codec.enc.encoder(recon_no_quant_single)
        f3_continuous = codec.enc.quant_conv(h3)
        quant_no_quant2 = codec.dec.post_quant_conv(f3_continuous)
        dec_no_quant2 = codec.dec.decoder(quant_no_quant2)
        recon_no_quant_double = torch.clamp(dec_no_quant2, -1., 1.)
    
    # Convert to numpy and save images
    batch_images_np = ((batch_images.cpu() + 1.0) / 2.0).clamp(0, 1).numpy()
    recon_single_np = ((recon_single.cpu() + 1.0) / 2.0).clamp(0, 1).numpy()
    recon_double_np = ((recon_double.cpu() + 1.0) / 2.0).clamp(0, 1).numpy()
    recon_no_quant_single_np = ((recon_no_quant_single.cpu() + 1.0) / 2.0).clamp(0, 1).numpy()
    recon_no_quant_double_np = ((recon_no_quant_double.cpu() + 1.0) / 2.0).clamp(0, 1).numpy()
    
    # Save first few images
    num_samples = min(8, batch_images.shape[0])
    sample_losses = []
    
    for i in range(num_samples):
        # Original
        img_orig = (batch_images_np[i].transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(img_orig).save(f"{sample_dir}/sample_{i:02d}_01_original.png")
        
        # Single quantized
        img_single = (recon_single_np[i].transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(img_single).save(f"{sample_dir}/sample_{i:02d}_02_single_quant.png")
        
        # Double quantized
        img_double = (recon_double_np[i].transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(img_double).save(f"{sample_dir}/sample_{i:02d}_03_double_quant.png")
        
        # Single no-quantization
        img_single_nq = (recon_no_quant_single_np[i].transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(img_single_nq).save(f"{sample_dir}/sample_{i:02d}_04_single_no_quant.png")
        
        # Double no-quantization
        img_double_nq = (recon_no_quant_double_np[i].transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(img_double_nq).save(f"{sample_dir}/sample_{i:02d}_05_double_no_quant.png")
        
        # Difference maps
        diff_single = np.abs(batch_images_np[i] - recon_single_np[i])
        diff_single = (diff_single.transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(diff_single).save(f"{sample_dir}/sample_{i:02d}_06_diff_single.png")
        
        diff_no_quant = np.abs(batch_images_np[i] - recon_no_quant_single_np[i])
        diff_no_quant = (diff_no_quant.transpose(1, 2, 0) * 255).astype(np.uint8)
        PImage.fromarray(diff_no_quant).save(f"{sample_dir}/sample_{i:02d}_07_diff_no_quant.png")
        
        # Store loss values
        sample_losses.append({
            'sample_idx': i,
            'recon_single': batch_losses['recon_single'][i] if i < len(batch_losses['recon_single']) else np.nan,
            'recon_double': batch_losses['recon_double'][i] if i < len(batch_losses['recon_double']) else np.nan,
            'recon_no_quant_single': batch_losses['recon_no_quant_single'][i] if i < len(batch_losses['recon_no_quant_single']) else np.nan,
            'recon_no_quant_double': batch_losses['recon_no_quant_double'][i] if i < len(batch_losses['recon_no_quant_double']) else np.nan,
            'codebook_single': batch_losses['codebook_single'][i] if i < len(batch_losses['codebook_single']) else np.nan,
            'codebook_double': batch_losses['codebook_double'][i] if i < len(batch_losses['codebook_double']) else np.nan,
        })
    
    # Save loss values as CSV
    df = pd.DataFrame(sample_losses)
    df.to_csv(f"{sample_dir}/sample_loss_values.csv", index=False)
    
    print(f"Sample reconstructions saved to {sample_dir}/")


def create_visualizations(real_data, generated_data, real_dataset_name, output_dir):
    """Create comprehensive visualizations for a single comparison."""
    
    # Set up the plotting style
    plt.style.use('default')
    sns.set_palette("husl")
    
    # Create figure with subplots - 2x3 for all losses and ratios
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'{real_dataset_name} vs VQ-Diffusion: Loss Distributions', fontsize=16, fontweight='bold')
    
    # Plot reconstruction loss (with quantization)
    axes[0, 0].hist(real_data['recon_single'], bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
    axes[0, 0].hist(generated_data['recon_single'], bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[0, 0].set_title('Reconstruction Loss (w/ Quantization)')
    axes[0, 0].set_xlabel('Loss')
    axes[0, 0].set_ylabel('Density')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot reconstruction loss (no quantization)
    axes[0, 1].hist(real_data['recon_no_quant'], bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
    axes[0, 1].hist(generated_data['recon_no_quant'], bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[0, 1].set_title('Reconstruction Loss (No Quantization)')
    axes[0, 1].set_xlabel('Loss')
    axes[0, 1].set_ylabel('Density')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot codebook loss
    axes[0, 2].hist(real_data['codebook_single'], bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
    axes[0, 2].hist(generated_data['codebook_single'], bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[0, 2].set_title('Codebook Loss')
    axes[0, 2].set_xlabel('Loss')
    axes[0, 2].set_ylabel('Density')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Plot reconstruction ratios
    real_recon_ratios_clean = [r for r in real_data['recon_ratios'] if not np.isnan(r) and np.isfinite(r)]
    gen_recon_ratios_clean = [r for r in generated_data['recon_ratios'] if not np.isnan(r) and np.isfinite(r)]
    
    if real_recon_ratios_clean and gen_recon_ratios_clean:
        axes[1, 0].hist(real_recon_ratios_clean, bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
        axes[1, 0].hist(gen_recon_ratios_clean, bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[1, 0].set_title('Reconstruction Ratio (Single/Double)')
    axes[1, 0].set_xlabel('Ratio')
    axes[1, 0].set_ylabel('Density')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot no-quantization ratios
    real_nq_ratios_clean = [r for r in real_data['recon_no_quant_ratios'] if not np.isnan(r) and np.isfinite(r)]
    gen_nq_ratios_clean = [r for r in generated_data['recon_no_quant_ratios'] if not np.isnan(r) and np.isfinite(r)]
    
    if real_nq_ratios_clean and gen_nq_ratios_clean:
        axes[1, 1].hist(real_nq_ratios_clean, bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
        axes[1, 1].hist(gen_nq_ratios_clean, bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[1, 1].set_title('No-Quantization Ratio (Single/Double)')
    axes[1, 1].set_xlabel('Ratio')
    axes[1, 1].set_ylabel('Density')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Plot codebook ratios
    real_cb_ratios_clean = [r for r in real_data['codebook_ratios'] if not np.isnan(r) and np.isfinite(r)]
    gen_cb_ratios_clean = [r for r in generated_data['codebook_ratios'] if not np.isnan(r) and np.isfinite(r)]
    
    if real_cb_ratios_clean and gen_cb_ratios_clean:
        axes[1, 2].hist(real_cb_ratios_clean, bins=50, alpha=0.7, label=f'{real_dataset_name}', density=True)
        axes[1, 2].hist(gen_cb_ratios_clean, bins=50, alpha=0.7, label='VQ-Diffusion', density=True)
    axes[1, 2].set_title('Codebook Loss Ratio')
    axes[1, 2].set_xlabel('Ratio')
    axes[1, 2].set_ylabel('Density')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plot
    plot_filename = os.path.join(output_dir, f"{real_dataset_name.replace(' ', '_').replace('(', '').replace(')', '').lower()}_loss_distributions.png")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    return plot_filename

def create_roc_curves(real_data, generated_data, real_dataset_name, output_dir, config=None):
    """Create ROC curves for all metrics in a flexible grid."""
    
    # Determine which metrics to include
    metric_names = [
        ('recon_single', 'Reconstruction Loss (w/ Quantization)'),
        ('recon_no_quant', 'Reconstruction Loss (No Quantization)'),
        ('codebook_single', 'Codebook Loss'),
        ('recon_ratio', 'Reconstruction Ratio (Single/Double)'),
        ('recon_no_quant_ratio', 'No-Quantization Ratio (Single/Double)'),
        ('codebook_ratio', 'Codebook Loss Ratio'),
        ('combined1', 'Combined1 (Codebook × No-Quant Ratio)'),
        ('combined2', 'Combined2 (Codebook Ratio × No-Quant Ratio)')
    ]
    
    # Only include latent tracer if enabled
    if config is None or config.USE_LATENT_TRACER:
        metric_names.append(('latent_tracer', 'Latent Tracer Loss'))
    
    # Adjust subplot layout based on number of metrics
    if len(metric_names) <= 6:
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        axes = axes.flatten()
    elif len(metric_names) <= 8:
        fig, axes = plt.subplots(2, 4, figsize=(24, 10))
        axes = axes.flatten()
    else:
        fig, axes = plt.subplots(3, 4, figsize=(24, 15))
        axes = axes.flatten()
    
    fig.suptitle(f'{real_dataset_name} vs VQ-Diffusion: ROC Curves', fontsize=16, fontweight='bold')
    
    for idx, (metric_key, metric_title) in enumerate(metric_names):
        ax = axes[idx]
        
        # Get data based on metric key
        if metric_key == 'recon_ratio':
            real_scores = real_data['recon_ratios']
            gen_scores = generated_data['recon_ratios']
        elif metric_key == 'recon_no_quant_ratio':
            real_scores = real_data['recon_no_quant_ratios']
            gen_scores = generated_data['recon_no_quant_ratios']
        elif metric_key == 'codebook_ratio':
            real_scores = real_data['codebook_ratios']
            gen_scores = generated_data['codebook_ratios']
        elif metric_key == 'latent_tracer':
            real_scores = real_data['latent_tracer']
            gen_scores = generated_data['latent_tracer']
        elif metric_key == 'combined1':
            real_scores = real_data['combined1']
            gen_scores = generated_data['combined1']
        elif metric_key == 'combined2':
            real_scores = real_data['combined2']
            gen_scores = generated_data['combined2']
        else:
            real_scores = real_data[metric_key]
            gen_scores = generated_data[metric_key]
        
        # Calculate metrics
        metrics = calculate_detection_metrics(real_scores, gen_scores, metric_title)
        
        if len(metrics['fpr']) > 0 and len(metrics['tpr']) > 0:
            ax.plot(metrics['fpr'], metrics['tpr'], linewidth=2, 
                   label=f'AUC = {metrics["auc"]:.3f}')
            ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random Classifier')
            ax.axvline(x=0.01, color='red', linestyle=':', alpha=0.7, label='1% FPR')
            
            # Mark TPR at 1% FPR
            tpr_at_1_fpr = metrics['tpr_at_1_fpr']
            ax.plot(0.01, tpr_at_1_fpr, 'ro', markersize=8, 
                   label=f'TPR@1%FPR = {tpr_at_1_fpr:.3f}')
        else:
            ax.text(0.5, 0.5, 'No valid data', ha='center', va='center', 
                   transform=ax.transAxes)
        
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title(metric_title)
        ax.legend(loc="lower right", fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(len(metric_names), len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    
    # Save the plot
    roc_filename = os.path.join(output_dir, f"{real_dataset_name.replace(' ', '_').replace('(', '').replace(')', '').lower()}_roc_curves.png")
    plt.savefig(roc_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    return roc_filename

def create_combined_roc_curve(real_data, generated_data, real_dataset_name, output_dir, config=None):
    """Create a single ROC curve with all metrics in different colors."""
    
    plt.figure(figsize=(12, 8))
    
    # Define colors for different metrics
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#bcbd22', '#17becf']
    
    # Define metrics to plot
    metric_info = [
        ('recon_single', 'Reconstruction Loss (w/ Quantization)', colors[0]),
        ('recon_no_quant', 'Reconstruction Loss (No Quantization)', colors[1]),
        ('codebook_single', 'Codebook Loss', colors[2]),
        ('recon_ratio', 'Reconstruction Ratio (Single/Double)', colors[3]),
        ('recon_no_quant_ratio', 'No-Quantization Ratio (Single/Double)', colors[4]),
        ('codebook_ratio', 'Codebook Loss Ratio', colors[5]),
        ('combined1', 'Combined1 (Codebook × No-Quant Ratio)', colors[6]),
        ('combined2', 'Combined2 (Codebook Ratio × No-Quant Ratio)', colors[7])
    ]
    
    # Only include latent tracer if enabled
    if config is None or config.USE_LATENT_TRACER:
        metric_info.append(('latent_tracer', 'Latent Tracer Loss', colors[8]))
    
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random Classifier')
    
    for metric_key, metric_title, color in metric_info:
        # Get data based on metric key
        if metric_key == 'recon_ratio':
            real_scores = real_data['recon_ratios']
            gen_scores = generated_data['recon_ratios']
        elif metric_key == 'recon_no_quant_ratio':
            real_scores = real_data['recon_no_quant_ratios']
            gen_scores = generated_data['recon_no_quant_ratios']
        elif metric_key == 'codebook_ratio':
            real_scores = real_data['codebook_ratios']
            gen_scores = generated_data['codebook_ratios']
        elif metric_key == 'latent_tracer':
            real_scores = real_data['latent_tracer']
            gen_scores = generated_data['latent_tracer']
        elif metric_key == 'combined1':
            real_scores = real_data['combined1']
            gen_scores = generated_data['combined1']
        elif metric_key == 'combined2':
            real_scores = real_data['combined2']
            gen_scores = generated_data['combined2']
        else:
            real_scores = real_data[metric_key]
            gen_scores = generated_data[metric_key]
        
        # Calculate metrics
        metrics = calculate_detection_metrics(real_scores, gen_scores, metric_title)
        
        if len(metrics['fpr']) > 0 and len(metrics['tpr']) > 0:
            plt.plot(metrics['fpr'], metrics['tpr'], linewidth=2, color=color, 
                    label=f'{metric_title} (AUC = {metrics["auc"]:.3f})')
    
    # Add vertical line at 1% FPR
    plt.axvline(x=0.01, color='red', linestyle=':', alpha=0.7, label='1% FPR')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title(f'{real_dataset_name} vs VQ-Diffusion: Combined ROC Curves', fontsize=14, fontweight='bold')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save the plot
    combined_roc_filename = os.path.join(output_dir, f"{real_dataset_name.replace(' ', '_').replace('(', '').replace(')', '').lower()}_combined_roc_curve.png")
    plt.savefig(combined_roc_filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    return combined_roc_filename

def create_overall_combined_roc_curve(results_summary, batch_output_dir, config=None):
    """Create overall combined ROC curves showing all experiments for each loss type."""
    
    # Define metrics to plot
    metric_names = [
        ('recon_single', 'Reconstruction Loss (w/ Quantization)'),
        ('recon_no_quant', 'Reconstruction Loss (No Quantization)'),
        ('codebook_single', 'Codebook Loss'),
        ('recon_ratio', 'Reconstruction Ratio (Single/Double)'),
        ('recon_no_quant_ratio', 'No-Quantization Ratio (Single/Double)'),
        ('codebook_ratio', 'Codebook Loss Ratio'),
        ('combined1', 'Combined1 (Codebook × No-Quant Ratio)'),
        ('combined2', 'Combined2 (Codebook Ratio × No-Quant Ratio)')
    ]
    
    # Only include latent tracer if enabled
    if config is None or config.USE_LATENT_TRACER:
        metric_names.append(('latent_tracer', 'Latent Tracer Loss'))
    
    # Adjust subplot layout based on number of metrics
    if len(metric_names) <= 6:
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        axes = axes.flatten()
    elif len(metric_names) <= 8:
        fig, axes = plt.subplots(2, 4, figsize=(24, 10))
        axes = axes.flatten()
    else:
        # For 9 metrics, use 3x3 layout
        fig, axes = plt.subplots(3, 3, figsize=(27, 12))
        axes = axes.flatten()
    
    fig.suptitle('Overall Combined ROC Curves: All Experiments', fontsize=16, fontweight='bold')
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(results_summary)))
    
    for idx, (metric_key, metric_title) in enumerate(metric_names):
        ax = axes[idx]
        
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random Classifier')
        
        for result_idx, result in enumerate(results_summary):
            dataset_name = result['dataset_name']
            metrics = result['metrics']
            
            if metric_key in metrics and len(metrics[metric_key]['fpr']) > 0:
                ax.plot(metrics[metric_key]['fpr'], metrics[metric_key]['tpr'], 
                       color=colors[result_idx], linewidth=2,
                       label=f'{dataset_name} (AUC={metrics[metric_key]["auc"]:.3f})')
        
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title(metric_title)
        ax.legend(loc="lower right", fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for idx in range(len(metric_names), len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(f'{batch_output_dir}/overall_combined_roc_curves.png', dpi=300, bbox_inches='tight')
    plt.close()


def create_per_method_roc_curves(results_summary, batch_output_dir, config=None):
    """Create separate ROC curve plots for each method, showing all datasets."""
    
    # Define metrics to plot
    metric_names = [
        ('recon_single', 'Reconstruction Loss (w/ Quantization)'),
        ('recon_no_quant', 'Reconstruction Loss (No Quantization)'),
        ('codebook_single', 'Codebook Loss'),
        ('recon_ratio', 'Reconstruction Ratio (Single/Double)'),
        ('recon_no_quant_ratio', 'No-Quantization Ratio (Single/Double)'),
        ('codebook_ratio', 'Codebook Loss Ratio'),
        ('combined1', 'Combined1 (Codebook × No-Quant Ratio)'),
        ('combined2', 'Combined2 (Codebook Ratio × No-Quant Ratio)')
    ]
    
    # Only include latent tracer if enabled
    if config is None or config.USE_LATENT_TRACER:
        metric_names.append(('latent_tracer', 'Latent Tracer Loss'))
    
    # Use a better color palette for ROC curves
    colors = plt.cm.tab10(np.linspace(0, 1, len(results_summary)))
    
    for metric_key, metric_title in metric_names:
        plt.figure(figsize=(10, 8))
        plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, linewidth=2, label='Random Classifier')
        
        # Plot each dataset for this method
        for result_idx, result in enumerate(results_summary):
            dataset_name = result['dataset_name']
            metrics = result['metrics']
            
            if metric_key in metrics and len(metrics[metric_key]['fpr']) > 0:
                auc_score = metrics[metric_key]['auc']
                tpr_1pct = metrics[metric_key]['tpr_at_1_fpr']
                
                plt.plot(metrics[metric_key]['fpr'], metrics[metric_key]['tpr'], 
                        color=colors[result_idx], linewidth=3,
                        label=f'{dataset_name} (AUC={auc_score:.3f}, TPR@1%FPR={tpr_1pct:.3f})')
        
        # Add vertical line at 1% FPR
        plt.axvline(x=0.01, color='red', linestyle=':', alpha=0.7, linewidth=2, label='1% FPR')
        
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate', fontsize=12)
        plt.ylabel('True Positive Rate', fontsize=12)
        plt.title(f'{metric_title}\nROC Curves for All Datasets', fontsize=14, fontweight='bold')
        plt.legend(loc="lower right", fontsize=10)
        plt.grid(True, alpha=0.3)
        
        # Save the plot
        safe_filename = metric_key.replace('_', '-')
        filename = os.path.join(batch_output_dir, f'roc_curve_{safe_filename}_all_datasets.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"Saved per-method ROC curve: {filename}")
    
    return True


def create_performance_heatmaps(results_summary, batch_output_dir, config=None):
    """Create heatmaps for AUC and TPR@1%FPR results."""
    
    # Define metrics to plot
    metric_names = [
        ('recon_single', 'Recon\n(w/ Quant)'),
        ('recon_no_quant', 'Recon\n(No Quant)'),
        ('codebook_single', 'Codebook'),
        ('recon_ratio', 'Recon\nRatio'),
        ('recon_no_quant_ratio', 'NoQuant\nRatio'),
        ('codebook_ratio', 'Codebook\nRatio'),
        ('combined1', 'Combined1\n(CB×NQR)'),
        ('combined2', 'Combined2\n(CBR×NQR)')
    ]
    
    # Only include latent tracer if enabled
    if config is None or config.USE_LATENT_TRACER:
        metric_names.append(('latent_tracer', 'Latent\nTracer'))
    
    # Extract dataset names and create data matrices
    dataset_names = [result['dataset_name'] for result in results_summary]
    n_datasets = len(dataset_names)
    n_methods = len(metric_names)
    
    # Initialize matrices
    auc_matrix = np.zeros((n_datasets, n_methods))
    tpr_matrix = np.zeros((n_datasets, n_methods))
    
    # Fill matrices
    for i, result in enumerate(results_summary):
        metrics = result['metrics']
        for j, (metric_key, _) in enumerate(metric_names):
            if metric_key in metrics:
                auc_matrix[i, j] = metrics[metric_key]['auc']
                tpr_matrix[i, j] = metrics[metric_key]['tpr_at_1_fpr']
            else:
                auc_matrix[i, j] = np.nan
                tpr_matrix[i, j] = np.nan
    
    method_labels = [label for _, label in metric_names]
    
    # Create AUC heatmap
    plt.figure(figsize=(10, 8))
    im = plt.imshow(auc_matrix, cmap='hot', aspect='auto', vmin=0.5, vmax=1.0)
    
    # Add colorbar
    cbar = plt.colorbar(im, shrink=0.8)
    cbar.set_label('AUC Score', fontsize=12, fontweight='bold')
    
    # Set ticks and labels
    plt.xticks(range(n_methods), method_labels, fontsize=10, rotation=0)
    plt.yticks(range(n_datasets), dataset_names, fontsize=10)
    plt.xlabel('Detection Method', fontsize=12, fontweight='bold')
    plt.ylabel('Dataset', fontsize=12, fontweight='bold')
    plt.title('AUC Performance Heatmap\nVQ-Diffusion Generated Image Detection', 
              fontsize=14, fontweight='bold', pad=20)
    
    # Add text annotations
    for i in range(n_datasets):
        for j in range(n_methods):
            if not np.isnan(auc_matrix[i, j]):
                text_color = 'white' if auc_matrix[i, j] < 0.75 else 'white'
                plt.text(j, i, f'{auc_matrix[i, j]:.3f}', 
                        ha='center', va='center', fontsize=9, 
                        fontweight='bold', color=text_color)
    
    plt.tight_layout()
    auc_heatmap_file = os.path.join(batch_output_dir, 'auc_performance_heatmap.png')
    plt.savefig(auc_heatmap_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Create TPR@1%FPR heatmap
    plt.figure(figsize=(10, 8))
    im = plt.imshow(tpr_matrix, cmap='hot', aspect='auto', vmin=0.0, vmax=1.0)
    
    # Add colorbar
    cbar = plt.colorbar(im, shrink=0.8)
    cbar.set_label('TPR@1%FPR', fontsize=12, fontweight='bold')
    
    # Set ticks and labels
    plt.xticks(range(n_methods), method_labels, fontsize=10, rotation=0)
    plt.yticks(range(n_datasets), dataset_names, fontsize=10)
    plt.xlabel('Detection Method', fontsize=12, fontweight='bold')
    plt.ylabel('Dataset', fontsize=12, fontweight='bold')
    plt.title('TPR@1%FPR Performance Heatmap\nVQ-Diffusion Generated Image Detection', 
              fontsize=14, fontweight='bold', pad=20)
    
    # Add text annotations
    for i in range(n_datasets):
        for j in range(n_methods):
            if not np.isnan(tpr_matrix[i, j]):
                text_color = 'white'
                plt.text(j, i, f'{tpr_matrix[i, j]:.3f}', 
                        ha='center', va='center', fontsize=9, 
                        fontweight='bold', color=text_color)
    
    plt.tight_layout()
    tpr_heatmap_file = os.path.join(batch_output_dir, 'tpr_1pct_fpr_performance_heatmap.png')
    plt.savefig(tpr_heatmap_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Saved AUC heatmap: {auc_heatmap_file}")
    print(f"Saved TPR@1%FPR heatmap: {tpr_heatmap_file}")
    
    return True


def run_single_comparison_with_precomputed(dataset_name, dataset_path, generated_data, output_dir, codec, device, config):
    """Run comparison for a single dataset with precomputed generated data."""
    print(f"\n{'='*50}")
    print(f"Processing {dataset_name}")
    print(f"{'='*50}")
    
    # Process real images
    real_data = process_images_folder(codec, device, dataset_path, config, f"real {dataset_name}")
    
    if not real_data['recon_single']:
        print(f"No valid real images processed for {dataset_name}")
        return None
    
    # Save real data losses immediately
    save_final_losses(real_data, dataset_name.lower().replace(' ', '_').replace('(', '').replace(')', ''), config)
    
    # Create individual output directory
    individual_output_dir = os.path.join(output_dir, dataset_name.lower().replace(' ', '_'))
    os.makedirs(individual_output_dir, exist_ok=True)
    
    # Calculate metrics first
    metrics = {}
    
    # Reconstruction metrics
    metrics['recon_single'] = calculate_detection_metrics(real_data['recon_single'], generated_data['recon_single'], "Reconstruction Loss")
    metrics['recon_no_quant'] = calculate_detection_metrics(real_data['recon_no_quant'], generated_data['recon_no_quant'], "Reconstruction Loss (No Quantization)")
    metrics['recon_ratio'] = calculate_detection_metrics(real_data['recon_ratios'], generated_data['recon_ratios'], "Reconstruction Loss Ratio")
    metrics['recon_no_quant_ratio'] = calculate_detection_metrics(real_data['recon_no_quant_ratios'], generated_data['recon_no_quant_ratios'], "No-Quantization Ratio")
    
    # Codebook metrics
    metrics['codebook_single'] = calculate_detection_metrics(real_data['codebook_single'], generated_data['codebook_single'], "Codebook Loss")
    metrics['codebook_ratio'] = calculate_detection_metrics(real_data['codebook_ratios'], generated_data['codebook_ratios'], "Codebook Loss Ratio")
    
    # Latent tracer metric (only if enabled)
    if config.USE_LATENT_TRACER:
        metrics['latent_tracer'] = calculate_detection_metrics(real_data['latent_tracer'], generated_data['latent_tracer'], "Latent Tracer Loss")
    else:
        # Create dummy metric when disabled
        metrics['latent_tracer'] = {'auc': 0.0, 'tpr_at_1_fpr': 0.0, 'fpr': [], 'tpr': []}
    
    # Combined metrics
    metrics['combined1'] = calculate_detection_metrics(real_data['combined1'], generated_data['combined1'], "Combined1 (Codebook × No-Quant Ratio)")
    metrics['combined2'] = calculate_detection_metrics(real_data['combined2'], generated_data['combined2'], "Combined2 (Codebook Ratio × No-Quant Ratio)")
    
    # Create visualizations (now metrics are available)
    create_visualizations(real_data, generated_data, dataset_name, individual_output_dir)
    create_roc_curves(real_data, generated_data, dataset_name, individual_output_dir, config)
    create_combined_roc_curve(real_data, generated_data, dataset_name, individual_output_dir, config)
    
    # Print results
    print(f"\n--- RESULTS for {dataset_name} ---")
    print("RECONSTRUCTION METRICS:")
    print(f"  Single (w/ Quant) AUC: {metrics['recon_single']['auc']:.4f}, TPR@1%FPR: {metrics['recon_single']['tpr_at_1_fpr']:.4f}")
    print(f"  Single (No Quant) AUC: {metrics['recon_no_quant']['auc']:.4f}, TPR@1%FPR: {metrics['recon_no_quant']['tpr_at_1_fpr']:.4f}")
    print(f"  Ratio (Single/Double) AUC: {metrics['recon_ratio']['auc']:.4f}, TPR@1%FPR: {metrics['recon_ratio']['tpr_at_1_fpr']:.4f}")
    print(f"  No-Quant Ratio AUC: {metrics['recon_no_quant_ratio']['auc']:.4f}, TPR@1%FPR: {metrics['recon_no_quant_ratio']['tpr_at_1_fpr']:.4f}")
    print("CODEBOOK METRICS:")
    print(f"  Single AUC: {metrics['codebook_single']['auc']:.4f}, TPR@1%FPR: {metrics['codebook_single']['tpr_at_1_fpr']:.4f}")
    print(f"  Ratio AUC: {metrics['codebook_ratio']['auc']:.4f}, TPR@1%FPR: {metrics['codebook_ratio']['tpr_at_1_fpr']:.4f}")
    if config.USE_LATENT_TRACER:
        print("LATENT TRACER METRICS:")
        print(f"  Latent Tracer AUC: {metrics['latent_tracer']['auc']:.4f}, TPR@1%FPR: {metrics['latent_tracer']['tpr_at_1_fpr']:.4f}")
    else:
        print("LATENT TRACER METRICS: DISABLED")
    print("COMBINED METRICS:")
    print(f"  Combined1 (Codebook × No-Quant Ratio) AUC: {metrics['combined1']['auc']:.4f}, TPR@1%FPR: {metrics['combined1']['tpr_at_1_fpr']:.4f}")
    print(f"  Combined2 (Codebook Ratio × No-Quant Ratio) AUC: {metrics['combined2']['auc']:.4f}, TPR@1%FPR: {metrics['combined2']['tpr_at_1_fpr']:.4f}")
    
    return {
        'dataset_name': dataset_name,
        'metrics': metrics,
        'real_data': real_data,
        'generated_data': generated_data
    }


def save_comparison_results(results_summary, batch_output_dir, config=None):
    """Save comprehensive comparison results."""
    
    # Create summary DataFrame
    summary_data = []
    for result in results_summary:
        dataset_name = result['dataset_name']
        metrics = result['metrics']
        
        row_data = {
            'Dataset': dataset_name,
            'Recon_Single_AUC': metrics['recon_single']['auc'],
            'Recon_Single_TPR@1%FPR': metrics['recon_single']['tpr_at_1_fpr'],
            'Recon_NoQuant_AUC': metrics['recon_no_quant']['auc'],
            'Recon_NoQuant_TPR@1%FPR': metrics['recon_no_quant']['tpr_at_1_fpr'],
            'Recon_Ratio_AUC': metrics['recon_ratio']['auc'],
            'Recon_Ratio_TPR@1%FPR': metrics['recon_ratio']['tpr_at_1_fpr'],
            'ReconNoQuant_Ratio_AUC': metrics['recon_no_quant_ratio']['auc'],
            'ReconNoQuant_Ratio_TPR@1%FPR': metrics['recon_no_quant_ratio']['tpr_at_1_fpr'],
            'Codebook_Single_AUC': metrics['codebook_single']['auc'],
            'Codebook_Single_TPR@1%FPR': metrics['codebook_single']['tpr_at_1_fpr'],
            'Codebook_Ratio_AUC': metrics['codebook_ratio']['auc'],
            'Codebook_Ratio_TPR@1%FPR': metrics['codebook_ratio']['tpr_at_1_fpr']
        }
        
        # Add combined metrics columns
        row_data['Combined1_AUC'] = metrics['combined1']['auc']
        row_data['Combined1_TPR@1%FPR'] = metrics['combined1']['tpr_at_1_fpr']
        row_data['Combined2_AUC'] = metrics['combined2']['auc']
        row_data['Combined2_TPR@1%FPR'] = metrics['combined2']['tpr_at_1_fpr']
        
        # Only add latent tracer columns if enabled
        if config is None or config.USE_LATENT_TRACER:
            row_data['LatentTracer_AUC'] = metrics['latent_tracer']['auc']
            row_data['LatentTracer_TPR@1%FPR'] = metrics['latent_tracer']['tpr_at_1_fpr']
        
        summary_data.append(row_data)
    
    df_summary = pd.DataFrame(summary_data)
    df_summary.to_csv(f'{batch_output_dir}/comparison_summary.csv', index=False)
    
    # Create pivot tables for better visualization
    auc_columns = [col for col in df_summary.columns if 'AUC' in col]
    tpr_columns = [col for col in df_summary.columns if 'TPR' in col]
    
    df_auc = df_summary[['Dataset'] + auc_columns].set_index('Dataset')
    df_tpr = df_summary[['Dataset'] + tpr_columns].set_index('Dataset')
    
    df_auc.to_csv(f'{batch_output_dir}/auc_results.csv')
    df_tpr.to_csv(f'{batch_output_dir}/tpr_results.csv')
    
    # Create overall combined ROC curves
    create_overall_combined_roc_curve(results_summary, batch_output_dir, config)
    
    # Create per-method ROC curves (one plot per method showing all datasets)
    create_per_method_roc_curves(results_summary, batch_output_dir, config)
    
    # Create performance heatmaps
    create_performance_heatmaps(results_summary, batch_output_dir, config)
    
    print(f"\n{'='*60}")
    print("BATCH PROCESSING COMPLETE!")
    print(f"{'='*60}")
    print(f"Summary results saved to: {batch_output_dir}/")
    print("Files created:")
    print("  - comparison_summary.csv (all metrics)")
    print("  - auc_results.csv (AUC values only)")
    print("  - tpr_results.csv (TPR@1%FPR values only)")
    print("  - overall_combined_roc_curves.png")


def main():
    """Main function to run batch analysis."""
    print("=== VQ-DIFFUSION BATCH LOSS ANALYSIS ===")
    
    
    # Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize config with configurable paths
    print("Initializing configuration...")
    config = Config("batch", "", "", "")  # These will be overridden by the config paths
    
    # Use configurable datasets from config
    datasets = config.DATASETS
    print(f"Analyzing {len(datasets)} datasets: {list(datasets.keys())}")
    
    # Use configurable paths from config
    generated_folder = config.GENERATED_IMAGES_FOLDER
    batch_output_dir = config.BATCH_OUTPUT_DIR
    
    print(f"Generated images folder: {generated_folder}")
    print(f"Results will be saved to: {batch_output_dir}")
    
    os.makedirs(batch_output_dir, exist_ok=True)
    
    # Initialize codec
    print("Initializing VQ codec...")
    codec = initialize_vq_codec(config)
    codec = codec.to(device)
    codec.eval()
    
    # Process generated images once
    print("\nProcessing VQ-Diffusion generated images (once for all comparisons)...")
    generated_data = process_images_folder(codec, device, generated_folder, config, "VQ-Diffusion")
    
    if not generated_data['recon_single']:
        print("Error: No valid generated images processed")
        return
    
    print(f"VQ-Diffusion images processed: {len(generated_data['recon_single'])}")
    
    # Save generated data losses immediately
    save_final_losses(generated_data, "vq_diffusion_generated", config)
    
    # Process each real dataset
    results_summary = []
    for dataset_name, dataset_path in datasets.items():
        individual_output_dir = os.path.join(batch_output_dir, dataset_name.lower())
        result = run_single_comparison_with_precomputed(dataset_name, dataset_path, generated_data, individual_output_dir, codec, device, config)
        if result:
            results_summary.append(result)
    
    # Save comprehensive comparison results
    if results_summary:
        save_comparison_results(results_summary, batch_output_dir, config)
        
        print(f"\n{'='*80}")
        print("BATCH ANALYSIS COMPLETE!")
        print(f"{'='*80}")
        print(f"Total comparisons completed: {len(results_summary)}")
        print(f"Results saved in: {batch_output_dir}")
        print(f"Individual results in subdirectories")
        print(f"Comprehensive summary in batch_analysis_*.csv files")
    else:
        print("No successful comparisons completed.")


if __name__ == "__main__":
    main()
