### Preamble ##########################################################################################################

"""
A GController module that implements the standard classifier-free guidance scheme from 
(https://arxiv.org/pdf/2207.12598). Note that this guidance controller is compatible with all diffusion pipelines.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import torch
from typing import Union, Iterable, Optional, Tuple
from diffusers import ModelMixin, ConfigMixin
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import register_to_config
from ..controller_utils import GController

#######################################################################################################################


class ClassifierFreeGuidance(GController):

    @register_to_config
    def __init__(self, **kwargs):
        """
        Returns a `gcontrol` classifier-free guidance module for use in the stable1 pipeline.
        """

        super().__init__(**kwargs)

        self._requires_uncond_noise = True

    def forward(
        self,
        _pipeline: DiffusionPipeline,
        _unconditional_noise: torch.Tensor,
        _conditional_noise: torch.Tensor,
        guidance_scale: float = 5,
        **kwargs,
    ) -> torch.Tensor:
        """
        :param guidance_scale: float
            The guidance scale used in the classifier-free guidance scheme. Note that this implementation treats the
            guidance scale as `w` of equation 2. of the [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).

        Note that common GuidanceController forward methods must accept kwargs to ensure forward compatibility with
        other pipeline specific guidance controllers.
        """

        return _unconditional_noise + guidance_scale * (_conditional_noise - _unconditional_noise)

    @staticmethod
    def do_gcontrol(guidance_scale: float, **kwargs):
        return guidance_scale != 1


#######################################################################################################################
