import json
import os
import pickle
import torch
import numpy as np
from PIL import Image
import tqdm
import sys

from distrl.models.diffusion_models.base import DiffusionModelBase

DISTRL_PRINT_UNETPARAMETER = os.environ.get('DISTRL_PRINT_UNETPARAMETER', None)
DISTRL_DEBUG_FULLBUFFER = os.environ.get('DISTRL_DEBUG_FULLBUFFER', None)



# Add EDM2 original code path to system path to find dnnlib
script_dir = os.path.dirname(os.path.abspath(__file__))
original_dir = os.path.join(script_dir, 'original')
if original_dir not in sys.path and os.path.isdir(original_dir):
    sys.path.insert(0, original_dir)
    print(f"Added {original_dir} to Python path")

# Import dnnlib module if available or create stub
try:
    import dnnlib
except ImportError:
    print("dnnlib module not found, creating stub module")
    import types
    dnnlib = types.ModuleType('dnnlib')
    sys.modules['dnnlib'] = dnnlib

# Helper function to get parameter stats
def print_model_params(model, model_name) -> str:
    params = list(model.parameters())
    total_params = len(params)

    if total_params > 0:
        # Get first 5 parameters
        first_params = params[:5]
        first_params_str = "\n".join([f"  {i}: {p.shape} | {p.mean().item():.6f} | {p.std().item():.6f}"
                                    for i, p in enumerate(first_params)])

        # Get last 5 parameters
        last_params = params[-5:]
        last_params_str = "\n".join([f"  {total_params-5+i}: {p.shape} | {p.mean().item():.6f} | {p.std().item():.6f}"
                                    for i, p in enumerate(last_params)])

        # Calculate stats for all parameters
        all_params = torch.cat([p.flatten() for p in params])
        mean_all = all_params.mean().item()
        std_all = all_params.std().item()
        min_all = all_params.min().item()
        max_all = all_params.max().item()

        output_str = f"{model_name} Parameter Summary:\n"
        output_str += f"Total parameters: {total_params}\n"
        output_str += f"First 5 parameters (idx | shape | mean | std):\n{first_params_str}\n"
        output_str += f"Last 5 parameters (idx | shape | mean | std):\n{last_params_str}\n"
        output_str += f"All parameters - Mean: {mean_all:.6f}, Std: {std_all:.6f}, Min: {min_all:.6f}, Max: {max_all:.6f}"

        return output_str
    return ""


