
import json
import os
import torch
from PIL import Image
import sys

from distrl.models.diffusion_models.base import DiffusionModelBase

DISTRL_SIT_VAE = os.environ.get('DISTRL_SIT_VAE', 'ema')
DISTRL_PRINT_UNETPARAMETER = os.environ.get('DISTRL_PRINT_UNETPARAMETER', None)
DISTRL_DEBUG_ODE = os.environ.get('DISTRL_DEBUG_ODE', None)
DISTRL_DEBUG_TEMP = os.environ.get('DISTRL_DEBUG_TEMP', None)
DISTRL_RL_NOTRAIN_LASTSTEP = os.environ.get('DISTRL_RL_NOTRAIN_LASTSTEP', None)
DISTRL_DEBUG_ALIGN_SFT_RL = os.environ.get('DISTRL_DEBUG_ALIGN_SFT_RL', None)



# Add SIT original code path to system path
script_dir = os.path.dirname(os.path.abspath(__file__))
original_dir = os.path.join(script_dir, 'original')
if original_dir not in sys.path and os.path.isdir(original_dir):
    sys.path.insert(0, original_dir)
    print(f"Added {original_dir} to Python path")

# Import SIT modules from original code if available
try:
    from .original.models import SiT_models
    from .original.transport import create_transport, Sampler
    from diffusers.models import AutoencoderKL
    from .original.transport.transport import RLSampler
except ImportError as e:
    print(f"Failed to import SIT modules: {e}")
    print("Make sure the original SIT implementation is in the 'original' directory.")
    raise

# Helper function to get parameter stats
def print_model_params(model, model_name) -> str:
    params = list(model.parameters())
    total_params = len(params)

    if total_params > 0:
        # Get first 5 parameters
        first_params = params[:5]
        first_params_str = "\n".join([f"  {i}: {p.shape} | {p.mean().item():.6f} | {p.std().item():.6f}"
                                    for i, p in enumerate(first_params)])

        # Get last 5 parameters
        last_params = params[-5:]
        last_params_str = "\n".join([f"  {total_params-5+i}: {p.shape} | {p.mean().item():.6f} | {p.std().item():.6f}"
                                    for i, p in enumerate(last_params)])

        # Calculate stats for all parameters
        all_params = torch.cat([p.flatten() for p in params])
        mean_all = all_params.mean().item()
        std_all = all_params.std().item()
        min_all = all_params.min().item()
        max_all = all_params.max().item()

        output_str = f"{model_name} Parameter Summary:\n"
        output_str += f"Total parameters: {total_params}\n"
        output_str += f"First 5 parameters (idx | shape | mean | std):\n{first_params_str}\n"
        output_str += f"Last 5 parameters (idx | shape | mean | std):\n{last_params_str}\n"
        output_str += f"All parameters - Mean: {mean_all:.6f}, Std: {std_all:.6f}, Min: {min_all:.6f}, Max: {max_all:.6f}"

        return output_str
    return ""

def load_sit_model(pretrained_path, weight_dtype=torch.float32):
    """
    Load a SIT model from the given checkpoint path.

    Args:
        pretrained_path: Path to the checkpoint
        weight_dtype: Data type to load the model in

    Returns:
        Tuple of (sit_model, vae)
    """
    # Check if file exists
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Model checkpoint not found at {pretrained_path}")

    # Load checkpoint
    checkpoint = torch.load(pretrained_path, map_location="cpu")

    # Handle different checkpoint formats
    if isinstance(checkpoint, dict) and "model" in checkpoint:
        # Format from training script
        state_dict = checkpoint["model"]
        ema_state_dict = checkpoint.get("ema")
        print("Loaded SIT model from training checkpoint")
    else:
        # Direct state dict
        state_dict = checkpoint
        ema_state_dict = None
        print("Loaded SIT model from direct state dict")

    # Determine model size from state dict
    model_size = "SiT-XL/2"  # Default
    for key in SiT_models.keys():
        # Try to infer model size from state dict shape
        if key in pretrained_path:
            model_size = key
            break

    # Determine image size and model parameters
    image_size = 256
    num_classes = 1000
    latent_size = image_size // 8
    vae_type = DISTRL_SIT_VAE

    # Create model
    model = SiT_models[model_size](
        input_size=latent_size,
        num_classes=num_classes,
        learn_sigma=True,
    )

    # Load state dict
    if ema_state_dict is not None:
        model.load_state_dict(ema_state_dict)
    else:
        model.load_state_dict(state_dict)

    # Create VAE
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{vae_type}")

    # Set model dtype
    model = model.to(dtype=weight_dtype)
    vae = vae.to(dtype=weight_dtype)

    return model, vae


