### Preamble ##########################################################################################################

"""
A GController module that implements classifier guidance for the StableDiffusionXL model.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import torch
from torchvision.transforms.v2 import TenCrop, Grayscale
from torchvision.transforms.v2.functional import rotate
from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
from typing import Union, Iterable, Optional, Tuple, Callable

from diffusers import DiffusionPipeline
from diffusers.models import AutoencoderKL
from diffusers.configuration_utils import register_to_config

from transformers import PreTrainedModel

from ..controller_utils import GController

from gcontrol.utils.im_utils import array_to_PIL

#######################################################################################################################


class ClassifierGuidance(GController):

    def __init__(
        self,
        classifier: Union[PreTrainedModel, torch.nn.Module],
        do_resize: bool,
        do_rescale: bool,
        do_normalize: bool,
        size: Union[dict, tuple[int, int, int], tuple[int, int], int],
        crop_pct: Optional[float] = None,
        rescale_factor: Optional[float] = None,
        image_mean: Optional[Union[list[float], torch.Tensor]] = None,
        image_std: Optional[Union[list[float], torch.Tensor]] = None,
        **kwargs,
    ):
        """
        :param classifier: PreTrainedModel or torch.nn.Module
            The classification model used to guide the diffusion process. Should return either a `torch.Tensor` of
            logits, or a object with a `logits` attribute.
        :param do_resize:  bool
            Whether the diffusion latents should be resized prior to being passed to the classifier.
        :param do_rescale: bool
            Whether the diffusion latents should be rescaled prior to being passed to the classifier.
        :param do_normalize: bool
            Whether to normalise the image prior to classification.
        :param size: int, (int, int), (int, int, int)
            An integer or tuple of integers with shape (C, H, W) or (H, W) that denotes the height and width that the
            diffusion image will be resized to when passing to the classifier. Note a single integer will cause images
            to be resized with equal height and width.
        :param crop_pct: float
            Determines whether the image will be resized and then cropped to preserve aspect ratio. `crop_pct` is the
            percentage of the resized image that won't be cropped. If an image is to be resized to `size = (200, 200)`
            and `crop_pct = 0.8`, then the image will be resized to (250, 250) (i.e., 200 / 0.8) and then cropped to
            (200, 200).
        :param rescale_factor: float
            The scale factor to be applied to the image post resizing, but prior to normalisation.
        :param image_mean: list or torch.Tensor
            The image mean to be used in normalisation.
        :param image_std: list or torch.Tensor
            The image standard deviation to be used in normalisation.

        Returns a `gcontrol` classifier guidance for use in the stableXL pipeline.
        """

        # For compatibility with transformers preprocess dict
        for attr in ["_processor_class", "crop_pct", "resample", "image_processor_type"]:
            if attr in kwargs:
                _ = kwargs.pop(attr)

        super().__init__(**kwargs)

        if isinstance(size, dict):
            if "shortest_edge" in size:
                size = (size["shortest_edge"], size["shortest_edge"])
            else:
                raise ValueError("`size` dict must contain key `shortest_edge`")
        elif isinstance(size, int):
            size = (size, size)
        elif isinstance(size, tuple):
            if len(size) == 2:
                size = size
            elif len(size) == 3:
                size = size[1:]
            else:
                raise ValueError(f"`size` expected tuple of length 2 or 3, got length {len(size)}")
        else:
            raise TypeError("Got unsupported `size` type")

        if isinstance(image_mean, torch.Tensor):
            image_mean = image_mean.tolist()

        if isinstance(image_std, torch.Tensor):
            image_std = image_std.tolist()

        self.register_module("classifier", classifier)

        self.register_to_config(
            crop_pct=crop_pct,
            do_resize=do_resize,
            size=size,
            do_rescale=do_rescale,
            rescale_factor=rescale_factor,
            do_normalize=do_normalize,
            image_mean=image_mean,
            image_std=image_std,
        )

        # Initialising preprocessor
        transforms = []
        if self.config.crop_pct is not None and self.config.do_resize:  # Resize and crop preserving aspect ratio
            resize_size = int(round(min(self.config.size) / self.config.crop_pct))
            transforms.append(Resize(resize_size))
            transforms.append(CenterCrop(self.config.size))
        elif self.config.do_resize:  # Resize violating aspect ratio
            transforms.append(Resize(self.config.size))

        if self.config.do_rescale:  # Rescale, typically dividing by max pixel value (255)
            transforms.append(Normalize(0, 1 / self.config.rescale_factor))

        if self.config.do_normalize:  # Normalize with some mean and std
            transforms.append(Normalize(self.config.image_mean, self.config.image_std))

        if len(transforms) > 0:
            self.preprocessor = Compose(transforms=transforms)
        else:
            self.preprocessor = None

        self._requires_uncond_noise = True
        self._requires_latent_grad = True

    def forward(
        self,
        _pipeline: DiffusionPipeline,
        _unconditional_noise: Union[torch.Tensor, None],
        _conditional_noise: torch.Tensor,
        _latents: torch.Tensor,
        _t,
        _extra_step_kwargs,
        target_idx: int = 0,
        g_w: float = 5,
        grad_norm: int = 2,
        augmentations: Optional[Union[list[Callable], Callable, str]] = None,
        grad_zero_threshold: Optional[float] = None,
        debug: bool = False,
        **kwargs,
    ) -> torch.Tensor:
        """
        :param g_w: float
            Equivalent to the guidance scale used in the classifier guidance scheme. How much the diffusion follows
            classifier guidance.
        :param grad_norm:  int | None
            The normalisation applied to the classifier gradient. A value of None is no normalisation, any other value
            is the l_{grad_norm} norm.
        :param augmentations: list of functions, function, or str
            Optional augmentation(s) to be applied to the images prior to classification. The classification
            logits are averaged across the predictions for each transformation and the unaugmented image. If `None`,
            then no transformations are applied. If an augmentation changes the shape of the image, then it will be
            resized using `torchvision.transforms.Resize`. Pass "recommended" to apply the recommended crop,
            rotations, and greyscale transforms. Pass `None` for no transforms.
        :grad_zero_threshold float | None
            Whether the classifier gradient should be set to 0 to prevent floating point underflow (can improve
            stability of sampling). If the classifier gradient falls below `grad_zero_threshold` then it is set to 0.
            If None, then the classifier gradient is not changed.
        """

        # Check that module has been passed latents with gradients
        if not torch.is_grad_enabled():
            raise ValueError("ClassifierGuidance requires torch.is_grad_enabled() to be true")
        if not _latents.requires_grad:
            raise ValueError("ClassifierGuidance expected _latents to have requires_grad = true")
        if _conditional_noise.grad_fn is None:
            raise ValueError("ClassifierGuidance expected _conditional_noise to have a non None grad_fn")

        if isinstance(target_idx, int):
            target_idx = [target_idx]

        if augmentations == "recommended":
            augmentations = [TenCrop(0.6 * self.config.size[0]), Grayscale(3)]

        # Handling VAE casting to prevent overflow
        needs_upcasting = _pipeline.vae.dtype == torch.float16
        if needs_upcasting:
            _pipeline.upcast_vae()
            latents = _latents.to(next(iter(_pipeline.vae.post_quant_conv.parameters())).dtype)
        else:
            latents = _latents

        if grad_zero_threshold is not None:
            _latents.register_hook(lambda grad: torch.where(grad.abs() > grad_zero_threshold, grad, 0.0))

        # Get the diffusion step standard deviation
        noise_sd = _pipeline.scheduler.add_noise(
            torch.tensor([0], dtype=latents.dtype),
            torch.tensor([1], dtype=latents.dtype),
            torch.tensor([_t], dtype=_t.dtype),
        )
        noise_sd = noise_sd.to(dtype=latents.dtype, device=latents.device)

        # Get the estimated original image
        try:
            pred_latent = _pipeline.scheduler.step(
                _unconditional_noise, _t, latents, **_extra_step_kwargs, return_dict=True
            ).pred_original_sample
            pred_latent = pred_latent.to(latents.dtype)  # Recast for compatibility issues
        except AttributeError:
            raise TypeError(
                f"Diffusion pipeline scheduler: `{type(_pipeline.scheduler)}` does not provide original "
                "sample predictions. Try another scheduler such as "
                "`<diffusers.schedulers.scheduling_ddim.DDIMScheduler>`"
            )
        if hasattr(_pipeline.scheduler, "_step_index"):
            _pipeline.scheduler._step_index -= 1  # Reset the step for use in the diffusion pipeline

        # Classifier prediction
        if needs_upcasting:
            pred_latent = pred_latent.to(torch.float16)
        image = self._latents_to_image(_pipeline.vae, pred_latent)
        processed_image = self.preprocessor(image)

        ## Try to handle latent and classifier dtype mismatch
        try:
            logits = self._do_classifier(processed_image, augmentations=augmentations)
        except:
            if next(self.classifier.parameters()).dtype != processed_image.dtype:
                processed_image = processed_image.to(next(self.classifier.parameters()).dtype)
            logits = self._do_classifier(processed_image, augmentations=augmentations)
        log_prob = torch.log_softmax(logits, dim=1)

        # Compute the gradient accounting for CPU offloading
        if _pipeline.unet.device != _pipeline.vae.device:
            original_unet_device = _pipeline.unet.device
            _pipeline.unet.to(_pipeline.vae.device)
            # Batches are independent, so individual gradients will be computed
            torch.sum(log_prob[torch.arange(len(target_idx)), target_idx]).backward()
            _pipeline.unet.to(original_unet_device)
        else:
            # Batches are independent, so individual gradients will be computed
            torch.sum(log_prob[torch.arange(len(target_idx)), target_idx]).backward()

        # Reset the VAE to float16 if required
        if needs_upcasting:
            _pipeline.vae.to(dtype=torch.float16)

        # Compute the noise estimate
        if grad_norm is not None:
            grad = torch.nn.functional.normalize(_latents.grad, p=grad_norm, dim=(-3, -2, -1))
        else:
            grad = _latents.grad
        noise = _unconditional_noise - noise_sd * g_w * grad
        noise = noise.to(_latents.dtype)

        if debug == True:
            predicted_label = logits[0].argmax(-1).item()
            print(f"t={_t}, s={noise_sd[0].item()}-----------------------------------")
            print(f"Target Class Predicted Probability:", torch.exp(log_prob[0, target_idx]).item())
            if hasattr(self.classifier, "config"):
                print(f"Predicted Class:", self.classifier.config.id2label[predicted_label])
            print(
                f"Gradient (min, mean, max):",
                [
                    torch.min(_latents.grad[0]).item(),
                    torch.mean(_latents.grad[0:1], dim=(0, 1, 2, 3)).item(),
                    torch.max(_latents.grad[0]).item(),
                ],
                "NANS:",
                torch.isnan(_latents.grad).any().item(),
            )
            print(
                f"Unconditional Patch (min, mean, max):",
                [
                    torch.min(_unconditional_noise[0]).item(),
                    torch.mean(_unconditional_noise[0:1], dim=(0, 1, 2, 3)).item(),
                    torch.max(_unconditional_noise[0]).item(),
                ],
                "NANS:",
                torch.isnan(_unconditional_noise[0]).any().item(),
            )
            print(
                f"Average Conditional Patch (min, mean, max):",
                [
                    torch.min(_conditional_noise[0]).item(),
                    torch.mean(_conditional_noise[0:1], dim=(0, 1, 2, 3)).item(),
                    torch.max(_conditional_noise[0]).item(),
                ],
                "NANS:",
                torch.isnan(_conditional_noise[0]).any().item(),
            )

            # pil_im = array_to_PIL(pt_image[0].detach().cpu())
            # pil_im.show()

        return noise

    @staticmethod
    def do_gcontrol(g_w: float, **kwargs):
        return g_w != 0

    def _latents_to_image(self, vae: AutoencoderKL, latents: torch.Tensor) -> torch.Tensor:
        """
        :param vae: AutoencoderKL
            The variational autoencoder used by the diffusion pipeline.
        :param latents: torch.Tensor
            (B, 4, h, w) The latents of the diffusion model.

        Converts the latents to a tensor of RGB values (B, 3, h, w). Note that this function does not handle upcasting
        of the VAE or latents. This is to support autograd for downstream tasks. The forward method should handle vae
        upcasting/downcasting.
        """

        # unscale/denormalize the latents denormalize with the mean and std if available and not None
        has_latents_mean = hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None
        has_latents_std = hasattr(vae.config, "latents_std") and vae.config.latents_std is not None
        if has_latents_mean and has_latents_std:
            latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
            latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
            latents = latents * latents_std / vae.config.scaling_factor + latents_mean
        else:
            latents = latents / vae.config.scaling_factor

        image = vae.decode(latents, return_dict=False)[0]
        image = (image / 2 + 0.5).clamp(min=0, max=1) * 255

        return image

    def _do_classifier(
        self, image: torch.Tensor, augmentations: Optional[Union[list[Callable], Callable]] = None
    ) -> torch.Tensor:
        """
        :param image: torch.Tensor
            (B, 3, h, w) A batch of images.
        :param augmentations: list of functions or function
            Optional augmentation(s) to be applied to the images prior to classification. The classification
            logits are averaged across the predictions for each transformation and the unaugmented image. If `None`,
            then no transformations are applied. If an augmentation changes the shape of the image, then it will be
            resized using `torchvision.transforms.Resize`.

        Returns a tensor of shape (B, S) containing the classification logits for each class, where `S` is the total
        number of classes. Note that logits are averaged across transformations if applied.
        """

        if augmentations is not None:
            original_shape = image.shape[-2:]
            if not isinstance(augmentations, list):
                augmentations = [augmentations]
            augmentations.append(lambda x: x)  # Add no augmentations func to augmentations list

            logits = []
            augmented_image = []
            for aug in augmentations:
                tmpim = aug(image)
                if isinstance(tmpim, torch.Tensor):
                    augmented_image.append(tmpim)
                elif isinstance(tmpim, list):  # Handling augmentations that return tuple or lists of augmented images
                    augmented_image += tmpim
                elif isinstance(tmpim, tuple):
                    augmented_image += list(tmpim)
                else:
                    raise TypeError("Unsupported augmentation return type.")
            for i, im in enumerate(augmented_image):
                if isinstance(im, torch.Tensor):
                    if im.shape[-2:] != original_shape:
                        im = Resize(original_shape)(im)  # Resizing augmented image if needed
                else:
                    raise TypeError("Unsupported augmentation return type.")

                classifier_output = self.classifier(im)
                if (not isinstance(classifier_output, torch.Tensor)) and hasattr(classifier_output, "logits"):
                    classifier_output = classifier_output.logits
                logits.append(classifier_output)

            logits = torch.stack(logits, dim=1)
            logits = torch.mean(logits, dim=1)  # Average over transforms
        else:
            classifier_output = self.classifier(image)
            if (not isinstance(classifier_output, torch.Tensor)) and hasattr(classifier_output, "logits"):
                classifier_output = classifier_output.logits
            logits = classifier_output
        return logits

    @property
    def requires_latent_grad(self):
        """
        Whether the the guidance controller requires the latents to have has gradients calculated through the UNet.
        """

        return self._requires_latent_grad


#######################################################################################################################
