import os
DISTRL_DEBUG_MODEL_FP32_JUST_AUTOCAST = os.environ.get("DISTRL_DEBUG_MODEL_FP32_JUST_AUTOCAST", None)

import torch

from diffusers.schedulers.scheduling_ddim import DDIMScheduler

from distrl.models.diffusion_models.edm2 import EDM2DiffusionModel
from distrl.models.diffusion_models.sit import SITDiffusionModel



def create_diffusion_model(model_type, pretrained_path=None, sft_path=None, revision=None,
                         weight_dtype=torch.float32, device=None, retry_on_error=False, **kwargs):
    """
    Create and initialize a diffusion model based on model type.

    Args:
        model_type: Type of diffusion model to create (e.g., 'sd15', 'edm2', 'sit', 'sdxl', etc.)
        pretrained_path: Path to pretrained model weights
        sft_path: Path to SFT (supervised fine-tuned) weights
        revision: Specific model revision to use
        weight_dtype: Data type for model weights
        device: Device to move model to ('cuda', 'cpu', etc.)
        retry_on_error: Whether to retry loading on error (useful for distributed training)
        **kwargs: Additional arguments for specific model types

    Returns:
        Initialized diffusion model with scheduler set up
    """
    model = None

    if DISTRL_DEBUG_MODEL_FP32_JUST_AUTOCAST:
        weight_dtype = torch.float32

    # Create model based on type
    if model_type.lower() == 'edm2':
        if sft_path:
            # Initialize from SFT weights
            model = EDM2DiffusionModel.from_sft(
                sft_path=sft_path,
                revision=revision,
                weight_dtype=weight_dtype,
                retry_on_error=retry_on_error
            )
        else:
            # Initialize from pretrained weights
            model = EDM2DiffusionModel.from_pretrained(
                pretrained_model_path=pretrained_path,
                revision=revision,
                weight_dtype=weight_dtype,
                retry_on_error=retry_on_error
            )

        # EDM2 uses its own scheduler, no need to set one

    elif model_type.lower() == 'sit':
        if sft_path:
            # Initialize from SFT weights
            model = SITDiffusionModel.from_sft(
                sft_path=sft_path,
                revision=revision,
                weight_dtype=weight_dtype,
                retry_on_error=retry_on_error
            )
        else:
            # Initialize from pretrained weights
            model = SITDiffusionModel.from_pretrained(
                pretrained_model_path=pretrained_path,
                revision=revision,
                weight_dtype=weight_dtype,
                retry_on_error=retry_on_error
            )

        # SIT uses its own scheduler, no need to set one

    elif model_type.lower() == 'sdxl':
        # Example placeholder for future model types
        raise NotImplementedError("SDXL model type not yet supported")

    else:
        raise ValueError(f"Unknown model type: {model_type}")

    # Move model to device if specified
    if device:
        model.to_device(device, weight_dtype)

    return model
