# This file define a unified interface for getting reward models.
import os
import torch
import logging
from reward_models.aesthetic_score.reward_model import MODEL_PATH, AestheticClassifier
from reward_models.hps_v2_score.hps_score import HPSV2Score
from reward_models.imagereward_score.imagereward_score import load
from reward_models.pickscore_score.pickscore_score import PickScore
from typing import Literal

# experiments
from hack_reward_train.hacked_clip_model import HackedGreyReward

# factory functions for constructing reward models
def get_aestehtic_score_model(model_path=MODEL_PATH, device='cuda'):
    return AestheticClassifier(model_path, device=device)

def get_hpsv2_model(device='cuda'):
    return HPSV2Score(device=device)

def get_imagereward_model(*args):
    return load("ImageReward-v1.0")

def get_pickscore_model(device='cuda'):
    return PickScore(device=device)

def get_hacked_grey_reward_model(hacked_model_path="custom_reward_train/model/hack_reward_model.pth", 
                                 clip_model_path="openai/clip-vit-base-patch32",
                                 device='cuda'):
    return HackedGreyReward(hacked_model_path, clip_model_path, device=device)

class UnifiedReward:
    """
    A unified interface for getting reward models.
    """
    def __init__(
        self,
        model_name: Literal['aesthetic', 'hps_v2', 'imagereward', 'pickscore', 'hacked_grey_reward'],
        device: str = 'cuda',
    ):
        logging.info(f"Using unified reward model, model_name={model_name}")

        if model_name == "aesthetic":
            self.model = get_aestehtic_score_model(device=device)
        elif model_name == "hps_v2":
            self.model = get_hpsv2_model(device=device)
        elif model_name == "imagereward":
            self.model = get_imagereward_model()
        elif model_name == "pickscore":
            self.model = get_pickscore_model(device=device)
        elif model_name == "hacked_grey_reward":
            self.model = get_hacked_grey_reward_model(device=device)
        else:
            raise NotImplementedError(f"Model {model_name} not implemented")

    def score(self, img, prompt):
        """
        expect input a tensor
        """
        return self.model.score(img, prompt)
    