import os
import copy
import random
import torch
import torch.nn.functional as F
import numpy as np
import importlib
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.cm as mpl_color_map
import distinctipy
from torchvision import transforms
from torchvision.utils import make_grid
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionPipeline,
)
from einops import rearrange, reduce, repeat

from backbone import UNetEncoder
from slot_attn import MultiHeadSTEVESA
from unet_with_pos import UNet2DConditionModelWithPos, UNet2DModelWithPos
from pipelines import UnconditionalDiffusionPipeline

import numpy as np
import scipy
import scipy.optimize as optimize
import torch
import torch.nn.functional as F


def load_models(ckpt_path, use_slots, device='cuda:0', weight_dtype=torch.float32, scheduler_path='/configs/faces/scheduler/scheduler_config.json'):
    if use_slots:
        backbone = UNetEncoder.from_pretrained(
            ckpt_path, subfolder="UNetEncoder".lower())
        backbone = backbone.to(device=device, dtype=weight_dtype)
        slot_attn = MultiHeadSTEVESA.from_pretrained(
            ckpt_path, subfolder="MultiHeadSTEVESA".lower())
        slot_attn = slot_attn.to(device=device, dtype=weight_dtype)
        unet = UNet2DConditionModelWithPos.from_pretrained(
            ckpt_path, subfolder="UNet2DConditionModelWithPos".lower())
    else:
        backbone, slot_attn = None, None
        unet = UNet2DModelWithPos.from_pretrained(
            ckpt_path, subfolder="UNet2DModelWithPos".lower())
    vae = AutoencoderKL.from_pretrained(
            'stabilityai/stable-diffusion-2-1', subfolder="vae").to(device)
    # scheduler = DDPMScheduler.from_pretrained(
    #     'stabilityai/stable-diffusion-2-1', subfolder="scheduler")
    noise_scheduler_config = DDPMScheduler.load_config(scheduler_path)
    scheduler = DDPMScheduler.from_config(noise_scheduler_config)
    return vae, unet, scheduler, backbone, slot_attn


def modify_prediction_celeba(gt_label, pred, dataset, seed=42, preferred_attrs=None):
    """
    For each sample in a batch, randomly choose one correctly predicted CelebA attribute and flip it.
    
    Args:
        gt_label (torch.Tensor): Ground truth labels, shape [batch_size, num_attributes]
        pred (torch.Tensor): Predicted labels, shape [batch_size, num_attributes]
        dataset: Dataset instance containing attribute names
        seed (int, optional): Random seed for deterministic attribute selection. If None, selection is random.
        preferred_attrs (list, optional): List of attribute names to prioritize for modification.
            If None, the following default attributes are used:
            ['Male', 'Young', 'Smiling', 'Eyeglasses', 'Wearing_Hat', 'No_Beard', 'Mouth_Slightly_Open', 
             'Blond_Hair', 'Wavy_Hair', 'Bangs', 'Black_Hair']
        
    Returns:
        tuple: (target_labels, changes)
            - target_labels (torch.Tensor): Modified labels with the selected attributes flipped
            - changes (list): List of dictionaries containing information about the changes made
    """
    batch_size = gt_label.shape[0]
    num_attributes = gt_label.shape[1]
    
    # Define default preferred attributes if not provided (those with visible impact)
    if preferred_attrs is None:
        preferred_attrs = [
            'Smiling', 'Mouth_Slightly_Open', 'Sideburns',
            'Blond_Hair', 'Wavy_Hair', 'Straight_Hair', 
            'Black_Hair', "Pale_Skin", "Pointy_Nose", 
        ]
    
    # Convert attribute names to indices
    preferred_attr_indices = []
    for attr in preferred_attrs:
        if attr in dataset.attr_names:
            idx = dataset.attr_names.index(attr)
            preferred_attr_indices.append(idx)
    
    # Initialize the target labels as a copy of the ground truth
    target_labels = gt_label.clone()
    
    # List to store information about changes
    changes = []
    
    # Initialize random generator with seed if provided
    if seed is not None:
        # Set PyTorch's global seed
        torch.manual_seed(seed)
        # Create a separate generator for more explicit control
        rng = torch.Generator()
        rng.manual_seed(seed)
    else:
        rng = torch.Generator()
    
    for i in range(batch_size):
        # First try to find correctly predicted attributes from the preferred list
        preferred_correct_predictions = []
        for idx in preferred_attr_indices:
            if gt_label[i, idx] == pred[i, idx]:
                preferred_correct_predictions.append(idx)
        
        # If we found preferred attributes that are correctly predicted, use them
        if preferred_correct_predictions:
            # Randomly select from the preferred attributes
            random_idx = torch.randint(0, len(preferred_correct_predictions), (1,), generator=rng).item()
            attribute_idx = preferred_correct_predictions[random_idx]
        else:
            # If no preferred attributes are correctly predicted, fall back to all correctly predicted attributes
            correct_predictions = (gt_label[i] == pred[i]).nonzero(as_tuple=True)[0]
            
            if len(correct_predictions) == 0:
                # If no attributes are correctly predicted, use a random preferred attribute
                attribute_idx = preferred_attr_indices[torch.randint(0, len(preferred_attr_indices), (1,), generator=rng).item()]
            else:
                # Randomly select one correctly predicted attribute
                random_idx = torch.randint(0, len(correct_predictions), (1,), generator=rng).item()
                attribute_idx = correct_predictions[random_idx].item()
        
        # Flip the selected attribute (0 to 1 or 1 to 0)
        original_value = gt_label[i, attribute_idx].item()
        target_labels[i, attribute_idx] = 1 - original_value
        
        # Store information about the change
        attribute_name = dataset.attr_names[attribute_idx]
        changes.append({
            'sample_idx': i,
            'attr_idx': attribute_idx,
            'attr_name': attribute_name,
            'original': 'Yes' if original_value == 1 else 'No',
            'modified': 'Yes' if target_labels[i, attribute_idx].item() == 1 else 'No'
        })
    
    return target_labels, changes

def modify_prediction_clevrtex(gt_features, pred_features, dataset, seed=42, preferred_attrs=None):
    """
    For each sample in a batch, randomly choose one correctly predicted ClevrTex attribute and modify it.
    
    Args:
        gt_features (torch.Tensor): Ground truth feature vectors, shape [batch_size, num_objects, feature_dim]
        pred_features (torch.Tensor): Predicted feature vectors, shape [batch_size, num_objects, feature_dim]
        dataset: ClevrTexDataset instance containing attribute mappings
        seed (int, optional): Random seed for deterministic attribute selection
        preferred_attrs (list, optional): List of attribute types to prioritize for modification.
            If None, defaults to ['shape', 'size', 'material']
        
    Returns:
        tuple: (target_features, changes)
            - target_features (torch.Tensor): Modified features with selected attributes changed
            - changes (list): List of dictionaries with information about the changes made
    """
    batch_size = gt_features.shape[0]
    
    # Define property slices matching SlotClassifier
    slices = {
        'shape': slice(0, 4),
        'size': slice(4, 7),
        'material': slice(7, 67),
        'coords': slice(67, 70),
        'visibility': slice(70, 71)
    }
    
    # Define default preferred attributes if not provided
    if preferred_attrs is None:
        # preferred_attrs = ['shape', 'size', 'material']
        preferred_attrs = ['size']

    
    # Initialize the target features as a copy of the ground truth
    target_features = gt_features.clone()
    
    # List to store information about changes
    changes = []
    
    # Initialize random generator with seed if provided
    if seed is not None:
        torch.manual_seed(seed)
        rng = torch.Generator()
        rng.manual_seed(seed)
    else:
        rng = torch.Generator()
    
    # First compute Hungarian matching to align predictions with ground truth
    indices = _compute_hungarian_matching(gt_features, pred_features)
    
    for i in range(batch_size):
        # Get matched pairs for this sample
        matched_pairs = indices[i]
        
        # First try to find correctly predicted attributes from the preferred list
        preferred_correct_predictions = []
        for obj_idx, pred_idx in matched_pairs:
            for attr in preferred_attrs:
                attr_slice = slices[attr]
                if torch.allclose(gt_features[i, obj_idx, attr_slice], pred_features[i, pred_idx, attr_slice], atol=0.1):
                    preferred_correct_predictions.append((obj_idx, pred_idx, attr))
        
        # If we found preferred attributes that are correctly predicted, use them
        if preferred_correct_predictions:
            # Randomly select from the preferred attributes
            random_idx = torch.randint(0, len(preferred_correct_predictions), (1,), generator=rng).item()
            obj_idx, pred_idx, attribute = preferred_correct_predictions[random_idx]
        else:
            # If no preferred attributes are correctly predicted, fall back to all correctly predicted attributes
            correct_predictions = []
            for obj_idx, pred_idx in matched_pairs:
                for attr in slices.keys():
                    attr_slice = slices[attr]
                    if torch.allclose(gt_features[i, obj_idx, attr_slice], pred_features[i, pred_idx, attr_slice], atol=0.1):
                        correct_predictions.append((obj_idx, pred_idx, attr))
            
            if len(correct_predictions) == 0:
                # If no attributes are correctly predicted, use a random preferred attribute
                obj_idx, pred_idx, attribute = matched_pairs[torch.randint(0, len(matched_pairs), (1,), generator=rng).item()]
                attribute = preferred_attrs[torch.randint(0, len(preferred_attrs), (1,), generator=rng).item()]
            else:
                # Randomly select one correctly predicted attribute
                random_idx = torch.randint(0, len(correct_predictions), (1,), generator=rng).item()
                obj_idx, pred_idx, attribute = correct_predictions[random_idx]
        
        # Modify the selected attribute
        attr_slice = slices[attribute]
        original_value = gt_features[i, obj_idx, attr_slice].clone()
        if attribute == 'shape':
            target_features[i, obj_idx, attr_slice] = torch.roll(original_value, shifts=1, dims=0)
        elif attribute == 'size':
            target_features[i, obj_idx, attr_slice] = torch.roll(original_value, shifts=1, dims=0)
        elif attribute == 'material':
            target_features[i, obj_idx, attr_slice] = torch.roll(original_value, shifts=1, dims=0)
        
        # Store information about the change
        changes.append({
            'sample_idx': i,
            'obj_idx': obj_idx,
            'pred_idx': pred_idx,
            'attr_name': attribute,
            'original': original_value.cpu().numpy(),
            'modified': target_features[i, obj_idx, attr_slice].cpu().numpy()
        })
    
    return target_features, changes

