# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py
import os
from collections import OrderedDict
from os import PathLike
from typing import Final

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import (
    SiglipImageProcessor,
    SiglipVisionConfig,
    SiglipVisionModel,
    logging,
)
from transformers.image_processing_utils import BatchFeature
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention

logging.set_verbosity_error()

URL: Final[str] = (
    "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth"
)


class AestheticPredictorV2_5Head(nn.Module):
    def __init__(self, config: SiglipVisionConfig) -> None:
        super().__init__()
        self.scoring_head = nn.Sequential(
            nn.Linear(config.hidden_size, 1024),
            nn.Dropout(0.5),
            nn.Linear(1024, 128),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.Dropout(0.5),
            nn.Linear(64, 16),
            nn.Dropout(0.2),
            nn.Linear(16, 1),
        )

    def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
        return self.scoring_head(image_embeds)


class AestheticPredictorV2_5Model(SiglipVisionModel):
    PATCH_SIZE = 14

    def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None:
        super().__init__(config, *args, **kwargs)
        self.layers = AestheticPredictorV2_5Head(config)
        self.post_init()
        self.transforms = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

    def forward(
        self,
        pixel_values: torch.FloatTensor | None = None,
        labels: torch.Tensor | None = None,
        return_dict: bool | None = None,
    ) -> tuple | ImageClassifierOutputWithNoAttention:
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = super().forward(
            pixel_values=pixel_values,
            return_dict=return_dict,
        )
        image_embeds = outputs.pooler_output
        image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
        prediction = self.layers(image_embeds_norm)

        loss = None
        if labels is not None:
            loss_fct = nn.MSELoss()
            loss = loss_fct()

        if not return_dict:
            return (loss, prediction, image_embeds)

        return ImageClassifierOutputWithNoAttention(
            loss=loss,
            logits=prediction,
            hidden_states=image_embeds,
        )


class AestheticPredictorV2_5Processor(SiglipImageProcessor):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def __call__(self, *args, **kwargs) -> BatchFeature:
        return super().__call__(*args, **kwargs)

    @classmethod
    def from_pretrained(
        self,
        pretrained_model_name_or_path: str
        | PathLike = "google/siglip-so400m-patch14-384",
        *args,
        **kwargs,
    ) -> "AestheticPredictorV2_5Processor":
        return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)


def convert_v2_5_from_siglip(
    predictor_name_or_path: str | PathLike | None = None,
    encoder_model_name: str = "google/siglip-so400m-patch14-384",
    *args,
    **kwargs,
) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]:
    model = AestheticPredictorV2_5Model.from_pretrained(
        encoder_model_name, *args, **kwargs
    )

    processor = AestheticPredictorV2_5Processor.from_pretrained(
        encoder_model_name, *args, **kwargs
    )

    if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path):
        state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu")
    else:
        state_dict = torch.load(predictor_name_or_path, map_location="cpu")

    assert isinstance(state_dict, OrderedDict)

    model.layers.load_state_dict(state_dict)
    model.eval()

    return model, processor