### Preamble ##########################################################################################################

"""
A GController module that implements mixture classifier-free guidance for the StableDiffusion model. A separate 
controller is used for this to avoid tracking the latent gradient. 
"""

#######################################################################################################################

### Imports ###########################################################################################################

import torch
from torchvision.transforms.v2 import TenCrop, Grayscale
from torchvision.transforms.v2.functional import rotate
from typing import Union, Iterable, Optional, Tuple, Callable

from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from diffusers.configuration_utils import register_to_config

from transformers import PreTrainedModel

from gcontrol.guidance_controllers.controller_utils import GController

#######################################################################################################################


class MixtureGuidance(GController):

    @register_to_config
    def __init__(self, **kwargs):
        """
        Returns a `gcontrol` mixture guidance module for use in the stable1 pipeline.
        """

        super().__init__(**kwargs)

        self._requires_uncond_noise = True
        self._requires_latent_grad = False

    def forward(
        self,
        _pipeline: DiffusionPipeline,
        _unconditional_noise: torch.Tensor,
        _conditional_noise: torch.Tensor,
        _gconditional_noise: torch.Tensor,
        g_w: float = 5,
        g_p: float = 2,
        g_m: Optional[float] = 0.5,
        **kwargs,
    ) -> torch.Tensor:
        """
        :param g_w: float
            Equivalent to the guidance scale used in the classifier-free guidance scheme. How much the diffusion
            follows classifier-free guidance.
        :param g_p: float
            Equivalent to the guidance scale used in the classifier-free guidance scheme. A larger value adds stronger
            adversarial features at the cost of image diversity.
        :param g_m: float
            The mixing scale for classifier-free adversarial guidance. A higher value incorporates more information
            from the classifier-free adversarial diffusion. If `None`, then only adversarial classifier guidance is
            used.
        """

        if (_gconditional_noise is None) and (g_m is not None):
            raise ValueError(f"_gconditional_noise is `None`, but `g_m` is {g_m}, expected `g_m = {None}`")

        class_vec = -_unconditional_noise + _conditional_noise
        if g_m is None:
            noise = _unconditional_noise + g_w * class_vec
        else:
            adv_vec = -_unconditional_noise + _gconditional_noise
            noise = _unconditional_noise + g_w * class_vec + g_m * (-g_w * class_vec + g_p * adv_vec)

        return noise

    @staticmethod
    def do_gcontrol(g_w: float, g_m: float, **kwargs):
        return (g_w != 1) or (g_m != 0)

    @property
    def requires_latent_grad(self):
        """
        Whether the the guidance controller requires the latents to have has gradients calculated through the UNet.
        """

        return self._requires_latent_grad


#######################################################################################################################