def annotate_attribute_changes(image, changes, batch_size):
    """
    Add highly visible text annotations to show attribute changes on the image.
    
    Args:
        image (PIL.Image): The generated image grid
        changes (list): List of dictionaries with attribute change information
        batch_size (int): Number of images in the batch
        
    Returns:
        PIL.Image: Image with text annotations
    """
    draw = ImageDraw.Draw(image)
    
    # Try to load a font, use default if not available
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 14)
    except IOError:
        try:
            font = ImageFont.truetype("Arial.ttf", 14)
        except IOError:
            font = ImageFont.load_default()
    
    # Determine image width per sample
    total_width = image.width
    img_width = total_width // batch_size
    
    # For each sample in the batch
    for change in changes:
        sample_idx = change['sample_idx']
        if sample_idx < batch_size:
            # Calculate x position for this sample
            x_pos = sample_idx * img_width
            
            # Create annotation text
            attr_text = f"{change['attr_name']}"
            change_text = f"{change['original']} → {change['modified']}"
            
            # Create semi-transparent background for better visibility
            text_w, text_h = draw.textbbox((0, 0), attr_text, font=font)[2:]
            change_w, change_h = draw.textbbox((0, 0), change_text, font=font)[2:]
            max_w = max(text_w, change_w)
            
            # Create a semi-transparent overlay for text background
            overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
            overlay_draw = ImageDraw.Draw(overlay)
            padding = 4
            overlay_draw.rectangle(
                [x_pos + 5, 5, x_pos + 5 + max_w + 2*padding, 5 + text_h + change_h + 3*padding],
                fill=(0, 0, 0, 180)
            )
            
            # Composite the overlay with the original image
            image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
            draw = ImageDraw.Draw(image)
            
            # Add text with contrasting colors
            draw.text((x_pos + 5 + padding, 5 + padding), attr_text, fill=(255, 255, 0), font=font)
            draw.text((x_pos + 5 + padding, 5 + padding + text_h + padding), change_text, fill=(255, 255, 255), font=font)
    
    return image

def annotate_clevrtex_changes(image, changes, batch_size, dataset):
    """
    Add text annotations to show object attribute changes on ClevrTex images.
    
    Args:
        image (PIL.Image): The generated image grid
        changes (list): List of dictionaries with attribute change information for ClevrTex
        batch_size (int): Number of images in the batch
        
    Returns:
        PIL.Image: Image with text annotations
    """
    draw = ImageDraw.Draw(image)
    
    # Try to load a font, use default if not available
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 14)
    except IOError:
        try:
            font = ImageFont.truetype("Arial.ttf", 14)
        except IOError:
            font = ImageFont.load_default()
    
    # Determine image width per sample
    total_width = image.width
    img_width = total_width // batch_size
    
    # For each sample in the batch
    for change in changes:
        sample_idx = change['sample_idx']
        if sample_idx < batch_size:
            # Calculate x position for this sample
            x_pos = sample_idx * img_width
            
            # Create annotation text for object and attribute
            obj_text = f"Object {change['obj_idx']}"
            attr_text = f"{change['attr_name'].capitalize()}"
            
            # Format the change depending on the attribute type
            if change['attr_name'] == 'shape':
                # For shape, find the index of the 1 in one-hot encoding
                orig_idx = np.argmax(change['original'])
                mod_idx = np.argmax(change['modified'])
                shape_names = list(dataset.labels_to_idx['shape'].keys()) if hasattr(dataset, 'labels_to_idx') else [f"Shape_{i}" for i in range(4)]
                change_text = f"{shape_names[orig_idx]} → {shape_names[mod_idx]}"
            elif change['attr_name'] == 'size':
                # Similar for size
                orig_idx = np.argmax(change['original'])
                mod_idx = np.argmax(change['modified'])
                size_names = list(dataset.labels_to_idx['size'].keys()) if hasattr(dataset, 'labels_to_idx') else [f"Size_{i}" for i in range(3)]
                change_text = f"{size_names[orig_idx]} → {size_names[mod_idx]}"
            elif change['attr_name'] == 'material':
                # Similar for material
                orig_idx = np.argmax(change['original'])
                mod_idx = np.argmax(change['modified'])
                material_names = list(dataset.labels_to_idx['material'].keys()) if hasattr(dataset, 'labels_to_idx') else [f"Mat_{i}" for i in range(len(change['original']))]
                change_text = f"{material_names[orig_idx]} → {material_names[mod_idx]}"
            else:
                # Generic format for other attributes
                change_text = "Modified"
            
            # Create a semi-transparent overlay for text background
            overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
            overlay_draw = ImageDraw.Draw(overlay)
            
            # Calculate text sizes for layout
            obj_bbox = draw.textbbox((0, 0), obj_text, font=font)
            attr_bbox = draw.textbbox((0, 0), attr_text, font=font)
            change_bbox = draw.textbbox((0, 0), change_text, font=font)
            
            obj_w, obj_h = obj_bbox[2] - obj_bbox[0], obj_bbox[3] - obj_bbox[1]
            attr_w, attr_h = attr_bbox[2] - attr_bbox[0], attr_bbox[3] - attr_bbox[1]
            change_w, change_h = change_bbox[2] - change_bbox[0], change_bbox[3] - change_bbox[1]
            
            max_w = max(obj_w, attr_w, change_w)
            total_h = obj_h + attr_h + change_h + 12  # 4px padding between lines
            
            padding = 4
            overlay_draw.rectangle(
                [x_pos + 5, 5, x_pos + 5 + max_w + 2*padding, 5 + total_h + 2*padding],
                fill=(0, 0, 0, 180)
            )
            
            # Composite the overlay with the original image
            image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
            draw = ImageDraw.Draw(image)
            
            # Add text with contrasting colors
            y_pos = 5 + padding
            draw.text((x_pos + 5 + padding, y_pos), obj_text, fill=(255, 255, 0), font=font)
            y_pos += obj_h + 4
            draw.text((x_pos + 5 + padding, y_pos), attr_text, fill=(0, 255, 255), font=font)
            y_pos += attr_h + 4
            draw.text((x_pos + 5 + padding, y_pos), change_text, fill=(255, 255, 255), font=font)
    
    return image

def create_attribute_mask_from_changes(changes, num_attributes, device):
    """
    Create a binary mask tensor that has 1s only for attributes we want to modify.
    
    Args:
        changes (list): List of dictionaries with attribute change information
        num_attributes (int): Total number of attributes in the model
        device (str): Device to create tensor on
        
    Returns:
        torch.Tensor: Binary mask of shape [batch_size, num_attributes]
    """
    batch_size = len(changes)
    mask = torch.zeros((batch_size, num_attributes), device=device)
    
    # Set 1s only for attributes we're intentionally changing
    for i, change in enumerate(changes):
        attr_idx = change['attr_idx']
        mask[i, attr_idx] = 1.0
    
    return mask

def create_clevrtex_attribute_mask(changes, num_objects, feature_dim, device):
    """
    Create a binary mask tensor with 1s only for object attributes we want to modify.
    
    Args:
        changes (list): List of dictionaries with object/attribute change information
        num_objects (int): Maximum number of objects per image
        feature_dim (int): Total feature dimension per object
        device (str): Device to create tensor on
        
    Returns:
        torch.Tensor: Binary mask of shape [batch_size, num_objects, feature_dim]
    """
    batch_size = len(changes)
    mask = torch.zeros((batch_size, num_objects, feature_dim), device=device)
    
    # Property slices matching SlotClassifier
    slices = {
        'shape': slice(0, 4),
        'size': slice(4, 7),
        'material': slice(7, 67),
        'coords': slice(67, 70),
        'visibility': slice(70, 71)
    }
    
    # Set 1s only for attributes we're intentionally changing
    for i, change in enumerate(changes):
        obj_idx = change['obj_idx']
        attr_name = change['attr_name']
        if attr_name in slices:
            mask[i, obj_idx, slices[attr_name]] = 1.0
    
    return mask

