# Adpated from https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/simple_inference.py
# Please download the weights for MLPs from that repository

import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchvision.transforms import ToTensor
import clip
from PIL import Image
from reward_models.aesthetic_score.differentiable_transform import DifferentiableTransform
import argparse

MODEL_PATH = "/home/ubuntu/workspace/transfer-learning-for-DMs/base_models/aesthetic_score/sac+logos+ava1-l14-linearMSE.pth"

class MLP(pl.LightningModule):
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            #nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 16),
            #nn.ReLU(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

def normalized(a, axis=-1, order=2):
    import numpy as np  # pylint: disable=import-outside-toplevel

    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)

def torch_normalized(a, axis=-1, order=2):
    l2 = torch.linalg.norm(a, ord=order, dim=axis)
    l2 = torch.where(l2 == 0, torch.tensor(1.0, device=l2.device), l2)
    return a / l2.unsqueeze(axis)

class AestheticClassifier:
    r"""
    AestheticClassifier class to predict the aesthetic score of an image.
    """
    def __init__(self, model_path, clip_model="ViT-L/14", device='cpu'):
        self.device = device
        self.model = MLP(768)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        self.model2, self.preprocess = clip.load(clip_model, device=device)
        self.differentiable_transform = DifferentiableTransform(224) # this one is used to replace the preprocess function for clip

    def predict(self, img):
        r"""
        Predict the aesthetic score of an image.
        Args:
            img (str or PIL.Image): The image for which to predict the aesthetic score.
        Returns:
            float: The predicted aesthetic score.
        """
        if isinstance(img, str):
            img = Image.open(img)
        img = self.preprocess(img).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = self.model2.encode_image(img)
        im_emb_arr = normalized(image_features.cpu().detach().numpy() )
        prediction = self.model(
                torch.from_numpy(im_emb_arr).to(self.device).type(torch.cuda.FloatTensor)
            )
        return prediction.item()
    
    def score(self, img, prompt=None):
        r"""
        Predict the aesthetic score of an image.
        Args:
            img (str or PIL.Image): The image for which to predict the aesthetic score.
            prompt (str): The prompt NOT used in this model, simply placeholder.
        Returns:
            float: The predicted aesthetic score.
        """
        if isinstance(img, str):
            img = Image.open(img)
            img = ToTensor()(img).to(self.device)
            print(f"Image shape: {img.shape}")
        elif isinstance(img, Image.Image):
            img = ToTensor()(img).to(self.device)

        # img.requires_grad = True
        img_transformed = self.differentiable_transform.forward(img).unsqueeze(0).to(self.device) # use the name other than img

        image_features = self.model2.encode_image(img_transformed)
        im_emb_arr = torch_normalized(image_features)
        prediction = self.model(
                im_emb_arr.to(self.device).type(torch.cuda.FloatTensor)
            )

        # get the gradient of preprocessed image with respect to the model output
        # prediction.backward()
        # gradient = img.grad
        # gradient = None
        # print(f"Gradient with shape {gradient.shape}, original image shape: {img.shape}")
        return prediction