### Preamble ##########################################################################################################

"""
Pytorch classifier pipeline for working with `torch` and `timm` models.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import torch
from typing import Union, Iterable, Optional, Tuple, Callable
from torchvision.transforms import Compose, CenterCrop, Normalize, Resize
from transformers import PreTrainedModel

#######################################################################################################################

### Pytorch Classifier Configs ########################################################################################

resnet50_config = {
    "size": (3, 224, 224),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 224 / 232,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

inceptionv3_config = {
    "size": (3, 299, 299),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 299 / 342,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

vith14_config = {
    "size": (3, 224, 224),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 1,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

resnet152_config = {
    "size": (3, 224, 224),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 224 / 232,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

swinb_config = {
    "size": (3, 256, 256),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 256 / 272,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

deit_config = {
    "size": (3, 224, 224),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 224 / 256,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

maxvit_config = {
    "size": (3, 224, 224),
    "image_mean": [0.485, 0.456, 0.406],
    "image_std": [0.229, 0.224, 0.225],
    "do_normalize": True,
    "do_rescale": True,
    "crop_pct": 1,
    "rescale_factor": 1.0 / 255.0,
    "do_resize": True,
}

#######################################################################################################################


class ClassifierPipeline(torch.nn.Module):

    def __init__(
        self,
        classifier: Union[PreTrainedModel, torch.nn.Module],
        do_resize: bool,
        do_rescale: bool,
        do_normalize: bool,
        size: Union[dict, tuple[int, int, int], tuple[int, int], int],
        crop_pct: Optional[float] = None,
        rescale_factor: Optional[float] = None,
        image_mean: Optional[Union[list[float], torch.Tensor]] = None,
        image_std: Optional[Union[list[float], torch.Tensor]] = None,
        **kwargs,
    ):
        """
        :param classifier: PreTrainedModel or torch.nn.Module
            The classification model.
        :param do_resize:  bool
            Whether the image should be resized prior to being passed to the classifier.
        :param do_rescale: bool
            Whether the image should be rescaled prior to being passed to the classifier.
        :param do_normalize: bool
            Whether to normalise the image prior to classification.
        :param size: int, (int, int), (int, int, int)
            An integer or tuple of integers with shape (C, H, W) or (H, W) that denotes the height and width that the
            diffusion image will be resized to when passing to the classifier. Note a single integer will cause images
            to be resized with equal height and width.
        :param crop_pct: float
            Determines whether the image will be resized and then cropped to preserve aspect ratio. `crop_pct` is the
            percentage of the resized image that won't be cropped. If an image is to be resized to `size = (200, 200)`
            and `crop_pct = 0.8`, then the image will be resized to (250, 250) (i.e., 200 / 0.8) and then cropped to
            (200, 200).
        :param rescale_factor: float
            The scale factor to be applied to the image post resizing, but prior to normalisation.
        :param image_mean: list or torch.Tensor
            The image mean to be used in normalisation.
        :param image_std: list or torch.Tensor
            The image standard deviation to be used in normalisation.

        Constructs the classification pipeline.
        """

        super().__init__()

        if isinstance(size, dict):
            if "shortest_edge" in size:
                size = (size["shortest_edge"], size["shortest_edge"])
            else:
                raise ValueError("`size` dict must contain key `shortest_edge`")
        elif isinstance(size, int):
            size = (size, size)
        elif isinstance(size, tuple):
            if len(size) == 2:
                size = size
            elif len(size) == 3:
                size = size[1:]
            else:
                raise ValueError(f"`size` expected tuple of length 2 or 3, got length {len(size)}")
        else:
            raise TypeError("Got unsupported `size` type")

        if isinstance(image_mean, torch.Tensor):
            image_mean = image_mean.tolist()

        if isinstance(image_std, torch.Tensor):
            image_std = image_std.tolist()

        for param in classifier.parameters():
            param.requires_grad = False

        self.do_resize = do_resize
        self.do_rescale = do_rescale
        self.do_normalize = do_normalize
        self.size = size
        self.crop_pct = crop_pct
        self.image_mean = image_mean
        self.image_std = image_std
        self.rescale_factor = rescale_factor

        # Initialising preprocessor
        transforms = []
        if self.crop_pct is not None and self.do_resize:  # Resize and crop preserving aspect ratio
            resize_size = int(round(min(self.size) / self.crop_pct))
            transforms.append(Resize(resize_size))
            transforms.append(CenterCrop(self.size))
        elif self.do_resize:  # Resize violating aspect ratio
            transforms.append(Resize(self.size))

        if self.do_rescale:  # Rescale, typically dividing by max pixel value (255)
            transforms.append(Normalize(0, 1 / self.rescale_factor))

        if self.do_normalize:  # Normalize with some mean and std
            transforms.append(Normalize(self.image_mean, self.image_std))

        if len(transforms) > 0:
            self.preprocessor = Compose(transforms=transforms)
        else:
            self.preprocessor = None

        self.register_module("classifier", classifier)

    def forward(self, x: torch.Tensor):
        """
        :param x: torch.Tensor
            A (N, C, H, W) batch of images.
        """

        # Get current dtype and device
        first_tensor = next(self.classifier.parameters(), next(self.classifier.buffers(), None))
        dtype = first_tensor.dtype if first_tensor is not None else None
        device = first_tensor.device if first_tensor is not None else None

        x = x.to(dtype=dtype, device=device)
        x = self.preprocessor(x)
        return self.classifier(x)


#######################################################################################################################