def calculate_attribute_success_rate(changes, classifier, generated_images, original_images, device):
    """
    Calculate success and failure rates for attribute modifications:
    - Success rate: How often target attributes were successfully changed
    - Failure rate: How often unchanged attributes were unintentionally modified
    
    Args:
        changes (list): List of dictionaries with attribute change information
        classifier: Classifier model for attribute prediction
        generated_images (torch.Tensor): Generated/modified images
        original_images (torch.Tensor): Original input images
        device: Device to run inference on
        
    Returns:
        tuple: (success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs)
    """
    if len(changes) == 0:
        return 0.0, 0.0, 0, 0, 0, 0
    
    # Get predictions for both original and generated images
    with torch.no_grad():
        # Original image predictions
        orig_outputs = classifier(original_images.to(device))
        if hasattr(orig_outputs, 'logits'):
            orig_logits = orig_outputs.logits
        else:
            orig_logits = orig_outputs
        orig_preds = (orig_logits > 0).float()
        
        # Generated image predictions
        gen_outputs = classifier(generated_images.to(device))
        if hasattr(gen_outputs, 'logits'):
            gen_logits = gen_outputs.logits
        else:
            gen_logits = gen_outputs
        gen_preds = (gen_logits > 0).float()
    
    # Track changes
    batch_size = original_images.shape[0]
    num_attributes = orig_preds.shape[1]
    success_count = 0
    total_target_attrs = 0
    failure_count = 0
    total_unchanged_attrs = 0
    
    # For each sample in the batch
    for i, change in enumerate(changes):
        sample_idx = change['sample_idx']
        attr_idx = change['attr_idx']
        target_value = 1.0 if change['modified'] == 'Yes' else 0.0
        
        # Check if target attribute was successfully changed
        if orig_preds[sample_idx, attr_idx] != gen_preds[sample_idx, attr_idx]:
            # Only count as success if it changed to the target value
            if gen_preds[sample_idx, attr_idx].item() == target_value:
                success_count += 1
            else:
                # Changed but in the wrong direction
                failure_count += 1  # Count this as a failure too
        total_target_attrs += 1
        
        # Check unintended changes in other attributes
        for j in range(num_attributes):
            if j != attr_idx:  # Skip the intentionally modified attribute
                # If prediction changed from original to generated, it's a failure
                if orig_preds[sample_idx, j] != gen_preds[sample_idx, j]:
                    failure_count += 1
                total_unchanged_attrs += 1
    
    # Calculate rates
    success_rate = (success_count / total_target_attrs) * 100 if total_target_attrs > 0 else 0.0
    failure_rate = (failure_count / total_unchanged_attrs) * 100 if total_unchanged_attrs > 0 else 0.0
    
    return success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs

def calculate_clevrtex_accuracy(changes, classifier, generated_images, original_images, device):
    """
    Calculate success and failure rates for attribute modifications in ClevrTex objects:
    
    Args:
        changes (list): List of dictionaries with attribute change information
        classifier: SlotClassifier model for object property prediction
        generated_images (torch.Tensor): Generated/modified images
        original_images (torch.Tensor): Original input images
        device: Device to run inference on
        
    Returns:
        tuple: (success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs)
    """
    if len(changes) == 0:
        return 0.0, 0.0, 0, 0, 0, 0
    
    # Property slices matching SlotClassifier
    slices = {
        'shape': slice(0, 4),
        'size': slice(4, 7),
        'material': slice(7, 67),
        'coords': slice(67, 70),
        'visibility': slice(70, 71)
    }
    
    # Get predictions for both original and generated images
    with torch.no_grad():
        # Original image predictions
        orig_outputs = classifier(original_images.to(device))
        
        # Generated image predictions
        gen_outputs = classifier(generated_images.to(device))
        
        # For each batch item, compute Hungarian matching between original and generated predictions
        orig_indices = _compute_hungarian_matching(orig_outputs, orig_outputs)  # Just to get indices structure
        gen_indices = _compute_hungarian_matching(gen_outputs, orig_outputs)
    
    # Track changes
    batch_size = original_images.shape[0]
    success_count = 0
    total_target_attrs = 0
    failure_count = 0
    total_unchanged_attrs = 0
    
    # For each sample in the batch
    for i, change in enumerate(changes):
        sample_idx = change['sample_idx']
        obj_idx = change['obj_idx']
        attr_name = change['attr_name']
        attr_slice = slices[attr_name]
        
        # Find which object in the generated output corresponds to our target object
        # using the Hungarian matching
        orig_obj_idx = orig_indices[sample_idx, obj_idx, 0].item()  # Original object index
        gen_obj_idx = gen_indices[sample_idx, obj_idx, 1].item()   # Generated object index
        
        # Check if target attribute was successfully changed
        orig_attr = orig_outputs[sample_idx, orig_obj_idx, attr_slice]
        gen_attr = gen_outputs[sample_idx, gen_obj_idx, attr_slice]
        
        # For categorical attributes, compare argmax
        if attr_name in ['shape', 'size', 'material']:
            orig_val = torch.argmax(orig_attr).item()
            gen_val = torch.argmax(gen_attr).item()
            target_val = torch.argmax(torch.tensor(change['modified'])).item()
            
            if orig_val != gen_val:
                # Only count as success if it changed to the target value
                if gen_val == target_val:
                    success_count += 1
                else:
                    # Changed but in the wrong way
                    failure_count += 1
            total_target_attrs += 1
            
            # Check other categorical attributes for unintentional changes
            for other_attr in ['shape', 'size', 'material']:
                if other_attr != attr_name:
                    other_slice = slices[other_attr]
                    other_orig = torch.argmax(orig_outputs[sample_idx, orig_obj_idx, other_slice]).item()
                    other_gen = torch.argmax(gen_outputs[sample_idx, gen_obj_idx, other_slice]).item()
                    
                    if other_orig != other_gen:
                        failure_count += 1
                    total_unchanged_attrs += 1
        
        # Special handling for coordinates or other continuous attributes could be added here
    
    # Calculate rates
    success_rate = (success_count / total_target_attrs) * 100 if total_target_attrs > 0 else 0.0
    failure_rate = (failure_count / total_unchanged_attrs) * 100 if total_unchanged_attrs > 0 else 0.0
    
    return success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs

@torch.no_grad()
def generate_images(
    val_dataset,
    unet,
    vae,
    scheduler,
    mode='sample',  # 'sample' or 'reconstruct'
    classifier=None,
    target_step=500,
    num_inference_steps=200,
    guidance_scale=1.0,
    conditioning_scale=1.0,  # Added new parameter for slot conditioning gradient strength
    backbone=None,
    slot_attn=None,
    weight_dtype=torch.bfloat16,
    n_images=16,
    device="cuda:0",
    visualize_attn=True,  # Whether to visualize attention maps when using slots
    mask_attributes=True,  # Whether to mask attributes for focused editing
    dataset_type="celeba",  # Added parameter to specify dataset type ('celeba' or 'clevrtex')
):
    """
    Unified function for image generation and reconstruction.
    
    Args:
        val_dataset: Dataset containing validation images
        unet: UNet model for diffusion
        vae: VAE model for encoding/decoding images
        scheduler: Diffusion scheduler
        mode: Either 'sample' for unconditional generation or 'reconstruct' for reconstruction
        classifier: Optional classifier for guidance
        target_step: Target diffusion step for reconstruction
        num_inference_steps: Number of inference steps for diffusion
        guidance_scale: Scale factor for classifier/unconditional guidance
        conditioning_scale: Scale factor for slot conditioning gradient strength (only for slot models)
        backbone: Feature extractor backbone for slot attention
        slot_attn: Optional slot attention module
        weight_dtype: Data type for model weights
        n_images: Maximum number of images to generate
        device: Device to run generation on
        visualize_attn: Whether to visualize attention maps when using slot attention
        mask_attributes (bool): If True, only consider target attributes for classifier guidance
        dataset_type (str): Type of dataset - 'celeba' or 'clevrtex'
    
    Returns:
        list: List of PIL images containing the generated results
    """
    # Initialize colorizer for slot visualization if needed
    colorizer = None
    if slot_attn and visualize_attn:
        colorizer = ColorMask(
            num_slots=slot_attn.config.num_slots,
            log_img_size=256,
            norm_mean=0,
            norm_std=1,
        )

    # Set up dataloader
    batch_size = min(8, n_images)
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    # Configure and initialize scheduler
    scheduler_args = {"variance_type": "fixed_small"} if "variance_type" in scheduler.config else {}
    scheduler = getattr(importlib.import_module("diffusers"), "DDIMScheduler").from_config(
        scheduler.config, **scheduler_args
    )

    # Select appropriate pipeline based on mode and model configuration
    if mode == 'sample':
        pipeline_class = StableDiffusionPipeline if slot_attn else UnconditionalDiffusionPipeline
    else:  # reconstruct
        # pipeline_class = SlotConditionedReconstructionPipeline if slot_attn else UnconditionalReconstructionPipeline
        raise NotImplementedError("Reconstruction pipelines not available in this code_appendix version")

    # Prepare pipeline arguments
    pipe_args = {
        "vae": vae,
        "unet": unet,
        "scheduler": scheduler,
    }
    
    # Add required dummy arguments for StableDiffusionPipeline
    if slot_attn and mode == "sample":
        pipe_args.update({
            "text_encoder": None, 
            "tokenizer": None, 
            "safety_checker": None, 
            "feature_extractor": None
        })

    # Initialize pipeline
    pipeline = pipeline_class(**pipe_args).to(device)
    generator = torch.Generator(device=device).manual_seed(42)
    
    images = []
    image_count = 0
    
    # Track success metrics across batches
    overall_success_count = 0
    overall_target_count = 0
    overall_failure_count = 0
    overall_unchanged_count = 0

    for batch in val_dataloader:
        pixel_values = batch["pixel_values"].to(device=device, dtype=weight_dtype)

        with torch.autocast("cuda"):
            # Common preprocessing
            model_input = vae.encode(pixel_values).latent_dist.sample()
            pixel_values_recon = vae.decode(model_input).sample

            target_labels = None
            changes = None
            attribute_mask = None
            
            # Process classifier guidance based on dataset type
            if classifier and mode == 'reconstruct':
                if dataset_type.lower() == "celeba" and 'labels' in batch:
                    # CelebA processing
                    gt_labels = batch['labels'].to(device=device, dtype=weight_dtype)
                    
                    # Get classifier predictions
                    with torch.no_grad():
                        outputs = classifier(pixel_values)
                        logits = outputs.logits if hasattr(outputs, 'logits') else outputs
                        preds = (logits > 0).float()
                        
                    # Generate target labels by flipping selected attributes
                    target_labels, changes = modify_prediction_celeba(gt_labels, preds, val_dataset)
                    target_labels = target_labels.to(device=device, dtype=weight_dtype)
                    
                    # Create attribute mask if requested
                    if mask_attributes and changes:
                        attribute_mask = create_attribute_mask_from_changes(
                            changes, target_labels.shape[1], device=device
                        )
                
                elif dataset_type.lower() == "clevrtex" and 'feature_vectors' in batch:
                    # ClevrTex processing
                    gt_features = batch['feature_vectors'].to(device=device, dtype=weight_dtype)
                    
                    # Get classifier predictions
                    with torch.no_grad():
                        pred_features = classifier(pixel_values)
                    
                    # Generate target features by modifying selected attributes
                    target_labels, changes = modify_prediction_clevrtex(gt_features, pred_features, val_dataset)
                    target_labels = target_labels.to(device=device, dtype=weight_dtype)
                    
                    # Create attribute mask if requested
                    if mask_attributes and changes:
                        attribute_mask = create_clevrtex_attribute_mask(
                            changes, target_labels.shape[1], target_labels.shape[2], device=device
                        )

            # Process slots if available
            slots = None
            attn = None
            if slot_attn is not None:
                feat = backbone(pixel_values)
                slots, attn = slot_attn(feat[:, None])
                slots = slots[:, 0]

            # Configure pipeline parameters
            pipeline_kwargs = {
                "generator": generator,
                "output_type": "pt",
            }
            
            if mode == 'sample':
                pipeline_kwargs.update({
                    "height": 256,
                    "width": 256,
                    "num_inference_steps": num_inference_steps,
                    "guidance_scale": guidance_scale,
                })
                if slot_attn:
                    pipeline_kwargs["prompt_embeds"] = slots
                else:
                    pipeline_kwargs["batch_size"] = pixel_values.shape[0]
            else:  # reconstruct
                pipeline_kwargs.update({
                    "images": pixel_values,
                    "target_step": target_step,
                    "reconstruction_steps": num_inference_steps,
                    "guidance_scale": guidance_scale,
                    "classifier": classifier,
                    "target_labels": target_labels,
                    "attribute_mask": attribute_mask,
                    "dataset_type": dataset_type,  # Pass dataset type to pipeline
                })
                if slot_attn:
                    pipeline_kwargs["prompt_embeds"] = slots
                    pipeline_kwargs["conditioning_scale"] = conditioning_scale

            # Generate images using the pipeline
            images_gen = pipeline(**pipeline_kwargs).images
            
            # Calculate success and failure rates for attribute modifications
            success_rate = failure_rate = 0.0
            success_count = failure_count = 0
            
            if classifier and changes and mode == 'reconstruct':
                if dataset_type.lower() == "celeba":
                    # Calculate metrics for CelebA
                    success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs = calculate_attribute_success_rate(
                        changes, classifier, images_gen, pixel_values, device
                    )
                else:
                    # Calculate metrics for ClevrTex
                    success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs = calculate_clevrtex_accuracy(
                        changes, classifier, images_gen, pixel_values, device
                    )
                
                # Update overall counts
                overall_success_count += success_count
                overall_target_count += total_target_attrs
                overall_failure_count += failure_count
                overall_unchanged_count += total_unchanged_attrs

            # Create visualization grid
            if slot_attn and visualize_attn and colorizer:
                # Create visualization with attention heatmaps
                grid_image = colorizer.get_heatmap(
                    img=(pixel_values * 0.5 + 0.5),
                    attn=reduce(attn[:, 0], 'b num_h (h w) s -> b s h w',
                              h=int(np.sqrt(attn.shape[-2])), reduction='mean'),
                    recon=[pixel_values_recon * 0.5 + 0.5, images_gen]
                )
            else:
                # Simple grid of original, reconstructed, and generated images
                grid_image = torch.cat([
                    rearrange(pixel_values * 0.5 + 0.5, 'b c h w -> c h (b w)'),
                    rearrange(pixel_values_recon * 0.5 + 0.5, 'b c h w -> c h (b w)'),
                    rearrange(images_gen, 'b c h w -> c h (b w)')
                ], dim=1)

            # Convert tensor to PIL Image
            ndarr = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
            pil_image = Image.fromarray(ndarr)
            
            # Add text annotations for attribute changes if in reconstruction mode with classifier
            if mode == 'reconstruct' and classifier and changes:
                if dataset_type.lower() == "celeba":
                    pil_image = annotate_attribute_changes(pil_image, changes, batch_size)
                else:
                    pil_image = annotate_clevrtex_changes(pil_image, changes, batch_size, val_dataset)
                
                # Add success and failure rate information to the image
                draw = ImageDraw.Draw(pil_image)
                try:
                    font = ImageFont.truetype("DejaVuSans-Bold.ttf", 16)
                except IOError:
                    font = ImageFont.load_default()
                
                # Create metrics text with both success and failure rates
                metrics_text = [
                    f"Success Rate: {success_rate:.1f}%",
                    f"Failure Rate: {failure_rate:.1f}%"
                ]
                
                # Calculate size needed for text
                text_heights = [draw.textbbox((0, 0), text, font=font)[3] for text in metrics_text]
                text_widths = [draw.textbbox((0, 0), text, font=font)[2] for text in metrics_text]
                max_width = max(text_widths)
                total_height = sum(text_heights) + 10  # 10px spacing between lines
                
                # Create semi-transparent background
                overlay = Image.new('RGBA', pil_image.size, (0, 0, 0, 0))
                overlay_draw = ImageDraw.Draw(overlay)
                padding = 8
                overlay_draw.rectangle(
                    [10, 10, 10 + max_width + 2*padding, 10 + total_height + 2*padding],
                    fill=(0, 0, 0, 180)
                )
                pil_image = Image.alpha_composite(pil_image.convert('RGBA'), overlay).convert('RGB')
                draw = ImageDraw.Draw(pil_image)
                
                # Add metrics text
                y_pos = 10 + padding
                for i, text in enumerate(metrics_text):
                    color = (0, 255, 0) if i == 0 and success_rate > 50 else (255, 100, 0) if i == 0 else (255, 50, 50) if failure_rate > 10 else (200, 200, 200)
                    draw.text((10 + padding, y_pos), text, fill=color, font=font)
                    y_pos += text_heights[i] + 5
            
            images.append(pil_image)
            
            image_count += pixel_values.shape[0]
            if image_count >= n_images:
                break

    # Clean up resources
    del pipeline
    torch.cuda.empty_cache()
    
    # Calculate overall success and failure rates across all batches
    if overall_target_count > 0:
        overall_success_rate = (overall_success_count / overall_target_count) * 100
        overall_failure_rate = (overall_failure_count / overall_unchanged_count) * 100
        print(f"=== Attribute Modification Results ===")
        print(f"Success Rate: {overall_success_rate:.1f}% ({overall_success_count}/{overall_target_count} attributes modified as intended)")
        print(f"Failure Rate: {overall_failure_rate:.1f}% ({overall_failure_count}/{overall_unchanged_count} attributes changed unintentionally)")
    
    return images

