from typing import List, Optional, Union
import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils import deprecate

from typing import List, Optional, Union, Callable, Dict, Any
import torch
import torch.nn.functional as F
from diffusers import DiffusionPipeline, ImagePipelineOutput
from diffusers.utils import deprecate
from src.models.set_utils import hungarian_accuracy, _compute_hungarian_matching

class UnconditionalDiffusionPipeline(DiffusionPipeline):
    def __init__(self, vae, unet, scheduler):
        super().__init__()
        self.register_modules(
            vae=vae,
            unet=unet,
            scheduler=scheduler
        )

    @torch.no_grad()
    def __call__(
        self,
        batch_size: int = 1,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        **kwargs,
    ):
        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size
        width = width or self.unet.config.sample_size

        # 1. Check inputs
        self.check_inputs(height, width)

        # 2. Define call parameters
        batch_size = 1 if batch_size is None else batch_size

        # 3. Prepare latent variables
        latents_shape = (batch_size, self.vae.config.latent_channels, height // 8, width // 8)
        if latents is None:
            latents = torch.randn(
                latents_shape,
                generator=generator,
                device=self.device,
                dtype=self.unet.dtype,
            )
        else:
            if latents.shape != latents_shape:
                raise ValueError(f"Incorrect latents shape. Expected {latents_shape}, got {latents.shape}")
            latents = latents.to(device=self.device, dtype=self.unet.dtype)

        # 4. Set timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # 5. Denoising loop
        for t in self.progress_bar(timesteps):
            # 5.1 Expand latents for classifier-free guidance
            latent_model_input = latents
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

            # 5.2 Predict the noise residual
            noise_pred = self.unet(latent_model_input, t).sample

            # 5.3 Compute previous image: x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # 6. Post-processing
        image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
        image = (image / 2 + 0.5).clamp(0, 1)

        # 7. Convert to PIL
        if output_type == "pil":
            image = self.image_processor.postprocess(image, output_type=output_type)

        if not return_dict:
            return (image,)

        return ImagePipelineOutput(images=image)

    def check_inputs(self, height, width):
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
        

class UnconditionalReconstructionPipeline(DiffusionPipeline):
    def __init__(self, vae, unet, scheduler):
        super().__init__()
        self.register_modules(
            vae=vae,
            unet=unet,
            scheduler=scheduler
        )
        # Cache for gradient calculations
        self._alpha_coefficients = {}
        self._beta_coefficients = {}

    @torch.no_grad()
    def __call__(
        self,
        images: torch.FloatTensor,
        reconstruction_steps: int = 100,
        target_step: Optional[int] = 500,
        guidance_scale: float = 1.0,
        classifier: Optional[Callable] = None,
        target_labels: Optional[torch.Tensor] = None,
        attribute_mask: Optional[torch.Tensor] = None,
        dataset_type: str = "celeba",  # Add dataset_type parameter ('celeba' or 'clevrtex')
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ):
        # 0. Validate inputs
        if classifier is not None and target_labels is None:
            raise ValueError("If classifier is provided, target_labels must also be provided")
        
        if target_labels is not None and target_labels.shape[0] != images.shape[0]:
            raise ValueError("Number of target labels must match batch size")

        # 1. Prepare images and convert to latent space
        images = images.to(device=self.device, dtype=self.vae.dtype)
        latents = self.vae.encode(images).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor

        # 2. Set timesteps
        self.scheduler.set_timesteps(reconstruction_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # If target_step is not provided, use the last timestep
        if target_step is None:
            target_step = timesteps[0]
        # Expand timestamp to match batch size
        target_step = torch.LongTensor([target_step] * images.shape[0]).to(device=self.device)

        # 3. Add noise to reach target timestep
        noise = torch.randn_like(latents)
        noisy_latents = self.scheduler.add_noise(latents, noise, target_step)

        # 4. Prepare for denoising loop
        latents = noisy_latents

        # 5. Denoising loop
        for t in self.progress_bar(timesteps):
            if t > target_step[0]:
                continue

            # Classifier guidance and noise prediction
            if classifier is not None and guidance_scale > 0:
                # Choose the appropriate guidance method based on dataset_type
                if dataset_type.lower() == "clevrtex":
                    latents, noise_pred = self._clevrtex_classifier_guidance(
                        latents, t, classifier, target_labels, guidance_scale, attribute_mask
                    )
                else:  # Default to CelebA
                    latents, noise_pred = self._simplified_classifier_guidance(
                        latents, t, classifier, target_labels, guidance_scale, attribute_mask
                    )
            else:
                latent_model_input = latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                noise_pred = self.unet(latent_model_input, t).sample
            # Compute previous image: x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # 6. Decode latents
        images = self.vae.decode(latents / self.vae.config.scaling_factor).sample
        images = (images / 2 + 0.5).clamp(0, 1)

        # 7. Convert to PIL if needed
        if output_type == "pil":
            images = self.image_processor.postprocess(images, output_type=output_type)

        if not return_dict:
            return (images,)

        return ImagePipelineOutput(images=images)

    def _simplified_classifier_guidance(self, latents, t, classifier, target_labels, guidance_scale, attribute_mask=None):
        # Existing direct gradient method for CelebA
        latents = latents.detach().clone().requires_grad_(True)
        latent_model_input = self.scheduler.scale_model_input(latents, t)
        with torch.enable_grad():
            noise_pred = self.unet(latent_model_input, t).sample
            
            # Try to use scheduler's built-in pred_orig_sample if available
            scheduler_output = self.scheduler.step(noise_pred, t, latents)
            if hasattr(scheduler_output, 'pred_orig_sample'):
                pred_orig_latents = scheduler_output.pred_orig_sample
            else:
                # Fall back to manual calculation if not available (e.g., for DDIM)
                # The formula is x_0 = (x_t - sqrt(1-alpha_t) * noise_pred) / sqrt(alpha_t)
                alphas_cumprod = self.scheduler.alphas_cumprod
                sqrt_alpha_prod = alphas_cumprod[t] ** 0.5
                sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[t]) ** 0.5
                
                # Reshape for broadcasting if needed
                sqrt_alpha_prod = sqrt_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                while len(sqrt_alpha_prod.shape) < len(latents.shape):
                    sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
                
                pred_orig_latents = (latents - sqrt_one_minus_alpha_prod * noise_pred) / sqrt_alpha_prod
            
            decoded = self.vae.decode(pred_orig_latents / self.vae.config.scaling_factor).sample
            clf_output = classifier(decoded)
            if hasattr(clf_output, 'logits'):
                clf_output = clf_output.logits
                
                # Apply mask to focus only on target attributes if provided
                if attribute_mask is not None:
                    # Apply mask to both predictions and target labels
                    masked_output = clf_output * attribute_mask
                    masked_target = target_labels * attribute_mask
                    loss = F.binary_cross_entropy_with_logits(masked_output, masked_target, reduction='sum')
                    # Normalize by number of masked elements to keep scale consistent
                    loss = loss / (attribute_mask.sum() + 1e-8)  # Add small epsilon to avoid div by zero
                else:
                    loss = F.binary_cross_entropy_with_logits(clf_output, target_labels)
            grad = torch.autograd.grad(loss, latents)[0]
        latents = latents - guidance_scale * grad.detach()
        with torch.no_grad():
            noise_pred = self.unet(self.scheduler.scale_model_input(latents, t), t).sample
        return latents.detach(), noise_pred.detach()

    def _clevrtex_classifier_guidance(self, latents, t, classifier, target_labels, guidance_scale, attribute_mask=None):
        # Implementation for ClevrTex with Hungarian matching
        latents = latents.detach().clone().requires_grad_(True)
        latent_model_input = self.scheduler.scale_model_input(latents, t)
        
        with torch.enable_grad():
            noise_pred = self.unet(latent_model_input, t).sample
            
            # Get predicted original image
            scheduler_output = self.scheduler.step(noise_pred, t, latents)
            if hasattr(scheduler_output, 'pred_orig_sample'):
                pred_orig_latents = scheduler_output.pred_orig_sample
            else:
                # Manual calculation
                alphas_cumprod = self.scheduler.alphas_cumprod
                sqrt_alpha_prod = alphas_cumprod[t] ** 0.5
                sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[t]) ** 0.5
                
                sqrt_alpha_prod = sqrt_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                while len(sqrt_alpha_prod.shape) < len(latents.shape):
                    sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
                
                pred_orig_latents = (latents - sqrt_one_minus_alpha_prod * noise_pred) / sqrt_alpha_prod
            
            # Decode latents to images
            decoded = self.vae.decode(pred_orig_latents / self.vae.config.scaling_factor).sample
            
            # Get classifier predictions
            pred_vectors = classifier(decoded)
            
            # Apply Hungarian matching between predictions and target vectors
            indices = _compute_hungarian_matching(pred_vectors, target_labels)
            batch_size = pred_vectors.shape[0]
            
            # Use attribute mask if provided
            if attribute_mask is not None:
                # Create custom loss function focusing on masked properties
                matched_pred_vectors = torch.gather(
                    pred_vectors, 
                    dim=1, 
                    index=indices[:, :, 1].unsqueeze(-1).expand(-1, -1, pred_vectors.shape[-1])
                )
                
                # Calculate masked loss - only consider properties marked in the attribute_mask
                masked_loss = 0
                
                # Handle per-object attributes (shape, size, material)
                if attribute_mask[:, :, :67].sum() > 0:  # Check if any categorical attributes are masked
                    # Calculate cross entropy loss for categorical attributes
                    for i in range(batch_size):
                        for obj_idx in range(matched_pred_vectors.shape[1]):
                            # Apply mask by object and attribute group
                            # Shape: 0-3
                            if attribute_mask[i, obj_idx, 0:4].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 0:4].unsqueeze(0),
                                    target_labels[i, obj_idx, 0:4].argmax(-1).unsqueeze(0)
                                )
                            
                            # Size: 4-6
                            if attribute_mask[i, obj_idx, 4:7].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 4:7].unsqueeze(0),
                                    target_labels[i, obj_idx, 4:7].argmax(-1).unsqueeze(0)
                                )
                            
                            # Material: 7-66
                            if attribute_mask[i, obj_idx, 7:67].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 7:67].unsqueeze(0),
                                    target_labels[i, obj_idx, 7:67].argmax(-1).unsqueeze(0)
                                )
                
                # Handle coordinates (67-69)
                if attribute_mask[:, :, 67:70].sum() > 0:
                    coord_mask = attribute_mask[:, :, 67:70]
                    masked_loss += F.mse_loss(
                        matched_pred_vectors[:, :, 67:70] * coord_mask,
                        target_labels[:, :, 67:70] * coord_mask,
                        reduction='sum'
                    ) / (coord_mask.sum() + 1e-8)
                
                # Handle visibility (70)
                if attribute_mask[:, :, 70].sum() > 0:
                    vis_mask = attribute_mask[:, :, 70].unsqueeze(-1)
                    masked_loss += F.binary_cross_entropy_with_logits(
                        matched_pred_vectors[:, :, 70].unsqueeze(-1) * vis_mask,
                        target_labels[:, :, 70].unsqueeze(-1) * vis_mask,
                        reduction='sum'
                    ) / (vis_mask.sum() + 1e-8)
            else:
                # If no mask, use hungarian loss directly
                # We focus on all properties equally
                matched_pred_vectors = torch.gather(
                    pred_vectors, 
                    dim=1, 
                    index=indices[:, :, 1].unsqueeze(-1).expand(-1, -1, pred_vectors.shape[-1])
                )
                
                # Calculate loss for categorical properties (shape, size, material)
                cat_loss = 0
                for i in range(batch_size):
                    # Shape loss (0-3)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 0:4].reshape(-1, 4),
                        target_labels[i, :, 0:4].argmax(-1)
                    )
                    
                    # Size loss (4-6)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 4:7].reshape(-1, 3),
                        target_labels[i, :, 4:7].argmax(-1)
                    )
                    
                    # Material loss (7-66)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 7:67].reshape(-1, 60),
                        target_labels[i, :, 7:67].argmax(-1)
                    )
                
                # Coordinate loss
                coord_loss = F.mse_loss(matched_pred_vectors[:, :, 67:70], target_labels[:, :, 67:70])
                
                # Visibility loss
                vis_loss = F.binary_cross_entropy_with_logits(
                    matched_pred_vectors[:, :, 70],
                    target_labels[:, :, 70]
                )
                
                # Combine losses
                masked_loss = cat_loss + coord_loss + vis_loss
            
            # Calculate gradients
            grad = torch.autograd.grad(masked_loss, latents)[0]
        
        # Update latents
        latents = latents - guidance_scale * grad.detach()
        
        # Recompute noise prediction with updated latents
        with torch.no_grad():
            noise_pred = self.unet(self.scheduler.scale_model_input(latents, t), t).sample
            
        return latents.detach(), noise_pred.detach()