class SITDiffusionModel(DiffusionModelBase):
    """
    Implementation of SIT (Scalable Image Transformer) model for AutoCFG framework.
    Adapts SIT to work with the same interface as other diffusion models.
    """

    def __init__(self, sit_model, vae_model):
        """
        Initialize SITDiffusionModel with SIT model and VAE.

        Args:
            sit_model: SIT model
            vae_model: VAE model for encoding/decoding images
        """
        self.unet = sit_model
        self.vae = vae_model
        self.pipe = self  # Self-reference as pipe for compatibility

        # Load the ImageNet class names to indices mapping if available
        script_dir = os.path.dirname(os.path.abspath(__file__))
        simple_json_path = os.path.join(script_dir, '..', '..', '..', '..', 'dataset', 'imagenet1k', 'simple.json')

        try:
            with open(simple_json_path, 'r') as f:
                self.imagenet_classes = json.load(f)
        except FileNotFoundError:
            print(f"Could not find ImageNet class mapping at {simple_json_path}")
            self.imagenet_classes = []

        # Set up attributes for compatibility
        # self.unet.in_channels = 4  # VAE latent channels
        # self.unet.out_channels = 4

        # Create transport and sampler
        self.transport = create_transport("Linear", "velocity")

        self.rl = os.environ.get('DISTRL_RL', None)
        if self.rl:
            self.sampler = RLSampler(self.transport)  # TODO: use for RL
        else:
            self.sampler = Sampler(self.transport)

        # Initialize sample_fn as None - will be created on demand
        self.sample_fn = None

    @property
    def device(self):
        return next(self.unet.parameters()).device

    @property
    def dtype(self):
        return next(self.unet.parameters()).dtype

    def configure_sampling_function(self, num_steps):
        """
        Configure the SDE sampling function with hardcoded parameters.

        Args:
            num_steps: Number of diffusion steps to use
        """
        # Hardcoded SDE parameters
        sampling_method = "Euler"
        if DISTRL_DEBUG_ODE:
            sampling_method = "ODE"
        diffusion_form = "sigma"
        diffusion_norm = 1.0
        last_step = "Mean"
        last_step_size = 0.04

        print(f"Using SDE sampler with {sampling_method}, steps={num_steps}")
        self.sample_fn = self.sampler.sample_sde(
            sampling_method=sampling_method,
            diffusion_form=diffusion_form,
            diffusion_norm=diffusion_norm,
            last_step=last_step,
            last_step_size=last_step_size,
            num_steps=num_steps
        )
        return self.sample_fn

    def to(self, device, dtype=None):
        """
        Move model components to specified device and cast to dtype.
        """
        self.to_device(device, dtype)
        return self

    @classmethod
    def from_pretrained(cls, pretrained_model_path, revision=None,
                        weight_dtype=torch.float32, retry_on_error=False, **kwargs):
        """
        Create model from pretrained SIT weights.

        Args:
            pretrained_model_path: Path to pretrained SIT model
            revision: Specific model revision (ignored for SIT)
            weight_dtype: Data type for model weights
            retry_on_error: Whether to retry loading on error (ignored for SIT)
            **kwargs: Additional arguments

        Returns:
            SITDiffusionModel instance
        """
        if retry_on_error:
            print("retry_on_error is not supported for SIT model")

        print(f"Loading SIT model from {pretrained_model_path}")

        # Load the SIT model and VAE
        try:
            sit_model, vae = load_sit_model(pretrained_model_path, weight_dtype)
        except Exception as e:
            print(f"Error loading SIT model: {e}")
            raise

        # Create model instance
        model = cls(sit_model=sit_model, vae_model=vae)

        return model

    @classmethod
    def from_sft(cls, sft_path, revision=None,
                weight_dtype=torch.float32, retry_on_error=False, **kwargs):
        """
        Create model from SFT (supervised fine-tuned) weights.
        For SIT, this is the same as from_pretrained.

        Args:
            sft_path: Path to SFT model weights
            revision: Specific model revision (ignored for SIT)
            weight_dtype: Data type for model weights
            retry_on_error: Whether to retry loading on error
            **kwargs: Additional arguments

        Returns:
            SITDiffusionModel instance
        """
        sft_path = os.path.join(sft_path, "policy_model.pt")
        # For SIT, the SFT loading is the same as pretrained loading
        return cls.from_pretrained(sft_path, revision, weight_dtype, retry_on_error, **kwargs)

    def to_device(self, device, dtype=None):
        """
        Move model components to specified device and cast to dtype.

        Args:
            device: Device to move model to ('cuda', 'cpu', etc)
            dtype: Data type to cast model parameters to

        Returns:
            Self for chaining
        """
        if dtype is not None:
            self.unet.to(device, dtype=dtype)
            self.vae.to(device, dtype=dtype)
        else:
            self.unet.to(device)
            self.vae.to(device)
        return self

    def set_gradient(self, enable=True):
        """
        Enable or disable gradients for model components.

        Args:
            enable: Whether to enable gradients

        Returns:
            Self for chaining
        """
        for param in self.unet.parameters():
            param.requires_grad_(enable)

        # Set specific training mode based on gradient status
        self.unet.train(enable)  # SIT requires train mode for class dropout

        return self

    def enable_xformers(self):
        """
        Enable xformers memory efficient attention if available.
        SIT may not support xformers directly.

        Returns:
            False as SIT doesn't have built-in xformers support
        """
        print("xformers optimization not supported for SIT model")
        return False

    def get_components(self):
        """
        Return individual model components to match other diffusion model interfaces.

        Returns:
            Dictionary containing model components
        """
        return {
            "pipe": self,  # Self as pipe for compatibility
            "unet": self.unet,  # Main network as unet
            "vae": self.vae,  # VAE encoder/decoder
            "text_encoder": None,  # No text encoder in SIT
            "tokenizer": None  # No tokenizer in SIT
        }

    def set_scheduler(self, scheduler_type):
        """
        Set the scheduler for the model.
        SIT uses its own transport-based sampling algorithm.

        Args:
            scheduler_type: Scheduler class or instance (ignored)

        Returns:
            Self for chaining
        """
        print("SIT uses its own transport-based sampling algorithm and doesn't require a diffusers scheduler")
        return self

    def _extract_class_from_prompt(self, prompt):
        """
        Extract class name from prompt string like "an image of X" and return the ImageNet index.

        Args:
            prompt: Prompt string or list of strings

        Returns:
            Class index or tensor of class indices
        """
        if isinstance(prompt, list):
            return torch.tensor([self._extract_class_from_prompt(p) for p in prompt])

        # Extract class name from "an image of X"
        if prompt.startswith("an image of "):
            class_name = prompt[len("an image of "):]
        else:
            class_name = prompt

        # Find class in ImageNet classes or use mapped number if prompt is a number
        try:
            if class_name.isdigit():
                class_idx = int(class_name)
                if 0 <= class_idx < 1000:
                    return class_idx

            # Try to find in class list
            if self.imagenet_classes:
                class_idx = self.imagenet_classes.index(class_name)
            else:
                raise ValueError(f"Class '{class_name}' not found in ImageNet classes")
        except ValueError:
            raise ValueError(f"Class '{class_name}' not found in ImageNet classes")

        return class_idx

    def decode_latents(self, latents):
        """
        Decode the generated latents to images using VAE.

        Args:
            latents: Generated latents from the sampling process (BCHW format)

        Returns:
            Decoded images as torch.Tensor in [0, 1] range (BCHW format)
        """
        # Decode with VAE - input and output are in BCHW format
        # Scale factor 0.18215 is used for consistency with Stable Diffusion
        with torch.no_grad():
            images = self.vae.decode(latents.to(dtype=self.dtype) / 0.18215).sample

        # Convert from [-1, 1] to [0, 1], TODO: fix this
        # images = (images + 1) / 2

        return images  # Returns in BCHW format

    def numpy_to_pil(self, images):
        """
        Convert tensor images to PIL images.

        Args:
            images: Tensor images in [0, 1] range, shape [batch, channels, height, width] (BCHW format)

        Returns:
            List of PIL images
        """
        raise NotImplementedError("SIT does not support numpy_to_pil")
        pil_images = []
        if images.ndim == 4:
            for image in images:
                # Convert to uint8
                img = (image.clamp(0, 1) * 255).to(torch.uint8).cpu()
                if img.shape[0] == 3:  # If channels-first format (CHW)
                    img = img.permute(1, 2, 0)  # Convert to HWC format for PIL
                pil_images.append(Image.fromarray(img.numpy()))
        else:
            # Single image
            img = (images.clamp(0, 1) * 255).to(torch.uint8).cpu()
            if img.shape[0] == 3:  # If channels-first format (CHW)
                img = img.permute(1, 2, 0)  # Convert to HWC format for PIL
            pil_images.append(Image.fromarray(img.numpy()))

        return pil_images

    def _get_unet_attr(self, attr_name, default=None):
        """
        Get attribute from UNet with a default fallback.
        """
        if hasattr(self.unet, attr_name):
            return getattr(self.unet, attr_name)
        elif hasattr(self.unet, "module") and hasattr(self.unet.module, attr_name):
            return getattr(self.unet.module, attr_name)
        else:
            if default is not None:
                return default
            raise AttributeError(f"UNet does not have attribute {attr_name}")

    def __call__(
        self,
        prompt = None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale = 1,
        negative_prompt = None,
        num_images_per_prompt = 1,
        generator = None,
        latents = None,
        output_type = None,
        return_dict = None
    ):
        """
        Simplified forward method for generating images without tracking intermediate states.

        Args:
            prompt: Text prompt or list of prompts
            height: Image height (ignored, uses fixed size from SIT model)
            width: Image width (ignored, uses fixed size from SIT model)
            num_inference_steps: Number of denoising steps
            guidance_scale: Guidance scale
            negative_prompt: Negative prompt (not used in SIT)
            num_images_per_prompt: Number of images to generate per prompt
            generator: Random generator for reproducibility
            latents: Pre-generated noise latents
            output_type: Output format ("pil", "latent", or tensor)

        Returns:
            Object with 'images' attribute containing the generated images
        """
        device = next(self.unet.parameters()).device

        # Print model parameter information if DISTRL_PRINT_UNETPARAMETER is set
        if DISTRL_PRINT_UNETPARAMETER:
            # Only print on rank 0 to avoid duplicate logs
            is_rank_0 = True
            try:
                # Try to get local rank from torch.distributed
                import torch.distributed as dist
                if dist.is_initialized():
                    is_rank_0 = dist.get_rank() == 0
            except:
                # If distributed is not initialized or not available
                # Check common environment variables for rank
                for env_var in ['RANK', 'LOCAL_RANK', 'SLURM_PROCID']:
                    if os.environ.get(env_var, '0') != '0':
                        is_rank_0 = False
                        break

            if is_rank_0 and (not getattr(self.unet, 'printed', False)):
                # Print UNet parameters
                s = print_model_params(self.unet, "SIT Model")
                print(s)
                self.unet.printed = True

        # Process inputs
        if isinstance(prompt, str):
            prompt = [prompt]
        elif prompt is None:
            raise ValueError("No prompt provided")

        batch_size = len(prompt)

        # Get model resolution
        image_size = int(os.environ.get('DISTRL_SIT_IMAGE_SIZE', '256'))
        latent_size = image_size // 8

        # Convert prompts to class labels
        class_labels = [self._extract_class_from_prompt(p) for p in prompt]
        y = torch.tensor(class_labels, device=device)

        # Apply num_images_per_prompt
        if num_images_per_prompt > 1:
            y = y.repeat(num_images_per_prompt)
            batch_size = batch_size * num_images_per_prompt

        # Generate initial noise if not provided
        if latents is None:
            if isinstance(generator, list):
                if len(generator) != batch_size:
                    raise ValueError(
                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                    )
                z = torch.cat([torch.randn(
                    (1, 4, latent_size, latent_size),
                    generator=g,
                    device=device,
                    dtype=self.dtype) for g in generator])
            else:
                z = torch.randn(
                    (batch_size, 4, latent_size, latent_size),
                    generator=generator,
                    device=device,
                    dtype=self.dtype
                )
        else:
            z = latents.to(device)

        # Setup classifier-free guidance
        use_cfg = guidance_scale > 1.0
        if use_cfg:
            raise NotImplementedError("CFG not supported for SIT")
            # Duplicate latents and add null class labels
            z = torch.cat([z, z], 0)
            y_null = torch.tensor([1000] * batch_size, device=device)  # 1000 is the null class in SIT
            y_combined = torch.cat([y, y_null], 0)
            model_kwargs = dict(y=y_combined, cfg_scale=guidance_scale)
        else:
            model_kwargs = dict(y=y)

        # Configure sample_fn if not already set
        if self.sample_fn is None:
            self.configure_sampling_function(num_inference_steps)

        # Ensure sample_fn is configured
        if self.sample_fn is None:
            raise RuntimeError("Failed to configure sampling function. Check SIT configuration.")

        # Run sampling with the configured sampler
        self.unet.to(dtype=self.dtype)

        # TODO: remove this
        if DISTRL_DEBUG_TEMP:
            z = torch.load("/apdcephfs/rhli_private_sz/workspace/distrl/SiT/a.pth")
            model_kwargs = dict(y=torch.tensor([164]).cuda())

        samples = self.sample_fn(z, self.unet.forward, **model_kwargs)
        # print(f"len(samples): {len(samples)}")
        # Get the final sample for image decoding
        final_sample = samples[-1]  # In BCHW format

        # Remove null class samples if using CFG
        if use_cfg and isinstance(final_sample, torch.Tensor):
            final_sample, _ = torch.chunk(final_sample, 2, dim=0)


        # Decode the samples based on output_type
        if output_type == "latent":
            images = final_sample  # Keep in BCHW format
        else:
            images = self.decode_latents(final_sample)  # Get decoded images in BCHW format
            if output_type == "pil":
                images = self.numpy_to_pil(images)  # Converts BCHW to PIL (implicitly to HWC)
            elif output_type == "pt":
                # Convert from [-1,1] to [0,255] while keeping BCHW format
                # This matches the scaling in sample_ddp.py: 127.5 * x + 128.0
                images = torch.clamp(127.5 * images + 128.0, 0, 255).to(dtype=torch.uint8)

        # Return as object with images attribute for compatibility
        class SimpleOutput:
            def __init__(self, images):
                self.images = images

        return SimpleOutput(images)

    def generate_noise_latents(self, batch_size):
        """
        Generate noise latents matching SIT's expected format.

        Args:
            batch_size: Number of latent samples to generate

        Returns:
            Tensor of shape (batch_size, channels, height, width) on CPU
        """
        # Get model resolution
        image_size = int(os.environ.get('DISTRL_SIT_IMAGE_SIZE', '256'))
        latent_size = image_size // 8

        # Generate noise for VAE latent space (4 channels)
        latents = torch.randn(
            (batch_size, 4, latent_size, latent_size),
        )

        return latents

    def forward_collect_traj_ddim(
        self,
        prompt = None,
        height = None,
        width = None,
        num_inference_steps = 50,
        guidance_scale = 1,
        negative_prompt = None,
        num_images_per_prompt = 1,
        eta = 1.0,
        generator = None,
        latents = None,
        prompt_embeds = None,
        negative_prompt_embeds = None,
        output_type = None,
        return_dict = True,
        callback = None,
        callback_steps = 1,
        cross_attention_kwargs = None,
        is_ddp = False,
        unet_copy=None,
        soft_reward=False,
    ):
        """
        SIT implementation for collecting trajectory during diffusion sampling process.
        Follows the interface of other diffusion models.

        Returns:
            Tuple of (image, latents_list, unconditional_embeds, conditional_embeds, log_probs_list, kl_path_list)
        """
        assert soft_reward is False, "Soft reward not supported for SIT"
        device = next(self.unet.parameters()).device

        # Process inputs
        if isinstance(prompt, str):
            prompt = [prompt]
        elif prompt is None:
            raise ValueError("No prompt provided")

        batch_size = len(prompt)

        # Get model resolution
        image_size = int(os.environ.get('DISTRL_SIT_IMAGE_SIZE', '256'))
        latent_size = image_size // 8

        # Convert prompts to class labels
        class_labels = [self._extract_class_from_prompt(p) for p in prompt]
        y = torch.tensor(class_labels, device=device)

        # Store conditional embeddings for output
        conditional_embeds = torch.zeros((batch_size, 1000), device=device)
        for i, idx in enumerate(class_labels):
            conditional_embeds[i, idx] = 1.0

        # Get unconditional embeddings for output
        unconditional_embeds = torch.zeros_like(conditional_embeds)

        # Apply num_images_per_prompt
        if num_images_per_prompt > 1:
            y = y.repeat(num_images_per_prompt)
            batch_size = batch_size * num_images_per_prompt

        # Generate initial noise if not provided
        if latents is None:
            if isinstance(generator, list):
                if len(generator) != batch_size:
                    raise ValueError(
                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                    )
                latents = torch.cat([torch.randn(
                    (1, 4, latent_size, latent_size),
                    generator=g,
                    device=device,
                    dtype=self.dtype) for g in generator])
            else:
                latents = torch.randn(
                    (batch_size, 4, latent_size, latent_size),
                    generator=generator,
                    device=device,
                    dtype=self.dtype
                )

        # Setup for model input
        z = latents
        # z = torch.load("a.pth")[:latents.shape[0]]     # HACK
        model_kwargs = dict(y=y)

        # Configure sample_fn if not already set
        if self.sample_fn is None:
            self.configure_sampling_function(num_inference_steps)

        # Ensure sample_fn is configured
        if self.sample_fn is None:
            raise RuntimeError("Failed to configure sampling function. Check SIT configuration.")

        # For capturing intermediate latents
        latents_list = [z.detach().cpu()]
        log_prob_list = []

        # Run sampling with the configured sampler
        samples = self.sample_fn(z, self.unet.forward, **model_kwargs)
        latents_list.extend([sample.detach().cpu() for sample in samples])
        final_sample = samples[-1]

        if self.rl:
            # * RL: get log_prob_list
            log_prob_list = [log_prob.detach().cpu() for log_prob in samples.log_probs]
        else:
            log_prob_list = latents_list[1:]    # HACK

        if DISTRL_RL_NOTRAIN_LASTSTEP:
            latents_list = latents_list[:-1]
            log_prob_list = log_prob_list[:-1]

        # Decode the final sample based on output_type
        if output_type == "latent":
            image = final_sample
        elif output_type == "pil":
            image = self.decode_latents(final_sample)
            image = self.numpy_to_pil(image)
        elif output_type == "pt":
            image = self.decode_latents(final_sample)
            image = torch.clamp(127.5 * image + 128.0, 0, 255).to(dtype=torch.uint8).detach()   # (bs, 3, 256, 256)
        else:
            image = self.decode_latents(final_sample)

        # Return in the format expected by pipeline_stable_diffusion_extended.py
        return (
            image,
            latents_list,
            unconditional_embeds.detach().cpu(),
            conditional_embeds.detach().cpu(),
            log_prob_list,
            None,  # No KL path
        )

    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
        """
        Prepare latents for diffusion model.

        Args:
            batch_size: Batch size
            num_channels_latents: Number of channels in latents
            height: Image height
            width: Image width
            dtype: Data type
            device: Device
            generator: Random number generator
            latents: Optional pre-generated latents

        Returns:
            Prepared latents tensor
        """
        # Override with SIT's expected values
        image_size = int(os.environ.get('DISTRL_SIT_IMAGE_SIZE', '256'))
        latent_size = image_size // 8
        channels = 4  # VAE latent channels

        shape = (batch_size, channels, latent_size, latent_size)

        if latents is None:
            if isinstance(generator, list):
                if len(generator) != batch_size:
                    raise ValueError(
                        f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                        f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                    )
                latents = torch.cat([torch.randn(
                    (1, channels, latent_size, latent_size),
                    generator=g,
                    device=device,
                    dtype=dtype) for g in generator])
            else:
                latents = torch.randn(
                    shape,
                    generator=generator,
                    device=device,
                    dtype=dtype
                )
        else:
            latents = latents.to(device)

        return latents

    def encode_images(self, images):
        """
        Encode images to latents.
        """
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample().mul_(0.18215)
        return latents

    def forward_calculate_logprob(
        self,
        prompt_embeds,
        latents,
        next_latents,
        ts,
        unet_copy=None,
        height=None,
        width=None,
        num_inference_steps=50,
        guidance_scale=1.0,
        negative_prompt=None,
        num_images_per_prompt=1,
        eta=1.0,
        generator=None,
        negative_prompt_embeds=None,
        output_type="pil",
        return_dict=True,
        callback=None,
        callback_steps=1,
        cross_attention_kwargs=None,
        is_ddp=False,
        soft_reward=False,
    ):
        """
        Calculate log probability for SIT model transitions.
        Adapts the interface from SD pipeline to work with SIT's transport system.

        Args:
            prompt_embeds: Class embeddings (one-hot encoded class indices) [batch * 2, num_classes], 前一半是 0 (unconditional), 后一半是 one-hot
            latents: Current state [batch, channels, height, width]
            next_latents: Next state [batch, channels, height, width]
            ts: Time step indices [batch,] (integer indices into timestep schedule)
            unet_copy: Optional old model for KL regularization (not fully supported yet)
            Other args: Maintained for interface compatibility but may be ignored

        Returns:
            Tuple of (log_prob, kl_regularizer)
        """
        device = next(self.unet.parameters()).device

        # Convert prompt_embeds (one-hot class vectors) to class indices
        y = torch.argmax(prompt_embeds[len(prompt_embeds) // 2:], dim=1).to(device)
        batch_size = latents.shape[0]
        # Setup model kwargs for SIT
        model_kwargs = dict(y=y)

        # timestep schedule
        last_step_size = 0.04
        t_steps = torch.linspace(0, 1 - last_step_size, num_inference_steps, device=device)
        ts = t_steps[ts]

        with torch.autocast("cuda", dtype=self.dtype):
            _, _, _, dist = self.sampler.sde.Euler_Maruyama_step(
                latents, ts, self.unet.forward, **model_kwargs
            )
            log_prob = dist.log_prob(next_latents)

        # KL regularization
        if unet_copy is not None:
            with torch.no_grad():
                with torch.autocast("cuda", dtype=self.dtype):
                    _, _, _, dist_ref = self.sampler.sde.Euler_Maruyama_step(
                        latents, ts, unet_copy.forward, **model_kwargs
                    )
                kl_regularizer = dist.kl_divergence(dist_ref)
        else:
            kl_regularizer = torch.zeros(batch_size, device=device)

        if DISTRL_DEBUG_ALIGN_SFT_RL:
            kl_regularizer.dist_ = dist

        return log_prob, kl_regularizer