def generate_reconstruction_sequence(
    val_dataset,
    unet,
    vae,
    scheduler,
    target_steps=[1, 50, 125, 250, 500, 750],
    n_images=16,
    device="cuda:0",
    weight_dtype=torch.bfloat16,
    backbone=None,
    slot_attn=None,
    guidance_scale=1.0,
    conditioning_scale=1.0,  # Added conditioning_scale parameter
    mask_attributes=True,  # Added parameter for attribute masking
    visualize_attn=False,
    dataset_type="celeba",  # Added parameter to specify dataset type
    figsize=None,
    save_path=None,
    show_plot=True
):
    """
    Generate a sequence of reconstructions with different target steps, stacked vertically.
    
    Args:
        val_dataset: Dataset containing validation images
        unet: UNet model for diffusion
        vae: VAE model for encoding/decoding images
        scheduler: Diffusion scheduler
        target_steps (List[int]): List of timesteps to use for reconstruction
        n_images (int): Number of images to generate
        device (str): Device to run generation on
        weight_dtype: Data type for model weights
        backbone: Feature extractor backbone for slot attention
        slot_attn: Optional slot attention module
        guidance_scale: Scale factor for classifier/unconditional guidance
        conditioning_scale: Scale factor for slot conditioning gradient strength
        mask_attributes (bool): Whether to mask attributes for focused editing
        visualize_attn (bool): Whether to visualize attention maps when using slot attention
        dataset_type (str): Type of dataset - 'celeba' or 'clevrtex'
        figsize (tuple): Figure size for the plot, if None, automatically determined
        save_path (str): Path to save the visualization, if None, won't save
        show_plot (bool): Whether to display the plot
    
    Returns:
        tuple: (PIL.Image, matplotlib.figure.Figure) - The combined image and the figure if created
    """
    all_generated = []
    
    # Generate reconstructions for each target step
    for step in target_steps:
        images = generate_images(
            val_dataset,
            unet,
            vae,
            scheduler,
            mode='reconstruct',
            target_step=step,
            n_images=n_images,
            device=device,
            weight_dtype=weight_dtype,
            backbone=backbone,
            slot_attn=slot_attn,
            guidance_scale=guidance_scale,
            conditioning_scale=conditioning_scale,  # Pass conditioning_scale to generate_images
            mask_attributes=mask_attributes,  # Pass the masking parameter
            visualize_attn=visualize_attn,
            dataset_type=dataset_type,
        )
        
        # Extract only generated images (exclude original and VAE reconstruction)
        for img in images:
            width = img.size[0]
            height = img.size[1] // 3
            generated = img.crop((0, height * 2, width, height * 3))
            all_generated.append(generated)
    
    # Calculate dimensions for the final image
    num_steps = len(target_steps)
    single_width = all_generated[0].size[0]
    single_height = all_generated[0].size[1]
    
    # Create a new image with all sequences stacked vertically
    final_image = Image.new('RGB', (single_width, single_height * num_steps))
    
    # Paste each timestep's images
    for i, img in enumerate(all_generated[:num_steps]):
        final_image.paste(img, (0, i * single_height))
    
    # Handle visualization
    fig = None
    if show_plot or save_path:
        # Calculate appropriate figure size if not provided
        if figsize is None:
            aspect_ratio = single_width / (single_height * num_steps)
            figsize = (20, 20 / aspect_ratio)
            
        fig = plt.figure(figsize=figsize)
        plt.imshow(final_image)
        plt.axis('off')
        
        # Add timestep labels
        for i, step in enumerate(target_steps):
            plt.text(-0.05, (num_steps - i - 0.5) / num_steps, f't={step}', 
                     transform=plt.gca().transAxes, va='center', fontsize=12, 
                     fontweight='bold')
        
        plt.tight_layout()
        
        # Save the figure if path is provided
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
        
        # Show or close the plot based on show_plot parameter
        if show_plot:
            plt.show()
        else:
            plt.close(fig)
    
    return final_image, fig

def add_latent_noise(pixel_values, vae, noise_scheduler, device, noise_strength=1.0):
    """
    Add noise to images in a batch by encoding to latent space, adding noise, then decoding.
    
    Args:
        pixel_values (torch.Tensor): Batch of images [B, C, H, W]
        vae (AutoencoderKL): VAE model for encoding/decoding
        noise_scheduler (DDPMScheduler): Noise scheduler
        device: Device to run operations on
        noise_strength (float): Controls noise strength (0-1)
        
    Returns:
        tuple: (noisy_pixel_values, timesteps)
    """
    with torch.no_grad():
        # Move to device if needed
        if (pixel_values.device != device):
            pixel_values = pixel_values.to(device)
        
        # Encode images to latent space
        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor
        
        # Sample random timesteps
        batch_size = latents.shape[0]
        max_timestep = int(noise_scheduler.config.num_train_timesteps * noise_strength)
        timesteps = torch.randint(
            0, max_timestep, (batch_size,), device=latents.device
        ).long()
        
        # Add noise to latents
        noise = torch.randn_like(latents)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        
        # Decode back to image space
        noisy_pixel_values = vae.decode(noisy_latents / vae.config.scaling_factor).sample
        
        return noisy_pixel_values, timesteps

