from typing import Dict, Any, Optional
import torch

class DiffusionModelBase:
    """Base class for diffusion models used in the AutoCFG framework.

    Defines the minimal interface required for diffusion models to work with the AutoCFG pipeline.
    Each model implementation should extend this class and implement its methods.
    """

    @classmethod
    def from_pretrained(cls, pretrained_path: str, **kwargs) -> "DiffusionModelBase":
        """Create model from pretrained weights.

        Args:
            pretrained_path: Path or identifier for pretrained model
            **kwargs: Additional arguments for model loading

        Returns:
            Initialized diffusion model instance
        """
        raise NotImplementedError

    @classmethod
    def from_sft(cls, sft_path: str, **kwargs) -> "DiffusionModelBase":
        """Create model from SFT (supervised fine-tuned) weights.

        Args:
            sft_path: Path to SFT model weights
            **kwargs: Additional arguments for model loading

        Returns:
            Initialized diffusion model instance
        """
        raise NotImplementedError

    def to_device(self, device: str, dtype: Optional[torch.dtype] = None) -> "DiffusionModelBase":
        """Move model to specified device and cast to dtype.

        Args:
            device: Device to move model to ('cuda', 'cpu', etc)
            dtype: Data type to cast model parameters to

        Returns:
            Self for chaining
        """
        raise NotImplementedError

    def set_gradient(self, enable: bool = True) -> "DiffusionModelBase":
        """Enable or disable gradients for model components.

        Args:
            enable: Whether to enable gradients

        Returns:
            Self for chaining
        """
        raise NotImplementedError

    def enable_xformers(self) -> bool:
        """Enable xformers memory efficient attention if available.

        Returns:
            True if xformers was successfully enabled, False otherwise
        """
        raise NotImplementedError

    def get_components(self) -> Dict[str, Any]:
        """Return individual model components.

        Returns:
            Dictionary containing model components (unet, vae, text_encoder, etc)
        """
        raise NotImplementedError

    def set_scheduler(self, scheduler_type: Any) -> "DiffusionModelBase":
        """Set the scheduler for the model.

        Args:
            scheduler_type: Scheduler class or instance

        Returns:
            Self for chaining
        """
        raise NotImplementedError
