### Preamble ##########################################################################################################

"""
Base GController class
"""

#######################################################################################################################

### Imports ###########################################################################################################

import os
from typing import Union, List, Tuple
import torch
from diffusers import ModelMixin, ConfigMixin, DiffusionPipeline
from diffusers.configuration_utils import register_to_config
from transformers.utils import is_torch_xla_available, ENV_VARS_TRUE_VALUES

#######################################################################################################################

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()


class GController(ModelMixin, ConfigMixin):
    """
    Abstract GController class to be extended by all classes implementing some form of diffusion guidance control.
    """

    @register_to_config
    def __init__(self, **kwargs):
        """
        The `__init__()` function of extending classes must set `_requires_uncond_noise` to either True or False. This
        denotes whether the diffusion pipeline will perform unconditional noise estimates or not. Will save
        computation time if the unconditional noise estimates are unnecessary. Additionally, the `__init__` function
        of the extending class must have the `@register_to_config` decorator from the diffusers package, or manually
        register args to config using `self.register_to_config(param1=param1,...)`.
        """

        super().__init__(**kwargs)

        self._requires_uncond_noise = None

        # Setting a dummy parameter for instances where a GController does not define any model parameters. This
        # causes issues with interaction with the HuggingFace library due to undefined dtype and device.
        self._dummy_param_updated = False
        if "torch_dtype" in kwargs:
            dtype = kwargs["torch_dtype"]
        else:
            dtype = torch.float32
        super().register_parameter("_dummy_param", torch.nn.Parameter(torch.tensor([], dtype=dtype)))

    def forward(
        self,
        _pipeline: DiffusionPipeline,
        _unconditional_noise: torch.Tensor,
        _conditional_noise: torch.Tensor,
    ) -> torch.Tensor:
        """
        The `forward` method of the extending class should accept at a minimum `_pipeline : DiffusionPipeline`,
        `_latents: torch.Tensor`, `_unconditional_noise : Union[torch.Tensor, None]`, and
        `_conditional_noise : torch.Tensor` arguments.
        """

        raise NotImplementedError("A `forward` method must be implemented when extending `GController`")

    @staticmethod
    def do_gcontrol(self, args) -> bool:
        """
        Whether the GController forward method will change the estimated noise based on the input arguments. The
        `do_gcontrol` method of the extending class should accept the same arguments as the additional args sent to
        the `forward` method. To improve efficiency gcontrol modules should specify a set of parameters that they do
        not run for (and thus unconditional noise estimates are not needed for), i.e., w = 1 in classifier-free
        guidance.
        """

        raise NotImplementedError("A static `do_gcontrol` method must be implemented when extending `GController`")

    @property
    def requires_uncond_noise(self):
        """
        Whether the extending class requires unconditional noise estimates during the `forward` method. If
        `requires_uncond_noise` is True then `_unconditional_noise` will be None in the `forward` method.
        """

        if self._requires_uncond_noise is None:
            raise NotImplementedError(
                "The `_requires_uncond_noise` attribute must be set to True or False when " "extending `GController`"
            )
        return self._requires_uncond_noise

    @property
    def dtype(self):
        """
        Returns the first floating point datatype of tensors in self.parameters(), otherwise if there are no floating
        point tensors returns the datatype of the last tensor in self.parameters(), or None if there are no tensors in
        self.parameters(). This property is for compatibility with some of the HuggingFace internals in the diffusers
        and transformers libraries.
        """

        return self.__get_parameter_dtype()

    @property
    def device(self):
        """
        Returns the device of the first tensor in self.parameters(), , or None if there are no tensors in
        self.parameters().
        """

        return self.__get_parameter_device()

    def register_parameter(self, name: str, param: torch.nn.Parameter):
        """
        param name: str
            Name of the parameter to register.
        param param: torch.nn.Parameter
            torch parameter to register.

        HuggingFace expectes `torch.nn.Module`s to have at least one parameter. In cases where there are no learnable
        parameters (classifier-free guidance) the lack of a parameter list can cause issues. Thus, we overload the
        `register_parameter` method to keep track of a dummy variable and cast it to the first parameter declared by
        the child class.
        """

        if not self._dummy_param_updated:
            updated_dummy = torch.nn.Parameter(self._dummy_param.to(dtype=param.dtype, device=param.device))
            super().register_parameter("_dummy_param", updated_dummy)
            self._dummy_param_updated = True
        return super().register_parameter(name, param)

    def register_module(self, name: str, module: torch.nn.Module):
        """
        param name: str
            Name of the module to register.
        param param: torch.nn.Module
            torch parameter to register.

        HuggingFace expectes `torch.nn.Module`s to have at least one parameter. In cases where there are no learnable
        parameters (classifier-free guidance) the lack of a parameter list can cause issues. Thus, we overloaad the
        `register_module` method to keep track of a dummy variable and cast it to the first parameter declared by
        the child class. In the case of a module being defined first the dummy variable is cast to the same type as
        first parameter of the module.
        """

        if not self._dummy_param_updated:
            try:
                param = next(module.parameters())
                updated_dummy = torch.nn.Parameter(self._dummy_param.to(dtype=param.dtype, device=param.device))
                super().register_parameter("_dummy_param", updated_dummy)
                self._dummy_param_updated = True
            except StopIteration:
                pass
        return super().register_module(name, module)

    def __get_parameter_dtype(self):
        """
        Returns the first floating point datatype of tensors in self.parameters(), otherwise if there are no floating
        point tensors returns the datatype of the last tensor in self.parameters(), or the type of the
        `dummy_parameter` initialised with the object. This property is for compatibility with some of the HuggingFace
        internals in the diffusers and transformers libraries.

        Taken in part from transformers.modeling_utils.get_parameter_dtype
        """

        last_dtype = None
        for t in self.parameters():
            last_dtype = t.dtype
            if t.is_floating_point():
                # Adding fix for https://github.com/pytorch/xla/issues/4152
                # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
                # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
                # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
                if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
                    return torch.bfloat16
                if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
                    if t.dtype == torch.float:
                        return torch.bfloat16
                    if t.dtype == torch.double:
                        return torch.float32
                return t.dtype

        return last_dtype

    def __get_parameter_device(self):
        """
        Returns the first floating point datatype of tensors in self.parameters(), otherwise if there are no floating
        point tensors returns the datatype of the last tensor in self.parameters(), or "cpu" if there are no user
        defined parameters. This property is for compatibility with some of the HuggingFace internals in the diffusers
        and transformers libraries.

        Taken in part from transformers.modeling_utils.get_parameter_device
        """

        return next(self.parameters()).device


#######################################################################################################################