def generate_counterfactuals(
    val_dataset,
    unet,
    vae,
    scheduler,
    classifier=None,
    target_step=500,
    num_inference_steps=200,
    guidance_scale=1.0,
    conditioning_scale=1.0,
    backbone=None,
    slot_attn=None,
    weight_dtype=torch.bfloat16,
    n_images=16,
    batch_size=8,
    preferred_attrs=None,
    device="cuda:0",
    visualize_attn=True,
    mask_attributes=True,
    dataset_type="celeba",  # Added dataset type parameter
    save_all_modifications=True,  # New parameter to control behavior
    calculate_fid=True,  # New parameter to control FID calculation
):
    """
    Generate counterfactual images by modifying specific attributes of input images.
    For each preferred attribute, the function processes images from the dataset until 
    reaching the target number of counterfactuals.
    
    Args:
        val_dataset: Dataset containing validation images
        unet: UNet model for diffusion
        vae: VAE model for encoding/decoding images
        scheduler: Diffusion scheduler
        classifier: Classifier model for guidance and evaluation
        target_step: Target diffusion step for reconstruction
        num_inference_steps: Number of inference steps for diffusion
        guidance_scale: Scale factor for classifier guidance
        conditioning_scale: Scale factor for slot conditioning gradient strength
        backbone: Feature extractor backbone for slot attention
        slot_attn: Optional slot attention module
        weight_dtype: Data type for model weights
        n_images: Number of counterfactual images to generate per attribute
        batch_size: Batch size for processing
        preferred_attrs: List of attribute names to focus on for counterfactuals
        device: Device to run generation on
        visualize_attn: Whether to visualize attention maps when using slot attention
        mask_attributes: Whether to mask attributes for focused editing
        save_all_modifications: Whether to save all modifications or just valid ones
        calculate_fid: Whether to calculate FID score between original and counterfactual images
    
    Returns:
        tuple: (results, metrics, fid_scores)
            - results: Dictionary mapping attribute names to lists of tuples (original_image, counterfactual_image, change_info)
            - metrics: Dictionary containing success/failure statistics for each attribute
            - fid_scores: Dictionary containing FID scores for each attribute
    """
    try:
        import torch_fidelity
    except ImportError:
        print("Warning: torch-fidelity package not found. FID calculation will be skipped.")
        print("To install, run: pip install torch-fidelity")
        calculate_fid = False
    
    if preferred_attrs is None:
        preferred_attrs = [
            'Smiling', 'Mouth_Slightly_Open', 'Sideburns',
            'Blond_Hair', 'Wavy_Hair', 'Straight_Hair', 
            'Black_Hair', "Pale_Skin", "Pointy_Nose", 
        ]
    
    # Validate that classifier is provided
    if classifier is None:
        raise ValueError("Classifier is required for generating counterfactuals")
    
    # Configure and initialize scheduler
    scheduler_args = {"variance_type": "fixed_small"} if "variance_type" in scheduler.config else {}
    import importlib
    scheduler = getattr(importlib.import_module("diffusers"), "DDIMScheduler").from_config(
        scheduler.config, **scheduler_args
    )
    
    # Select appropriate pipeline based on model configuration
    if slot_attn:
        # pipeline_class = SlotConditionedReconstructionPipeline
        raise NotImplementedError("SlotConditionedReconstructionPipeline not available in this code_appendix version")
    else:
        # pipeline_class = UnconditionalReconstructionPipeline
        raise NotImplementedError("UnconditionalReconstructionPipeline not available in this code_appendix version")
    
    # Prepare pipeline arguments
    pipe_args = {
        "vae": vae,
        "unet": unet,
        "scheduler": scheduler,
    }
    
    # Initialize pipeline
    pipeline = pipeline_class(**pipe_args).to(device)
    generator = torch.Generator(device=device).manual_seed(42)
    
    # Dictionary to store results
    results = {attr: [] for attr in preferred_attrs}
    
    # Dictionary to store metrics
    metrics = {
        attr: {
            'success_count': 0,
            'total_attempts': 0,
            'failure_count': 0,
            'unchanged_count': 0
        } for attr in preferred_attrs
    }
    
    # Dictionary to store image tensors for FID calculation
    if calculate_fid:
        fid_originals = {attr: [] for attr in preferred_attrs}
        fid_counterfactuals = {attr: [] for attr in preferred_attrs}
    
    # Set up dataloader
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,  # Shuffle to get diverse samples
    )
    
    # For each attribute
    for attr in preferred_attrs:
        images_generated = 0
        attr_idx = val_dataset.attr_names.index(attr) if attr in val_dataset.attr_names else -1
        
        if attr_idx == -1:
            print(f"Warning: Attribute '{attr}' not found in dataset. Skipping.")
            continue
        
        print(f"\nProcessing attribute: {attr} (index: {attr_idx})")
        
        # Process batches until we get enough counterfactuals for this attribute
        for batch in val_dataloader:
            if images_generated >= n_images:
                break
                
            pixel_values = batch["pixel_values"].to(device=device, dtype=weight_dtype)
            gt_labels = batch['labels'].to(device=device, dtype=weight_dtype) if 'labels' in batch else None
            
            if gt_labels is None:
                continue
                
            # Get classifier predictions
            with torch.no_grad():
                outputs = classifier(pixel_values)
                if hasattr(outputs, 'logits'):
                    logits = outputs.logits
                else:
                    logits = outputs
                
                # Calculate probabilities using sigmoid for binary classification
                probs = torch.sigmoid(logits)
                preds = (logits > 0).float()
            
            # Create target labels by flipping only the current attribute
            target_labels = gt_labels.clone()
            
            # Track which images in the batch will be modified
            valid_indices = []
            all_indices = []
            
            # Process all images in batch
            for i in range(pixel_values.shape[0]):
                # Always modify the target attribute
                target_labels[i, attr_idx] = 1 - gt_labels[i, attr_idx]
                all_indices.append(i)
                
                # But track which ones have matching predictions for metrics
                if preds[i, attr_idx] == gt_labels[i, attr_idx]:
                    valid_indices.append(i)
            
            # Create changes list for all images
            changes = []
            for idx in all_indices:
                changes.append({
                    'sample_idx': idx,
                    'attr_idx': attr_idx,
                    'attr_name': attr,
                    'original': 'Yes' if gt_labels[idx, attr_idx].item() == 1 else 'No',
                    'modified': 'Yes' if target_labels[idx, attr_idx].item() == 1 else 'No',
                    'prediction_matches': idx in valid_indices,
                    'gt_label': gt_labels[idx, attr_idx].item(),
                    'orig_prediction': preds[idx, attr_idx].item(),
                    'orig_probability': probs[idx, attr_idx].item(),  # Store raw probability
                    'orig_logit': logits[idx, attr_idx].item(),  # Store raw logit value
                    'target_label': target_labels[idx, attr_idx].item()
                })
            
            # Create attribute mask focusing only on the target attribute
            attribute_mask = None
            if mask_attributes:
                attribute_mask = torch.zeros((pixel_values.shape[0], gt_labels.shape[1]), device=device)
                for idx in all_indices:
                    attribute_mask[idx, attr_idx] = 1.0
            
            # Process slots if available
            slots = None
            attn = None
            if slot_attn is not None:
                feat = backbone(pixel_values)
                slots, attn = slot_attn(feat[:, None])
                slots = slots[:, 0]
            
            # Configure pipeline parameters
            pipeline_kwargs = {
                "images": pixel_values,
                "target_step": target_step,
                "reconstruction_steps": num_inference_steps,
                "guidance_scale": guidance_scale,
                "classifier": classifier,
                "target_labels": target_labels,
                "attribute_mask": attribute_mask,
                "generator": generator,
                "output_type": "pt",
            }
            
            if slot_attn:
                pipeline_kwargs["prompt_embeds"] = slots
                pipeline_kwargs["conditioning_scale"] = conditioning_scale
            
            # Generate counterfactual images
            counterfactual_images = pipeline(**pipeline_kwargs).images
            
            # Now get the predictions for the counterfactual images to store with the results
            with torch.no_grad():
                cf_outputs = classifier(counterfactual_images)
                if hasattr(cf_outputs, 'logits'):
                    cf_logits = cf_outputs.logits
                    cf_probs = torch.sigmoid(cf_logits)
                    cf_preds = (cf_logits > 0).float()
                else:
                    # code for handling SlotClassifier
                    pass
                
            
            # Calculate success and failure rates only for valid indices (where prediction matched ground truth)
            # This ensures our metrics remain accurate while we still save all modificationslculate_attribute_success_rate(
            valid_changes = [change for change in changes if change['prediction_matches']]
            
            # Only calculate metrics on valid indices (prediction matches ground truth)
            if valid_indices:
                success_rate, failure_rate, success_count, failure_count, total_target_attrs, total_unchanged_attrs = calculate_attribute_success_rate(
                    valid_changes, classifier, counterfactual_images, pixel_values, device)
                # Update metrics
                metrics[attr]['success_count'] += success_count
                metrics[attr]['total_attempts'] += total_target_attrs
                metrics[attr]['failure_count'] += failure_count
                metrics[attr]['unchanged_count'] += total_unchanged_attrs
            
            # Store results for all modifications (or just valid ones depending on save_all_modifications flag)
            indices_to_save = all_indices if save_all_modifications else valid_indices
            for idx in indices_to_save:
                if images_generated >= n_images:
                    break
                    
                # Extract original and counterfactual images
                original = pixel_values[idx]
                counterfactual = counterfactual_images[idx]
                change_info = next((c for c in changes if c['sample_idx'] == idx), None)
                
                # Add counterfactual prediction information to change_info
                if change_info:
                    # For the specific attribute
                    change_info['cf_prediction'] = cf_preds[idx, attr_idx].item()
                    change_info['cf_probability'] = cf_probs[idx, attr_idx].item()  # Store raw probability
                    change_info['cf_logit'] = cf_logits[idx, attr_idx].item()  # Store raw logit value
                    change_info['target_success'] = (cf_preds[idx, attr_idx].item() == target_labels[idx, attr_idx].item())
                    
                    # Also track predictions for all other attributes
                    other_attrs = {
                        'orig_all_preds': preds[idx].detach().cpu().tolist(),
                        'orig_all_probs': probs[idx].detach().cpu().tolist(),  # Store all probabilities
                        'orig_all_logits': logits[idx].detach().cpu().tolist(),  # Store all logits
                        'cf_all_preds': cf_preds[idx].detach().cpu().tolist(),
                        'cf_all_probs': cf_probs[idx].detach().cpu().tolist(),  # Store all probabilities
                        'cf_all_logits': cf_logits[idx].detach().cpu().tolist(),  # Store all logits
                        'gt_all_labels': gt_labels[idx].detach().cpu().tolist(),
                    }
                    change_info.update(other_attrs)
                
                # Convert tensors to PIL images for storage
                original_img = (original * 0.5 + 0.5).clamp(0, 1)
                counterfactual_img = counterfactual.clamp(0, 1)
                
                # Add to results
                results[attr].append((original_img, counterfactual_img, change_info))
                
                # Store images for FID calculation
                if calculate_fid:
                    # Convert normalized tensors to range [0, 255] for FID calculation
                    fid_originals[attr].append(original_img.mul(255).byte().cpu())
                    fid_counterfactuals[attr].append(counterfactual_img.mul(255).byte().cpu())
                
                images_generated += 1
        
        # Calculate and print success rate for this attribute
        success_rate = (metrics[attr]['success_count'] / metrics[attr]['total_attempts']) * 100 if metrics[attr]['total_attempts'] > 0 else 0
        failure_rate = (metrics[attr]['failure_count'] / metrics[attr]['unchanged_count']) * 100 if metrics[attr]['unchanged_count'] > 0 else 0
        
        print(f"Attribute: {attr}")
        print(f"  Success Rate: {success_rate:.1f}% ({metrics[attr]['success_count']}/{metrics[attr]['total_attempts']})")
        print(f"  Failure Rate: {failure_rate:.1f}% ({metrics[attr]['failure_count']}/{metrics[attr]['unchanged_count']})")
    
    # Calculate FID scores - overall, not per attribute
    fid_score = float('nan')
    if calculate_fid:
        try:
            import os
            import tempfile
            from torch_fidelity import calculate_metrics
            
            print("\nCalculating overall FID score...")
            
            # Collect all originals and counterfactuals across attributes
            all_originals = []
            all_counterfactuals = []
            
            for attr in preferred_attrs:
                if attr in results and results[attr]:
                    for orig, cf, _ in results[attr]:
                        all_originals.append(orig)
                        all_counterfactuals.append(cf)
            
            # Need at least 8 images for FID
            if len(all_originals) >= 8:
                # Create temporary directories for images
                with tempfile.TemporaryDirectory() as orig_dir, tempfile.TemporaryDirectory() as cf_dir:
                    # Save original images
                    for i, img_tensor in enumerate(all_originals):
                        if isinstance(img_tensor, torch.Tensor):
                            img_np = img_tensor.cpu().permute(1, 2, 0).numpy()
                            img = Image.fromarray((img_np.clip(0, 1) * 255).astype(np.uint8))
                        else:
                            img = img_tensor
                        img.save(os.path.join(orig_dir, f"{i:05d}.png"))
                    
                    # Save counterfactual images
                    for i, img_tensor in enumerate(all_counterfactuals):
                        if isinstance(img_tensor, torch.Tensor):
                            img_np = img_tensor.cpu().permute(1, 2, 0).numpy()
                            img = Image.fromarray((img_np.clip(0, 1) * 255).astype(np.uint8))
                        else:
                            img = img_tensor
                        img.save(os.path.join(cf_dir, f"{i:05d}.png"))
                    
                    try:
                        # Calculate FID with more explicit arguments
                        metrics_dict = calculate_metrics(
                            input1=orig_dir,
                            input2=cf_dir,
                            cuda=torch.cuda.is_available(),
                            isc=False,
                            fid=True,
                            kid=False,
                            prc=False,
                            verbose=True,
                            samples_find_deep=True,
                            samples_find_ext='png'
                        )
                        
                        # The key is 'frechet_inception_distance' not 'fid'
                        if 'frechet_inception_distance' in metrics_dict:
                            fid_score = metrics_dict['frechet_inception_distance']
                            print(f"\nOverall FID: {fid_score:.4f}")
                        else:
                            print(f"FID calculation failed, keys in metrics_dict: {list(metrics_dict.keys())}")
                    except Exception as inner_e:
                        print(f"FID calculation error: {str(inner_e)}")
            else:
                print(f"Not enough images for FID calculation (min 8, got {len(all_originals)})")
        
        except Exception as e:
            print(f"Error calculating FID scores: {str(e)}")
            import traceback
            traceback.print_exc()  # Print full traceback for better debugging
    
    # Create a simple fid_scores dict with just the overall score
    fid_scores = {'overall': fid_score}
    
    # Clean up resources
    del pipeline
    torch.cuda.empty_cache()
    
    return results, metrics, fid_scores




