import torch
from transformers import AutoModel, AutoTokenizer

# TODO: should have abstract classes to define what all task reward fucns should take as input etc...

class RMTaskUsefulness():
    
    def __init__(
        self,
        model_name="internlm/internlm2-1_8b-reward",
        device='cuda',
    ):
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        
        self.model = AutoModel.from_pretrained(
            model_name,
            device_map=device,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        # self.model = None
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
    def get_reward(self, user_queries, answers, max_batch_size):
        # TODO: user query should just have the entire query (inc task prompt)
        raise NotImplementedError("User queries currently don't contain task prompt")
        n_items = len(user_queries)
        all_scores = []

        for start_idx in range(0, n_items, max_batch_size):
            end_idx = min(start_idx + max_batch_size, n_items)
            batch_queries = user_queries[start_idx:end_idx]
            batch_answers = answers[start_idx:end_idx]

            batch_chats = [
                [
                    {"role": "user", "content": query},
                    {"role": "assistant", "content": answer}
                ]
                for task_prompt, query, answer in zip(batch_queries, batch_answers)
            ]

            with torch.no_grad():
                batch_scores = self.model.get_scores(self.tokenizer, batch_chats)
            
            all_scores.extend(batch_scores)

        return all_scores