class SlotConditionedReconstructionPipeline(DiffusionPipeline):
    def __init__(self, vae, unet, scheduler):
        super().__init__()
        self.register_modules(
            vae=vae,
            unet=unet,
            scheduler=scheduler
        )
        # Cache for gradient calculations
        self._alpha_coefficients = {}
        self._beta_coefficients = {}

    @torch.no_grad()
    def __call__(
        self,
        images: torch.FloatTensor,
        prompt_embeds: torch.FloatTensor,
        reconstruction_steps: int = 100,
        target_step: Optional[int] = None,
        guidance_scale: float = 1.0,
        conditioning_scale: float = 1.0,
        classifier: Optional[Callable] = None,
        target_labels: Optional[torch.Tensor] = None,
        attribute_mask: Optional[torch.Tensor] = None,
        dataset_type: str = "celeba",  # Add dataset_type parameter ('celeba' or 'clevrtex')
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
    ):
        # 0. Validate inputs
        if classifier is not None and target_labels is None:
            raise ValueError("If classifier is provided, target_labels must also be provided")
        
        if target_labels is not None and target_labels.shape[0] != images.shape[0]:
            raise ValueError("Number of target labels must match batch size")

        # 2. Convert images to latent space
        latents = self.vae.encode(images).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor

        # 3. Set timesteps
        self.scheduler.set_timesteps(reconstruction_steps, device=self.device)
        timesteps = self.scheduler.timesteps

        # If target_step is not provided, use the last timestep
        if target_step is None:
            target_step = timesteps[0]
        # Expand timestamp to match batch size
        target_step = torch.LongTensor([target_step] * images.shape[0]).to(device=self.device)

        # 4. Add noise to reach target timestep
        noise = torch.randn_like(latents)
        noisy_latents = self.scheduler.add_noise(latents, noise, target_step)

        # 5. Prepare for denoising loop
        latents = noisy_latents
        
        # Make a copy of prompt_embeds that requires gradients
        cond = prompt_embeds.clone().detach()

        # 6. Denoising loop
        for t in self.progress_bar(timesteps):
            if t > target_step[0]:
                continue

            # Classifier guidance and noise prediction
            if classifier is not None and guidance_scale > 0:
                # Choose the appropriate guidance method based on dataset_type
                if dataset_type.lower() == "clevrtex":
                    latents, cond, noise_pred = self._clevrtex_classifier_guidance(
                        latents, t, classifier, target_labels, guidance_scale, 
                        cond, conditioning_scale, attribute_mask
                    )
                else:  # Default to CelebA
                    latents, cond, noise_pred = self._simplified_classifier_guidance(
                        latents, t, classifier, target_labels, guidance_scale, 
                        cond, conditioning_scale, attribute_mask
                    )
            else:
                latent_model_input = latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=cond).sample
            # Compute previous image: x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # Clear cached coefficients
        self._alpha_coefficients.clear()
        self._beta_coefficients.clear()
        
        # 7. Decode latents
        images = self.vae.decode(latents / self.vae.config.scaling_factor).sample
        images = (images / 2 + 0.5).clamp(0, 1)

        # 8. Convert to PIL if needed
        if output_type == "pil":
            images = self.image_processor.postprocess(images, output_type=output_type)

        if not return_dict:
            return (images,)

        return ImagePipelineOutput(images=images)
    
    def _simplified_classifier_guidance(self, latents, t, classifier, target_labels, guidance_scale, 
                                       cond=None, conditioning_scale=1.0, attribute_mask=None):
        # Existing direct gradient method for CelebA
        latents = latents.detach().clone().requires_grad_(True)
        cond_update = cond.detach().clone().requires_grad_(True) if cond is not None and conditioning_scale > 0 else cond
        
        latent_model_input = self.scheduler.scale_model_input(latents, t)
        with torch.enable_grad():
            if cond_update is not None:
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=cond_update).sample
            else:
                noise_pred = self.unet(latent_model_input, t).sample
            
            # Try to use scheduler's built-in pred_orig_sample if available
            scheduler_output = self.scheduler.step(noise_pred, t, latents)
            if hasattr(scheduler_output, 'pred_orig_sample'):
                pred_orig_latents = scheduler_output.pred_orig_sample
            else:
                # Fall back to manual calculation if not available (e.g., for DDIM)
                # The formula is x_0 = (x_t - sqrt(1-alpha_t) * noise_pred) / sqrt(alpha_t)
                alphas_cumprod = self.scheduler.alphas_cumprod
                sqrt_alpha_prod = alphas_cumprod[t] ** 0.5
                sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[t]) ** 0.5
                
                # Reshape for broadcasting if needed
                sqrt_alpha_prod = sqrt_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                while len(sqrt_alpha_prod.shape) < len(latents.shape):
                    sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
                
                pred_orig_latents = (latents - sqrt_one_minus_alpha_prod * noise_pred) / sqrt_alpha_prod
            
            decoded = self.vae.decode(pred_orig_latents / self.vae.config.scaling_factor).sample
            clf_output = classifier(decoded)
            if hasattr(clf_output, 'logits'):
                clf_output = clf_output.logits
                
                # Apply mask to focus only on target attributes if provided
                if attribute_mask is not None:
                    # Apply mask to both predictions and target labels
                    masked_output = clf_output * attribute_mask
                    masked_target = target_labels * attribute_mask
                    loss = F.binary_cross_entropy_with_logits(masked_output, masked_target, reduction='sum')
                    # Normalize by number of masked elements to keep scale consistent
                    loss = loss / (attribute_mask.sum() + 1e-8)  # Add small epsilon to avoid div by zero
                else:
                    loss = F.binary_cross_entropy_with_logits(clf_output, target_labels)
            
            # Calculate gradients for both latents and conditioning embeddings
            grads = torch.autograd.grad(loss, [latents, cond_update] if cond_update.requires_grad else [latents])
            latent_grad = grads[0]
            cond_grad = grads[1] if cond_update.requires_grad else None
            
        # Update latents
        latents = latents - guidance_scale * latent_grad.detach()
        
        # Update conditioning embeddings if needed
        if cond_update.requires_grad and conditioning_scale > 0:
            cond = cond - conditioning_scale * cond_grad.detach()
        
        # Get new noise prediction with updated inputs
        with torch.no_grad():
            noise_pred = self.unet(self.scheduler.scale_model_input(latents, t), t, encoder_hidden_states=cond).sample
        
        return latents.detach(), cond.detach(), noise_pred.detach()

    def _clevrtex_classifier_guidance(self, latents, t, classifier, target_labels, guidance_scale, 
                                    cond=None, conditioning_scale=1.0, attribute_mask=None):
        """
        Implementation for ClevrTex with Hungarian matching and slot conditioning.
        Updated to handle both latent and conditioning updates like _simplified_classifier_guidance
        """
        # Setup inputs with gradients
        latents = latents.detach().clone().requires_grad_(True)
        cond_update = cond.detach().clone().requires_grad_(True) if cond is not None and conditioning_scale > 0 else cond
        
        latent_model_input = self.scheduler.scale_model_input(latents, t)
        
        with torch.enable_grad():
            # Forward pass using slots
            if cond_update is not None:
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=cond_update).sample
            else:
                noise_pred = self.unet(latent_model_input, t).sample
            
            # Get predicted original image
            scheduler_output = self.scheduler.step(noise_pred, t, latents)
            if hasattr(scheduler_output, 'pred_orig_sample'):
                pred_orig_latents = scheduler_output.pred_orig_sample
            else:
                # Manual calculation
                alphas_cumprod = self.scheduler.alphas_cumprod
                sqrt_alpha_prod = alphas_cumprod[t] ** 0.5
                sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[t]) ** 0.5
                
                sqrt_alpha_prod = sqrt_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten().to(device=latents.device, dtype=latents.dtype)
                while len(sqrt_alpha_prod.shape) < len(latents.shape):
                    sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
                
                pred_orig_latents = (latents - sqrt_one_minus_alpha_prod * noise_pred) / sqrt_alpha_prod
            
            # Decode latents to images
            decoded = self.vae.decode(pred_orig_latents / self.vae.config.scaling_factor).sample
            
            # Get classifier predictions
            pred_vectors = classifier(decoded)
            
            # Apply Hungarian matching between predictions and target vectors
            indices = _compute_hungarian_matching(pred_vectors, target_labels)
            batch_size = pred_vectors.shape[0]
            
            # Use attribute mask if provided
            if attribute_mask is not None:
                # Create custom loss function focusing on masked properties
                matched_pred_vectors = torch.gather(
                    pred_vectors, 
                    dim=1, 
                    index=indices[:, :, 1].unsqueeze(-1).expand(-1, -1, pred_vectors.shape[-1])
                )
                
                # Calculate masked loss - only consider properties marked in the attribute_mask
                masked_loss = 0
                
                # Handle per-object attributes (shape, size, material)
                if attribute_mask[:, :, :67].sum() > 0:  # Check if any categorical attributes are masked
                    # Calculate cross entropy loss for categorical attributes
                    for i in range(batch_size):
                        for obj_idx in range(matched_pred_vectors.shape[1]):
                            # Apply mask by object and attribute group
                            # Shape: 0-3
                            if attribute_mask[i, obj_idx, 0:4].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 0:4].unsqueeze(0),
                                    target_labels[i, obj_idx, 0:4].argmax(-1).unsqueeze(0)
                                )
                            
                            # Size: 4-6
                            if attribute_mask[i, obj_idx, 4:7].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 4:7].unsqueeze(0),
                                    target_labels[i, obj_idx, 4:7].argmax(-1).unsqueeze(0)
                                )
                            
                            # Material: 7-66
                            if attribute_mask[i, obj_idx, 7:67].sum() > 0:
                                masked_loss += F.cross_entropy(
                                    matched_pred_vectors[i, obj_idx, 7:67].unsqueeze(0),
                                    target_labels[i, obj_idx, 7:67].argmax(-1).unsqueeze(0)
                                )
                
                # Handle coordinates (67-69)
                if attribute_mask[:, :, 67:70].sum() > 0:
                    coord_mask = attribute_mask[:, :, 67:70]
                    masked_loss += F.mse_loss(
                        matched_pred_vectors[:, :, 67:70] * coord_mask,
                        target_labels[:, :, 67:70] * coord_mask,
                        reduction='sum'
                    ) / (coord_mask.sum() + 1e-8)
                
                # Handle visibility (70)
                if attribute_mask[:, :, 70].sum() > 0:
                    vis_mask = attribute_mask[:, :, 70].unsqueeze(-1)
                    masked_loss += F.binary_cross_entropy_with_logits(
                        matched_pred_vectors[:, :, 70].unsqueeze(-1) * vis_mask,
                        target_labels[:, :, 70].unsqueeze(-1) * vis_mask,
                        reduction='sum'
                    ) / (vis_mask.sum() + 1e-8)
            else:
                # If no mask, use hungarian loss directly
                # We focus on all properties equally
                matched_pred_vectors = torch.gather(
                    pred_vectors, 
                    dim=1, 
                    index=indices[:, :, 1].unsqueeze(-1).expand(-1, -1, pred_vectors.shape[-1])
                )
                
                # Calculate loss for categorical properties (shape, size, material)
                cat_loss = 0
                for i in range(batch_size):
                    # Shape loss (0-3)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 0:4].reshape(-1, 4),
                        target_labels[i, :, 0:4].argmax(-1)
                    )
                    
                    # Size loss (4-6)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 4:7].reshape(-1, 3),
                        target_labels[i, :, 4:7].argmax(-1)
                    )
                    
                    # Material loss (7-66)
                    cat_loss += F.cross_entropy(
                        matched_pred_vectors[i, :, 7:67].reshape(-1, 60),
                        target_labels[i, :, 7:67].argmax(-1)
                    )
                
                # Coordinate loss
                coord_loss = F.mse_loss(matched_pred_vectors[:, :, 67:70], target_labels[:, :, 67:70])
                
                # Visibility loss
                vis_loss = F.binary_cross_entropy_with_logits(
                    matched_pred_vectors[:, :, 70],
                    target_labels[:, :, 70]
                )
                
                # Combine losses
                masked_loss = cat_loss + coord_loss + vis_loss
            
            # Calculate gradients for both latents and conditioning embeddings
            grads = torch.autograd.grad(masked_loss, [latents, cond_update] if cond_update.requires_grad else [latents])
            latent_grad = grads[0]
            cond_grad = grads[1] if cond_update.requires_grad else None
        
        # Update latents
        latents = latents - guidance_scale * latent_grad.detach()
        
        # Update conditioning embeddings if needed
        if cond_update.requires_grad and conditioning_scale > 0:
            cond = cond - conditioning_scale * cond_grad.detach()
        
        # Get new noise prediction with updated inputs
        with torch.no_grad():
            noise_pred = self.unet(self.scheduler.scale_model_input(latents, t), t, encoder_hidden_states=cond).sample
        
        return latents.detach(), cond.detach(), noise_pred.detach()