def hungarian_huber_loss(x, y, coord_scale=1.):
    n_objs = x.shape[1]
    pairwise_cost = F.smooth_l1_loss(torch.unsqueeze(y, -2).expand(-1, -1, n_objs, -1), torch.unsqueeze(x, -3).expand(-1, n_objs, -1, -1), reduction='none').mean(dim=-1)
    indices = np.array(list(map(optimize.linear_sum_assignment, pairwise_cost.detach().cpu().numpy())))
    transposed_indices = np.transpose(indices, axes=(0, 2, 1))
    final_costs = torch.gather(pairwise_cost, dim=-1, index=torch.LongTensor(transposed_indices).to(pairwise_cost.device))[:, :, 1]
    return final_costs.sum(dim=1).mean()


def _compute_hungarian_matching(x, y):
    """Helper function to compute Hungarian matching using Huber loss.
    
    Args:
        x: Tensor of shape [batch_size, n_points, dim_points]
        y: Tensor of shape [batch_size, n_points, dim_points]
    """
    n_objs = x.shape[1]
    pairwise_cost = F.smooth_l1_loss(
        torch.unsqueeze(y, -2).expand(-1, -1, n_objs, -1),
        torch.unsqueeze(x, -3).expand(-1, n_objs, -1, -1),
        reduction='none'
    ).mean(dim=-1)
    
    indices = np.array(list(map(scipy.optimize.linear_sum_assignment, pairwise_cost.detach().cpu().numpy())))
    transposed_indices = np.transpose(indices, axes=(0, 2, 1))
    return torch.LongTensor(transposed_indices).to(pairwise_cost.device)

def hungarian_accuracy(preds, targets):
    """Calculates accuracy for each property after Hungarian matching."""
    # Property slices matching SlotClassifier
    slices = {
        'shape': slice(0, 4),
        'size': slice(4, 7),
        'material': slice(7, 67),
        'coords': slice(67, 70),
        'visibility': slice(70, 71)
    }
    
    # Get matching indices using the same method as hungarian_huber_loss
    indices = _compute_hungarian_matching(preds, targets)
    
    # Reorder predictions according to matching
    batch_size = preds.shape[0]
    batch_idx = torch.arange(batch_size, device=preds.device)
    matched_preds = torch.gather(preds, dim=1, index=indices[:, :, 1].unsqueeze(-1).expand(-1, -1, preds.shape[-1]))
    
    # Create visibility mask (True where target visibility > 0.5)
    vis_mask = (targets[..., slices['visibility']] > 0.5).squeeze(-1)
    
    accuracies = {}
    for prop_name, prop_slice in slices.items():
        if prop_name == 'visibility':
            continue
        elif prop_name == 'coords':
            coords_mse = F.mse_loss(
                matched_preds[..., prop_slice],
                targets[..., prop_slice],
                reduction='none'
            ).mean(dim=-1)
            masked_mse = coords_mse[vis_mask]
            accuracies[prop_name] = masked_mse.mean().item() if len(masked_mse) > 0 else 0.0
        else:
            pred_cats = matched_preds[..., prop_slice].argmax(dim=-1)
            true_cats = targets[..., prop_slice].argmax(dim=-1)
            correct = (pred_cats == true_cats).float()
            masked_correct = correct[vis_mask]
            accuracies[prop_name] = masked_correct.mean().item() if len(masked_correct) > 0 else 0.0
    
    categorical_props = [p for p in accuracies.keys() if p != 'coords']
    accuracies['mean'] = sum(accuracies[p] for p in categorical_props) / len(categorical_props)
    
    return accuracies

