import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torchvision
from torchvision import transforms
from typing import List
import numpy as np
import torch
from einops import rearrange

import ImageReward as reward
import PIL


class ImageRewardScorer:
    """
    ImageReward preference model via imscore.
    Outputs a scalar reward per image given a text prompt.
    """

    def __init__(self, device="cuda"):
        self.device = device

        # Load ImageReward via imscore (HF-compatible, modern deps)
        self.model =  reward.load("ImageReward-v1.0")

    
        self.model.to(self.device)
        # self.model.eval()



    @torch.no_grad()
    def score_images(self, images, text_prompt):
        """
        images: List[PIL.Image]
        returns: Tensor [N]
        """
        rewards = self.model.score(text_prompt, images)
        return torch.tensor(rewards, device=self.device)
    

class ImageNetScorer:
    def __init__(self, device):
        self.model = torchvision.models.resnet50(weights="IMAGENET1K_V1")
        self.model.eval().to(device)
        self.preprocess = torchvision.transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])
        self.device = device

    @torch.no_grad()
    def score_images(self, images, class_id):
        batch = torch.stack([
            self.preprocess(img) for img in images
        ]).to(self.device)

        logits = self.model(batch)
        return logits[:, class_id]



class CLIPScorer:
    def __init__(self, device="cuda"):
        self.device = device
        self.model = CLIPModel.from_pretrained(
            "openai/clip-vit-base-patch32"
        ).to(device)
        self.processor = CLIPProcessor.from_pretrained(
            "openai/clip-vit-base-patch32",
            use_fast=True
        )
        self.model.eval()



    def score_images(self, images, text_prompt):
        """
        images: List[PIL.Image]
        returns: Tensor [N]
        """
        inputs = self.processor(
            text=[text_prompt],       
            images=images,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)
            scores = outputs.logits_per_image.squeeze(-1)
            # shape: (N,)

        return scores