class EDM2DiffusionModel(DiffusionModelBase):
    """
    Implementation of EDM2 model for AutoCFG framework.
    Adapts EDM2 to work with the same interface as SD15DiffusionModel.
    """

    def __init__(self, net, gnet=None):
        """
        Initialize EDM2DiffusionModel with network and optional guidance network.

        Args:
            net: EDM2 network model
            gnet: Optional guidance network (if different from net)
        """
        self.unet = net
        self.gnet = gnet if gnet is not None else net
        self.pipe = self  # Self-reference as pipe for compatibility

        # Load the ImageNet class names to indices mapping
        script_dir = os.path.dirname(os.path.abspath(__file__))
        simple_json_path = os.path.join(script_dir, '..', '..', '..', '..', 'dataset', 'imagenet1k', 'simple.json')

        try:
            with open(simple_json_path, 'r') as f:
                self.imagenet_classes = json.load(f)
        except FileNotFoundError:
            print(f"Could not find ImageNet class mapping at {simple_json_path}")
            self.imagenet_classes = []

        # Initialize encoder as None, will be set in from_pretrained
        self.encoder = None

        self.unet.in_channels = self.unet.img_channels
        self.unet.out_channels = self.unet.img_channels

    def to(self, device, dtype=None):
        """
        Move model components to specified device and cast to dtype.
        """
        self.to_device(device, dtype)
        return self

    @classmethod
    def from_pretrained(cls, pretrained_model_path, revision=None,
                        weight_dtype=torch.float32, retry_on_error=False,
                        encoder_batch_size=None, verbose=False, **kwargs):
        """
        Create model from pretrained EDM2 weights.

        Args:
            pretrained_model_path: Path to pretrained EDM2 model
            revision: Specific model revision (ignored for EDM2)
            weight_dtype: Data type for model weights
            retry_on_error: Whether to retry loading on error
            encoder_batch_size: Optional batch size for encoder
            verbose: Whether to print verbose information
            **kwargs: Additional arguments

        Returns:
            EDM2DiffusionModel instance
        """
        if retry_on_error:
            print("retry_on_error is not supported for EDM2 model")

        print(f"Loading EDM2 model from {pretrained_model_path}")
        # Custom unpickler to handle missing modules
        class CustomUnpickler(pickle.Unpickler):
            def find_class(self, module, name):
                # Handle cases where module might not exist
                if module == 'dnnlib':
                    # Return dummy class if the real one can't be found
                    return getattr(dnnlib, name, object)
                return super().find_class(module, name)

        # Variable to hold loaded data
        data = None
        encoder = None

        try:
            with open(pretrained_model_path, 'rb') as f:
                unpickler = CustomUnpickler(f)
                data = unpickler.load()
            net = data['ema'].to(weight_dtype)
        except Exception as e:
            print(f"Error loading model: {e}")
            raise e

        # Try to get encoder from loaded data
        if data is not None:
            encoder = data.get('encoder', None)

        # If no encoder found, create a new one
        if encoder is None:
            print("No encoder found, creating a default one")
            # Import StandardRGBEncoder from EDM2's training module
            from .original.training.encoders import StandardRGBEncoder
            encoder = StandardRGBEncoder()

        # Create model instance
        model = cls(net=net)

        # Get device from network
        if hasattr(net, 'device'):
            device = net.device
        elif hasattr(net, 'parameters'):
            device = next(net.parameters()).device
        else:
            device = 'cuda'

        # Initialize encoder
        assert encoder is not None
        encoder.init(device)
        model.encoder = encoder

        return model

    @classmethod
    def from_sft(cls, sft_path, revision=None,
                weight_dtype=torch.float32, retry_on_error=False, **kwargs):
        """
        Create model from SFT (supervised fine-tuned) weights.
        For EDM2, this is the same as from_pretrained.

        Args:
            sft_path: Path to SFT model weights
            revision: Specific model revision (ignored for EDM2)
            weight_dtype: Data type for model weights
            retry_on_error: Whether to retry loading on error
            **kwargs: Additional arguments

        Returns:
            EDM2DiffusionModel instance
        """
        # For EDM2, the SFT loading is the same as pretrained loading
        return cls.from_pretrained(sft_path, revision, weight_dtype, retry_on_error, **kwargs)

    def to_device(self, device, dtype=None):
        """
        Move model components to specified device and cast to dtype.

        Args:
            device: Device to move model to ('cuda', 'cpu', etc)
            dtype: Data type to cast model parameters to

        Returns:
            Self for chaining
        """
        if dtype is not None:
            self.unet.to(device, dtype=dtype)
            if self.unet is not self.gnet:
                self.gnet.to(device, dtype=dtype)
        else:
            self.unet.to(device)
            if self.unet is not self.gnet:
                self.gnet.to(device)
        return self

    def set_gradient(self, enable=True):
        """
        Enable or disable gradients for model components.

        Args:
            enable: Whether to enable gradients

        Returns:
            Self for chaining
        """
        for param in self.unet.parameters():
            param.requires_grad_(enable)

        if self.unet is not self.gnet:
            for param in self.gnet.parameters():
                param.requires_grad_(False)  # Guidance network typically kept frozen

        # Set specific training mode based on gradient status
        if enable:
            self.unet.train()
        else:
            self.unet.eval()

        if self.unet is not self.gnet:
            self.gnet.eval()  # Guidance network typically kept in eval mode

        return self

    def enable_xformers(self):
        """
        Enable xformers memory efficient attention if available.
        Note: EDM2 may not support xformers directly.

        Returns:
            False as EDM2 doesn't have xformers support built in
        """
        logger.warning("xformers optimization not supported for EDM2 model")
        return False

    def get_components(self):
        """
        Return individual model components to match SD15DiffusionModel interface.

        Returns:
            Dictionary containing model components
        """
        return {
            "pipe": self,  # Self as pipe for compatibility
            "unet": self.unet,  # Main network as unet
            "vae": self.encoder if self.encoder is not None else None,  # Use encoder as VAE equivalent
            "text_encoder": None,  # No text encoder in EDM2
            "tokenizer": None  # No tokenizer in EDM2
        }

    def set_scheduler(self, scheduler_type):
        """
        Set the scheduler for the model.
        Note: EDM2 uses its own sampling algorithm and doesn't require an external scheduler.

        Args:
            scheduler_type: Scheduler class or instance (ignored)

        Returns:
            Self for chaining
        """
        logger.warning("EDM2 uses its own sampling algorithm and doesn't require an external scheduler")
        return self

    def _extract_class_from_prompt(self, prompt):
        """
        Extract class name from prompt string like "an image of X" and return the ImageNet index.

        Args:
            prompt: Prompt string or list of strings

        Returns:
            Class index or tensor of class indices
        """
        if isinstance(prompt, list):
            return torch.tensor([self._extract_class_from_prompt(p) for p in prompt])

        # Extract class name from "an image of X"
        if prompt.startswith("an image of "):
            class_name = prompt[len("an image of "):]
        else:
            class_name = prompt

        # Find class in ImageNet classes
        try:
            class_idx = self.imagenet_classes.index(class_name)
        except ValueError:
            print(f"Class '{class_name}' not found in ImageNet classes. Using random class.")
            class_idx = np.random.randint(0, len(self.imagenet_classes))

        return class_idx

    def decode_latents(self, latents):
        """
        Decode the generated latents to images.

        Args:
            latents: Generated latents from the sampling process

        Returns:
            Decoded images as torch.Tensor in [0, 1] range
        """
        # Use encoder's decode method if available
        if self.encoder is not None:
            return self.encoder.decode(latents)
        else:
            raise ValueError("Encoder is not initialized")

    def numpy_to_pil(self, images):
        """
        Convert numpy images to PIL images.

        Args:
            images: Numpy array images in [0, 1] range, shape [batch, height, width, channels]

        Returns:
            List of PIL images
        """
        pil_images = []
        if images.ndim == 4:
            for image in images:
                pil_images.append(Image.fromarray(image.permute(1, 2, 0).cpu().numpy()))
        else:
            pil_images.append(Image.fromarray(images.permute(1, 2, 0).cpu().numpy()))

        return pil_images

    def _get_unet_attr(self, attr_name):
        """
        Get attribute from UNet.
        """
        if hasattr(self.unet, attr_name):
            return getattr(self.unet, attr_name)
        elif hasattr(self.unet, "module"):
            return getattr(self.unet.module, attr_name)
        else:
            raise AttributeError(f"UNet does not have attribute {attr_name}")

    def forward_ddim_stepwise_cfg_scales_with_latents(
        self,
        prompt = None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale_base = 7.5,
        guidance_scale_bias = None,
        negative_prompt = None,
        num_images_per_prompt = 1,
        eta = 1.0,
        generator = None,
        latents = None,
        prompt_embeds = None,
        negative_prompt_embeds = None,
        output_type = None,
        return_dict = True,
        callback = None,
        callback_steps = 1,
        cross_attention_kwargs = None,
        is_ddp = False,
    ):
        """
        Run the EDM2 model to generate images based on text prompts, compatible with StableDiffusionPipelineExtended.
        This method is called by AutoCFG.

        Args:
            prompt: Text prompt or list of prompts (format: "an image of X")
            height: Image height (ignored, uses model's default)
            width: Image width (ignored, uses model's default)
            num_inference_steps: Number of denoising steps
            guidance_scale_base: Base guidance scale
            guidance_scale_bias: Optional guidance scale adjustments per timestep
            negative_prompt: Negative prompt (ignored for EDM2)
            num_images_per_prompt: Number of images to generate per prompt
            eta: Unused parameter (for compatibility)
            generator: Random generator for reproducibility
            latents: Pre-generated noise latents
            prompt_embeds: Pre-generated text embeddings (ignored for EDM2)
            negative_prompt_embeds: Pre-generated negative text embeddings (ignored for EDM2)
            output_type: Output format (ignored, EDM2 always returns tensors)
            return_dict: Whether to return a dict (ignored)
            callback: Optional callback function
            callback_steps: Frequency for callback
            cross_attention_kwargs: Kwargs for cross attention (ignored)
            is_ddp: Whether using DistributedDataParallel (mostly ignored)

        Returns:
            Tuple of (image, guidance_scales_used, latents_list, unconditional_embeds, conditional_embeds)
        """
        device = next(self.unet.parameters()).device

        # Process inputs
        if isinstance(prompt, str):
            prompt = [prompt]
        elif prompt is None:
            logger.warning("No prompt provided, using empty prompt")
            prompt = [""]

        batch_size = len(prompt)

        # Get model resolution
        resolution = self._get_unet_attr("img_resolution")
        height = height or resolution
        width = width or resolution

        # Convert prompts to class labels
        label_indices = [self._extract_class_from_prompt(p) for p in prompt]

        # Create one-hot encoded class labels
        labels = torch.zeros((batch_size, self._get_unet_attr("label_dim")), device=device)
        for i, idx in enumerate(label_indices):
            labels[i, idx] = 1.0

        # Handle multiple images per prompt
        if num_images_per_prompt > 1:
            # Expand all tensors for num_images_per_prompt
            prompt = prompt * num_images_per_prompt
            labels = labels.repeat(num_images_per_prompt, 1)
            batch_size = len(prompt)

        unconditional_embeds = torch.zeros_like(labels)  # Match shape with labels
        conditional_embeds = labels  # Use actual class labels instead of dummy ones

        # Generate timesteps for EDM sampler
        sigma_min = 0.002
        sigma_max = 80
        rho = 7

        # Generate initial noise if not provided
        if latents is None:
            if generator is None:
                seed = torch.randint(0, 2**32 - 1, (1,)).item()
                generator = torch.Generator(device=device).manual_seed(int(seed))

            # Generate noise matching EDM2's expected format
            latents = torch.randn(
                (batch_size, self._get_unet_attr("img_channels"), resolution, resolution),
                device=device,
                generator=generator
            )

        # Create progress bar
        progress_bar = tqdm.tqdm(total=num_inference_steps)

        # Time step discretization for EDM sampler
        step_indices = torch.arange(num_inference_steps, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_inference_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Storage for guidance scales and latents
        guidance_scales_used = []
        latents_list = [latents.detach().cpu()]

        # Main sampling loop
        x_next = latents * t_steps[0]

        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            # Get current guidance scale
            if guidance_scale_bias is not None:
                # Add dynamic adjustment to base scale
                scale_adjustment = guidance_scale_bias[:, i].view(-1, 1, 1, 1)
                current_guidance_scale = guidance_scale_base + scale_adjustment
                # Don't reduce to a scalar with .item() - keep the batch dimension
            else:
                # Create a tensor with batch dimension for the base guidance scale
                current_guidance_scale = torch.full((batch_size, 1, 1, 1), guidance_scale_base, device=device)

            # Record guidance scale - store as a tensor with batch dim to match SD15 format
            guidance_scales_used.append(current_guidance_scale.detach().cpu())

            x_cur = x_next

            # Helper function for guided denoising
            def denoise(x, t_val):
                t_tensor = torch.full([x.shape[0]], t_val, device=device)

                # Get both conditional and unconditional outputs if guidance_scale > 1
                if isinstance(current_guidance_scale, torch.Tensor):
                    # Get unconditional and conditional outputs
                    net_out = self.unet(x, t_tensor, labels)
                    gnet_out = self.gnet(x, t_tensor, None)  # Unconditional

                    # Apply per-batch guidance scale using broadcasting
                    # current_guidance_scale is shape [batch_size, 1, 1, 1]
                    # Reshape the outputs to match the guidance scale shape if needed
                    batch_size = x.shape[0]
                    if net_out.ndim == 4:  # [B, C, H, W]
                        # Interpolate between unconditional and conditional with per-batch guidance
                        guidance_scale_reshaped = current_guidance_scale.reshape(batch_size, 1, 1, 1)
                        return net_out     # TODO: add CFG
                        return gnet_out + guidance_scale_reshaped * (net_out - gnet_out)
                    else:
                        # Fallback for unexpected dimensions
                        return gnet_out.lerp(net_out, current_guidance_scale.mean().item())
                else:
                    # Single guidance scale for all batch items
                    if current_guidance_scale != 1.0:
                        raise NotImplementedError("EDM2 does not CFG at the moment")
                        net_out = self.unet(x, t_tensor, labels)
                        gnet_out = self.gnet(x, t_tensor, None)  # Unconditional
                        return gnet_out.lerp(net_out, current_guidance_scale)
                    else:
                        return self.unet(x, t_tensor, labels)

            # Euler step
            d_cur = (x_cur - denoise(x_cur, t_cur)) / t_cur
            x_next = x_cur + (t_next - t_cur) * d_cur

            # Apply 2nd order correction if not the last step
            # if i < num_inference_steps - 1:
            #     d_prime = (x_next - denoise(x_next, t_next)) / t_next
            #     x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)

            # Store current latents
            latents_list.append(x_next.detach().cpu())

            # Update progress bar
            if i % 5 == 0 or i == len(t_steps) - 2:
                progress_bar.update(min(5, progress_bar.total - progress_bar.n))

            # Call callback if provided
            if callback is not None and i % callback_steps == 0:
                callback(i, t_cur, x_next)

        # Close progress bar
        progress_bar.close()

        # Final samples are in x_next
        samples = x_next

        # Decode the samples based on output_type
        if output_type == "latent":
            image = samples
        elif output_type == "pil":
            # Decode and convert to PIL
            image = self.decode_latents(samples)
            image = self.numpy_to_pil(image)
        else:
            # Default: decode to tensor
            image = self.decode_latents(samples)

        return (
            image,
            guidance_scales_used,
            latents_list,
            unconditional_embeds.detach().cpu(),
            conditional_embeds.detach().cpu()
        )

    def __call__(
        self,
        prompt = None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale = 7.5,
        negative_prompt = None,
        num_images_per_prompt = 1,
        generator = None,
        latents = None,
        output_type = None,
        return_dict = None
    ):
        """
        Simplified forward method for generating images without tracking intermediate states.
        This is a simpler version of forward_ddim_stepwise_cfg_scales_with_latents.

        Args:
            prompt: Text prompt or list of prompts
            height: Image height (ignored, uses model's default)
            width: Image width (ignored, uses model's default)
            num_inference_steps: Number of denoising steps
            guidance_scale: Guidance scale (equivalent to guidance_scale_base), can be a float or a tensor for batch-specific values
            negative_prompt: Negative prompt (ignored for EDM2)
            num_images_per_prompt: Number of images to generate per prompt
            generator: Random generator for reproducibility
            latents: Pre-generated noise latents
            output_type: Output format ("pil", "latent", or "tensor")

        Returns:
            Object with 'images' attribute containing the generated images
        """
        device = next(self.unet.parameters()).device

        # Print model parameter information if DISTRL_PRINT_UNETPARAMETER is set
        if DISTRL_PRINT_UNETPARAMETER:
            # Only print on rank 0 to avoid duplicate logs
            is_rank_0 = True
            try:
                # Try to get local rank from torch.distributed
                if torch.distributed.is_initialized():
                    is_rank_0 = torch.distributed.get_rank() == 0
            except:
                # If distributed is not initialized or not available
                # Check common environment variables for rank
                for env_var in ['RANK', 'LOCAL_RANK', 'SLURM_PROCID']:
                    if os.environ.get(env_var, '0') != '0':
                        is_rank_0 = False
                        break

            if is_rank_0 and (not getattr(self.unet, 'printed', False)):
                # Print UNet parameters
                s = print_model_params(self.unet, "UNet")
                logger.info(s)
                self.unet.printed = True

                # Print GNet parameters if different from UNet
                # if self.gnet is not self.unet:
                    # print_model_params(self.gnet, "GNet")

        # Process inputs
        if isinstance(prompt, str):
            prompt = [prompt]
        elif prompt is None:
            logger.warning("No prompt provided, using empty prompt")
            prompt = [""]

        batch_size = len(prompt)

        # Get model resolution
        resolution = self._get_unet_attr("img_resolution")
        height = height or resolution
        width = width or resolution

        # Convert prompts to class labels
        label_indices = [self._extract_class_from_prompt(p) for p in prompt]

        # Create one-hot encoded class labels
        labels = torch.zeros((batch_size, self._get_unet_attr("label_dim")), device=device)
        for i, idx in enumerate(label_indices):
            labels[i, idx] = 1.0

        # Handle guidance scale as tensor or scalar
        if isinstance(guidance_scale, (list, tuple, np.ndarray)):
            guidance_scale = torch.tensor(guidance_scale, device=device)

        if isinstance(guidance_scale, torch.Tensor):
            if guidance_scale.ndim == 0:
                guidance_scale = guidance_scale.view(1, 1, 1, 1).expand(batch_size, 1, 1, 1)
            elif guidance_scale.ndim == 1:
                if len(guidance_scale) != batch_size:
                    raise ValueError(f"Guidance scale tensor length ({len(guidance_scale)}) must match batch size ({batch_size})")
                guidance_scale = guidance_scale.view(batch_size, 1, 1, 1)
            else:
                raise ValueError(f"Guidance scale tensor must be 0D or 1D, got {guidance_scale.ndim}D")
        else:
            # Convert scalar to tensor with batch dimension
            guidance_scale = torch.full((batch_size, 1, 1, 1), guidance_scale, device=device)

        # Handle multiple images per prompt
        if num_images_per_prompt > 1:
            # Expand all tensors for num_images_per_prompt
            prompt = prompt * num_images_per_prompt
            labels = labels.repeat(num_images_per_prompt, 1)
            guidance_scale = guidance_scale.repeat(num_images_per_prompt, 1, 1, 1)
            batch_size = len(prompt)

        # Generate initial noise if not provided
        if latents is None:
            if generator is None:
                seed = torch.randint(0, 2**32 - 1, (1,)).item()
                generator = torch.Generator(device=device).manual_seed(int(seed))

            # Generate noise matching EDM2's expected format
            latents = torch.randn(
                (batch_size, self._get_unet_attr("img_channels"), resolution, resolution),
                device=device,
                generator=generator
            )
        else:
            # Ensure latents batch size matches
            if latents.shape[0] != batch_size:
                raise ValueError(f"Latents batch size ({latents.shape[0]}) must match prompt batch size ({batch_size})")

        # Generate timesteps for EDM sampler
        sigma_min = 0.002
        sigma_max = 80
        rho = 7

        # Time step discretization for EDM sampler
        step_indices = torch.arange(num_inference_steps, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_inference_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Main sampling loop
        x_next = latents * t_steps[0]

        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            x_cur = x_next

            # Helper function for denoising
            def denoise(x, t_val):
                t_tensor = torch.full([x.shape[0]], t_val, device=device)

                # Support batched guidance scale if needed
                if (guidance_scale != 1.0).any():  # Check if any guidance scale is not 1.0
                    raise NotImplementedError("EDM2 does not CFG at the moment")
                    # Get unconditional and conditional outputs
                    net_out = self.unet(x, t_tensor, labels)
                    gnet_out = self.gnet(x, t_tensor, None)  # Unconditional

                    # Apply per-batch guidance scale using broadcasting
                    return gnet_out + guidance_scale * (net_out - gnet_out)
                else:
                    return self.unet(x, t_tensor, labels)

            # Euler step
            d_cur = (x_cur - denoise(x_cur, t_cur)) / t_cur
            x_next = x_cur + (t_next - t_cur) * d_cur

            # Apply 2nd order correction if not the last step
            # if i < num_inference_steps - 1:
            #     d_prime = (x_next - denoise(x_next, t_next)) / t_next
            #     x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)

        # Final samples are in x_next
        samples = x_next

        # Decode the samples based on output_type
        if output_type == "latent":
            images = samples
        elif output_type == "pil":
            # Decode and convert to PIL
            images = self.decode_latents(samples)
            images = self.numpy_to_pil(images)
        else:
            # Default: decode to tensor
            images = self.decode_latents(samples)

        # Return as object with images attribute for compatibility
        class SimpleOutput:
            def __init__(self, images):
                self.images = images

        return SimpleOutput(images)

    def generate_noise_latents(self, batch_size):
        """
        Generate noise latents matching EDM2's expected format.

        Args:
            batch_size: Number of latent samples to generate
            generator: Optional random generator for reproducibility

        Returns:
            Tensor of shape (batch_size, channels, height, width) on CPU
        """
        # Get model resolution and channels
        resolution = self._get_unet_attr("img_resolution")
        channels = self._get_unet_attr("img_channels")

        # Generate noise matching EDM2's expected format
        latents = torch.randn(
            (batch_size, channels, resolution, resolution),
        )

        return latents

    def forward_collect_traj_ddim(
        self,
        prompt = None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale = 1,
        negative_prompt = None,
        num_images_per_prompt = 1,
        eta = 1.0,
        generator = None,
        latents = None,
        prompt_embeds = None,
        negative_prompt_embeds = None,
        output_type = None,
        return_dict = True,
        callback = None,
        callback_steps = 1,
        cross_attention_kwargs = None,
        is_ddp = False,
        unet_copy=None,
        soft_reward=False,
    ):
        """
        EDM2 implementation of collecting trajectory during diffusion sampling process.
        Follows input/output format from StableDiffusionPipelineExtended.forward_collect_traj_ddim.

        Args:
            prompt: Text prompt or list of prompts (in EDM2, should be like "an image of X")
            height: Image height (ignored, uses model's default)
            width: Image width (ignored, uses model's default)
            num_inference_steps: Number of denoising steps
            guidance_scale: Base guidance scale
            negative_prompt: Negative prompt (ignored for EDM2)
            num_images_per_prompt: Number of images to generate per prompt
            eta: Unused parameter (for compatibility)
            generator: Random generator for reproducibility
            latents: Pre-generated noise latents
            prompt_embeds: Pre-generated text embeddings (ignored for EDM2)
            negative_prompt_embeds: Pre-generated negative text embeddings (ignored for EDM2)
            output_type: Output format ("pil", "latent", or None for tensor)
            return_dict: Whether to return a dict (ignored)
            callback: Optional callback function
            callback_steps: Frequency for callback
            cross_attention_kwargs: Kwargs for cross attention (ignored)
            is_ddp: Whether using DistributedDataParallel
            unet_copy: Optional reference model for soft reward calculation
            soft_reward: Whether to calculate soft rewards

        Returns:
            Tuple of (image, latents_list, unconditional_embeds, conditional_embeds, log_probs_list, kl_path_list)
        """
        assert soft_reward is False
        device = next(self.unet.parameters()).device

        # Process input prompts
        if isinstance(prompt, str):
            prompt = [prompt]
        elif prompt is None:
            logger.warning("No prompt provided, using empty prompt")
            prompt = [""]

        batch_size = len(prompt)

        # Get model resolution
        resolution = self._get_unet_attr("img_resolution")
        height = height or resolution
        width = width or resolution

        # Convert prompts to class labels for EDM2
        label_indices = [self._extract_class_from_prompt(p) for p in prompt]

        # Create one-hot encoded class labels
        labels = torch.zeros((batch_size, self._get_unet_attr("label_dim")), device=device)
        for i, idx in enumerate(label_indices):
            labels[i, idx] = 1.0

        # Handle multiple images per prompt
        if num_images_per_prompt > 1:
            # Expand all tensors for num_images_per_prompt
            prompt = prompt * num_images_per_prompt
            labels = labels.repeat(num_images_per_prompt, 1)
            batch_size = len(prompt)

        unconditional_embeds = torch.zeros_like(labels)  # Match shape with labels
        conditional_embeds = labels  # Use actual class labels instead of dummy ones

        # Generate timesteps for EDM sampler
        sigma_min = 0.002
        sigma_max = 80
        rho = 7

        # Generate initial noise if not provided
        if latents is None:
            if generator is None:
                seed = torch.randint(0, 2**32 - 1, (1,)).item()
                generator = torch.Generator(device=device).manual_seed(int(seed))

            # Generate noise matching EDM2's expected format
            latents = torch.randn(
                (batch_size, self._get_unet_attr("img_channels"), resolution, resolution),
                device=device,
                generator=generator
            )

        # Time step discretization for EDM sampler
        step_indices = torch.arange(num_inference_steps, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_inference_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])])  # t_N = 0

        # Storage for trajectories, latents, and log probs
        log_prob_list = []

        # Create progress bar
        with tqdm.tqdm(total=num_inference_steps) as progress_bar:
            # Main sampling loop
            x_next = latents * t_steps[0]
            latents_list = [x_next.detach().clone().cpu()]

            for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
                x_cur = x_next

                # Helper function for guided denoising
                def denoise(x, t_val):
                    t_tensor = torch.full([x.shape[0]], t_val, device=device)

                    # Get unconditional output
                    if guidance_scale > 1.0:
                        # For classifier-free guidance
                        gnet_out = self.gnet(x, t_tensor, None)  # Unconditional
                        net_out = self.unet(x, t_tensor, labels)  # Conditional

                        # Apply guidance
                        return gnet_out + guidance_scale * (net_out - gnet_out)
                    else:
                        return self.unet(x, t_tensor, labels)

                # Calculate predicted noise and log probability
                # Get model prediction: unet_result = x_cur - t_cur * (x_next - x_cur) / (t_next - t_cur)

                # Euler step
                d_cur = (x_cur - denoise(x_cur, t_cur)) / t_cur
                x_next = x_cur + (t_next - t_cur) * d_cur

                # Apply 2nd order correction if not the last step
                # if i < num_inference_steps - 1:
                #     d_prime = (x_next - denoise(x_next, t_next)) / t_next
                #     x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)

                # FIXME: Calculate log probability of this step
                # Since EDM2 doesn't have a scheduler with step_logprob, we implement a simplified version
                # This is a dummy implementation - in practice would need to match the distribution used in EDM2
                # EDM2 uses Gaussian transitions, so log_prob is proportional to squared error
                noise_residual = (x_cur - x_next) / (t_cur - t_next) - d_cur
                log_prob = -0.5 * torch.sum(noise_residual**2, dim=[1, 2, 3])

                # Store the log probability and latents
                # log_prob_list.append(log_prob.detach().cpu())
                # log_prob_list.append((x_cur - t_cur * (x_next - x_cur) / (t_next - t_cur)).detach().cpu())  # HACK
                log_prob_list.append(denoise(x_cur, t_cur).detach().cpu())
                # log_prob_list.append(torch.full([x_next.shape[0]], t_cur))
                latents_list.append(x_next.detach().clone().cpu())

                # Update progress bar
                if i % 5 == 0 or i == len(t_steps) - 2:
                    progress_bar.update(min(5, progress_bar.total - progress_bar.n))

                # Call callback if provided
                if callback is not None and i % callback_steps == 0:
                    callback(i, t_cur, x_next)

        # Decode the final result
        samples = x_next

        # Process result based on output_type
        if output_type == "latent":
            image = samples
        elif output_type == "pil":
            image = self.decode_latents(samples)
            image = self.numpy_to_pil(image)
        else:
            image = self.decode_latents(samples)

        kl_path = None

        return (
            image,
            latents_list,
            unconditional_embeds.detach().cpu(),
            conditional_embeds.detach().cpu(),
            log_prob_list,
            kl_path,
        )

    def forward_calculate_logprob(
        self,
        prompt_embeds,
        latents,
        next_latents,
        ts,
        unet_copy=None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale = 7.5,
        negative_prompt = None,
        num_images_per_prompt = 1,
        eta = 1.0,
        generator = None,
        negative_prompt_embeds = None,
        output_type = "pil",
        return_dict = True,
        callback = None,
        callback_steps = 1,
        cross_attention_kwargs = None,
        is_ddp = False,
        soft_reward=False,
    ):
        """
        EDM2 implementation for calculating log probability between transitions.
        Follows input/output format from StableDiffusionPipelineExtended.forward_calculate_logprob.

        Args:
            prompt_embeds: In EDM2, this is used to extract batch size information
            latents: Current state latents
            next_latents: Next state latents
            ts: Timesteps to calculate log probability for
            unet_copy: Optional reference model for KL divergence calculation
            height: Image height (ignored, uses model's default)
            width: Image width (ignored, uses model's default)
            num_inference_steps: Total number of inference steps
            guidance_scale: Guidance scale for classifier-free guidance
            negative_prompt: Ignored for EDM2
            num_images_per_prompt: Number of images per prompt
            eta: Unused parameter (for compatibility)
            generator: Random number generator
            negative_prompt_embeds: Ignored for EDM2
            output_type: Output format (ignored for this method)
            return_dict: Whether to return a dict (ignored)
            callback: Optional callback function (ignored)
            callback_steps: Frequency of callback (ignored)
            cross_attention_kwargs: Kwargs for cross attention (ignored)
            is_ddp: Whether using DistributedDataParallel
            soft_reward: Whether to calculate soft rewards

        Returns:
            Tuple of (log_prob, kl_regularizer)
        """
        assert soft_reward is False
        device = next(self.unet.parameters()).device

        # Extract batch size from prompt_embeds
        batch_size = prompt_embeds.shape[0] // 2  # Divide by 2 as prompt_embeds contains both conditional and unconditional

        # Extract class labels from prompt_embeds
        # In EDM2 we don't have text embeddings, so we'll use dummy labels
        # In a real implementation, the actual class labels would need to be passed
        labels = torch.zeros((batch_size, self._get_unet_attr("label_dim")), device=device)

        # Generate timesteps for EDM sampler
        sigma_min = 0.002
        sigma_max = 80
        rho = 7

        # Time step discretization for EDM sampler
        step_indices = torch.arange(num_inference_steps, device=device)
        t_steps = (sigma_max ** (1 / rho) + step_indices / (num_inference_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

        # Get current timesteps from ts index
        unet_times = t_steps[ts]

        # Create tensor input for the model
        unet_t_tensor = torch.full([latents.shape[0]], unet_times.item(), device=device)

        # Get both conditional and unconditional outputs
        noise_pred_uncond = self.gnet(latents, unet_t_tensor, None)
        noise_pred_text = self.unet(latents, unet_t_tensor, labels)

        # Apply classifier-free guidance
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        kl_regularizer = torch.zeros(latents.shape[0], device=device)

        # FIXME: Calculate log probability of the transition
        # In EDM2, transitions follow a Gaussian model
        # The log probability is proportional to the negative squared distance
        # between the predicted and actual next state
        # We scale by timestep since noise magnitude depends on timestep
        sigma_cur = unet_times.item()

        # Euler step - predicted next state
        d_cur = (latents - noise_pred) / sigma_cur
        pred_next = latents - sigma_cur * d_cur

        # Calculate log probability using squared error
        mse = torch.sum((next_latents - pred_next)**2, dim=[1, 2, 3])
        log_prob = -0.5 * mse / (sigma_cur**2)

        return log_prob, kl_regularizer

    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        shape = (batch_size, self._get_unet_attr("img_channels"), self._get_unet_attr("img_resolution"), self._get_unet_attr("img_resolution"))
        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if latents is None:
            latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
        else:
            latents = latents.to(device)

        # scale the initial noise by the standard deviation required by the scheduler
        # latents = latents * self.scheduler.init_noise_sigma
        return latents