def average_precision_clevrtex(pred, attributes, distance_threshold ,ClevrTexDataset):
  """Computes the average precision for CLEVR.

  This function computes the average precision of the predictions specifically
  for the CLEVR dataset. First, we sort the predictions of the model by
  confidence (highest confidence first). Then, for each prediction we check
  whether there was a corresponding object in the input image. A prediction is
  considered a true positive if the discrete features are predicted correctly
  and the predicted position is within a certain distance from the ground truth
  object.

  Args:
    pred: Tensor of shape [batch_size, num_elements, dimension] containing
      predictions. The last dimension is expected to be the confidence of the
      prediction.
    attributes: Tensor of shape [batch_size, num_elements, dimension] containing
      ground-truth object properties.
    distance_threshold: Threshold to accept match. -1 indicates no threshold.

  Returns:
    Average precision of the predictions.
  """
  pred = pred.detach().cpu().numpy()
  attributes = attributes.detach().cpu().numpy()
  [batch_size, _, element_size] = attributes.shape
  [_, predicted_elements, _] = pred.shape

  def unsorted_id_to_image(detection_id, predicted_elements):
    """Find the index of the image from the unsorted detection index."""
    return int(detection_id // predicted_elements)

  flat_size = batch_size * predicted_elements
  flat_pred = np.reshape(pred, [flat_size, element_size])
  sort_idx = np.argsort(flat_pred[:, -1], axis=0)[::-1]  # Reverse order.

  sorted_predictions = np.take_along_axis(
      flat_pred, np.expand_dims(sort_idx, axis=1), axis=0)
  idx_sorted_to_unsorted = np.take_along_axis(
      np.arange(flat_size), sort_idx, axis=0)

  def process_targets(target):
    """Unpacks the target into the CLEVRTex properties."""
    coords = target[75:78]
    object_size = np.argmax(target[4:7])
    material = np.argmax(target[15:75])
    shape = np.argmax(target[:4])
    color = np.argmax(target[7:15])
    real_obj = target[78]
    return coords, object_size, material, shape, color, real_obj

  true_positives = np.zeros(sorted_predictions.shape[0])
  false_positives = np.zeros(sorted_predictions.shape[0])

  detection_set = set()

  for detection_id in range(sorted_predictions.shape[0]):
    # Extract the current prediction.
    current_pred = sorted_predictions[detection_id, :]
    # Find which image the prediction belongs to. Get the unsorted index from
    # the sorted one and then apply to unsorted_id_to_image function that undoes
    # the reshape.
    original_image_idx = unsorted_id_to_image(
        idx_sorted_to_unsorted[detection_id], predicted_elements)
    # Get the ground truth image.
    gt_image = attributes[original_image_idx, :, :]

    # Initialize the maximum distance and the id of the groud-truth object that
    # was found.
    best_distance = 10000
    best_id = None

    # Unpack the prediction by taking the argmax on the discrete attributes.
    (pred_coords, pred_object_size, pred_material, pred_shape, pred_color,
     _) = process_targets(current_pred)

    # Loop through all objects in the ground-truth image to check for hits.
    for target_object_id in range(gt_image.shape[0]):
      target_object = gt_image[target_object_id, :]
      # Unpack the targets taking the argmax on the discrete attributes.
      (target_coords, target_object_size, target_material, target_shape,
       target_color, target_real_obj) = process_targets(target_object)
      # Only consider real objects as matches.
      if target_real_obj:
        # For the match to be valid all attributes need to be correctly
        # predicted.
        pred_attr = [pred_object_size, pred_material, pred_shape, pred_color]
        target_attr = [
            target_object_size, target_material, target_shape, target_color]
        match = pred_attr == target_attr
        if match:
          # If a match was found, we check if the distance is below the
          # specified threshold. Recall that we have rescaled the coordinates
          # in the dataset from [-3, 3] to [0, 1], both for `target_coords` and
          # `pred_coords`. To compare in the original scale, we thus need to
          # multiply the distance values by 6 before applying the norm.
          distance = np.linalg.norm((target_coords - pred_coords) * 6.)

          # If this is the best match we've found so far we remember it.
          if distance < best_distance:
            best_distance = distance
            best_id = target_object_id
    if best_distance < distance_threshold or distance_threshold == -1:
      # We have detected an object correctly within the distance confidence.
      # If this object was not detected before it's a true positive.
      if best_id is not None:
        if (original_image_idx, best_id) not in detection_set:
          true_positives[detection_id] = 1
          detection_set.add((original_image_idx, best_id))
        else:
          false_positives[detection_id] = 1
      else:
        false_positives[detection_id] = 1
    else:
      false_positives[detection_id] = 1
  accumulated_fp = np.cumsum(false_positives)
  accumulated_tp = np.cumsum(true_positives)
  recall_array = accumulated_tp / np.sum(attributes[:, :, -1])
  precision_array = np.divide(accumulated_tp, (accumulated_fp + accumulated_tp))

  return compute_average_precision(
      np.array(precision_array, dtype=np.float32),
      np.array(recall_array, dtype=np.float32))


def compute_average_precision(precision, recall):
  """Computation of the average precision from precision and recall arrays."""
  recall = recall.tolist()
  precision = precision.tolist()
  recall = [0] + recall + [1]
  precision = [0] + precision + [0]

  for i in range(len(precision) - 1, -0, -1):
    precision[i - 1] = max(precision[i - 1], precision[i])

  indices_recall = [
      i for i in range(len(recall) - 1) if recall[1:][i] != recall[:-1][i]
  ]

  average_precision = 0.
  for i in indices_recall:
    average_precision += precision[i + 1] * (recall[i + 1] - recall[i])
  return average_precision


def random_colors(N, randomize=False, rng=42):
    colors = distinctipy.get_colors(N, rng=rng)
    if randomize:
        random.shuffle(colors)
    return colors


class ColorMask(object):
    def __init__(self, num_slots, log_img_size, norm_mean, 
                 norm_std, rng=42, img_tmp_pth=None, reshape_first=False):
        self.img_tmp_pth = img_tmp_pth
        self.num_slots = num_slots
        self.log_img_size = log_img_size
        self.color = torch.tensor(random_colors(num_slots, randomize=False, rng=rng))
        self.log_image_resize = transforms.Resize(log_img_size,
                                                  interpolation=transforms.InterpolationMode.BILINEAR,
                                                  antialias=True)
        self.norm_mean = norm_mean
        self.norm_std = norm_std
        self.img_unnorm = transforms.Compose([
            transforms.Normalize(mean=[0., 0., 0.],
                                 std=1 / torch.tensor(norm_std)),
            transforms.Normalize(mean=-torch.tensor(norm_mean),
                                 std=[1., 1., 1.])
        ])
        self.reshape_first = reshape_first

    def apply_colormap_on_image(self, org_im, activation, colormap_name, alpha=0.5):
        """
            Apply heatmap on image
        Args:
            org_img (PIL img): Original image
            activation_map (numpy arr): Activation map (grayscale) 0-255
            colormap_name (str): Name of the colormap
        """
        # Get colormap
        color_map = mpl_color_map.get_cmap(colormap_name)
        no_trans_heatmap = color_map(activation)
        # Change alpha channel in colormap to make sure original image is displayed
        heatmap = copy.copy(no_trans_heatmap)
        heatmap[:, :, 3] = alpha
        heatmap = Image.fromarray((heatmap * 255).astype(np.uint8))
        no_trans_heatmap = Image.fromarray((no_trans_heatmap * 255).astype(np.uint8))

        # Apply heatmap on image
        heatmap_on_image = Image.new("RGBA", org_im.size)
        heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA'))
        heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap)
        return no_trans_heatmap, heatmap_on_image

    def _apply_mask(self, image, mask, alpha=0.5, color=None):

        B, C, H, W = image.size()
        B, N, H, W = mask.size()

        image = image.clone()
        mask_only = torch.ones_like(image)

        if color is None:
            color = random_colors(N)

        for n in range(N):
            for c in range(3):
                image[..., c, :, :] = torch.where(
                    mask[:, n] == 1,
                    image[..., c, :, :] * (1 - alpha) + alpha * (color[n][c] if isinstance(color, list) else
                                                                 color[..., n, c][..., None, None]),
                    image[..., c, :, :]
                )
                mask_only[..., c, :, :] = torch.where(
                    mask[:, n] == 1,
                    mask_only[..., c, :, :] * (1 - 0.99) + 0.99 * (color[n][c] if isinstance(color, list) else
                                                                   color[..., n, c][..., None, None]),
                    mask_only[..., c, :, :]
                )
        return image, mask_only

    def get_heatmap(self, img, attn, recon=None, mask_pred_sorted=None, return_all=False):
        '''

                :param img: b, c, h, w
                :param attn: b, s, h, w
                :param name:
                :param global_step:
                :return:
                '''
        img = img.to(torch.device('cpu'), dtype=torch.float32)
        attn = attn.to(torch.device('cpu'), dtype=torch.float32)
        if recon is not None:
            if not isinstance(recon, list):
                recon = [recon]
            recon = [r.to(torch.device('cpu'), dtype=torch.float32) for r in recon]
        if mask_pred_sorted is not None:
            mask_pred_sorted = mask_pred_sorted.to(torch.device('cpu'))
        bs, inp_channel, h, w = img.size()

        img = self.img_unnorm(img).clamp(0., 1.)
        if recon is not None:
            recon = [self.img_unnorm(r).clamp(0., 1.) for r in recon]

        if h > self.log_img_size:
            img = self.log_image_resize(img)
            h, w = img.shape[-2:]

        num_s = attn.size(1)
        # --------------------------------------------------------------------------
        # reshape first to get nicer visualization
        if self.reshape_first and (attn.shape[-2] != img.shape[-2] or attn.shape[-1] != img.shape[-1]):
            attn = transforms.Resize(size=img.shape[-2:], interpolation=transforms.InterpolationMode.BILINEAR,
                                     antialias=True)(attn)
            
        # --------------------------------------------------------------------------
        # --------------------------------------------------------------------------
        # get color map
        if mask_pred_sorted is None:
            mask_pred = (attn.argmax(1, keepdim=True) == torch.arange(attn.size(1))[None, :, None, None]).float()
        else:
            mask_pred = mask_pred_sorted

        if mask_pred.shape[-2] != img.shape[-2] or mask_pred.shape[-1] != img.shape[-1]:
            mask_pred = transforms.Resize(size=img.shape[-2:], interpolation=transforms.InterpolationMode.NEAREST)(
                mask_pred)

        # b c h w
        img_overlay, color_mask = self._apply_mask(img, mask_pred, alpha=0.5, color=self.color)

        # --------------------------------------------------------------------------

        if attn.shape[-2] != img.shape[-2] or attn.shape[-1] != img.shape[-1]:
            attn = transforms.Resize(size=img.shape[-2:], interpolation=transforms.InterpolationMode.BILINEAR,
                                     antialias=True)(attn)

        attn = rearrange(attn, 'b s h w -> b s (h w)')

        attn = rearrange(attn, 'b s h_w -> (b s) h_w').detach().numpy()

        img_reshape = repeat(img, 'b c h w -> c (b s) (h w) ', s=num_s)

        img_pil = transforms.ToPILImage()(img_reshape)

        no_trans_heatmap, heatmap_on_image = self.apply_colormap_on_image(img_pil, attn, 'gray')

        heatmap_on_image = transforms.ToTensor()(heatmap_on_image.convert('RGB'))

        heatmap_on_image = rearrange(heatmap_on_image, 'c (b s) (h w) -> b s c h w', b=bs, c=inp_channel, h=h, w=w)

        grid_image = torch.cat([img[:, None], img_overlay[:, None], heatmap_on_image], dim=1)
        if recon is not None:
            if not isinstance(recon, list):
                recon = [recon]
            recon = [self.log_image_resize(r) if r.shape[-2] != h else r for r in recon]
            grid_image = torch.cat([*[r[:, None] for r in recon if r is not None], grid_image], dim=1)
        grid_image = make_grid(rearrange(grid_image, 'b n c h w -> (b n) c h w'),
                               nrow=grid_image.size(1), padding=1,
                               pad_value=0.8)
        if return_all:
            return grid_image, img_overlay, color_mask, heatmap_on_image
        return grid_image

    def log_heatmap(self, img, attn, recon=None, mask_pred_sorted=None, path=None):
        assert path is not None or self.img_tmp_pth is not None, 'path is None and img_tmp_pth is None'

        grid_image = self.get_heatmap(img, attn, recon, mask_pred_sorted)
        # save_image(grid_image, self.img_tmp_pth)
        ndarr = grid_image.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
        im = Image.fromarray(ndarr)
        img_path = path if path is not None else self.img_tmp_pth
        im.save(img_path, optimize=True, quality=95)
        return img_path