from transformers import AutoTokenizer, AutoConfig, PreTrainedModel, AutoModel, AutoModelForSequenceClassification
import torch.nn as nn
import torch
from tqdm import tqdm
import json
import os
import math
from concurrent.futures import ThreadPoolExecutor

class RewardModel():
    def __init__ (self, path, device:str):
        self.path = path
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(
                    self.path,
                    num_labels=1,
                    torch_dtype=torch.bfloat16,
                    attn_implementation="flash_attention_2",
                    use_cache=False,
                    device_map=device
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.path)
        self.tokenizer.padding_side = "left"

    def get_reward(self,text_list):
        inputs = self.tokenizer(
            text_list,
            return_tensors="pt",
            padding="max_length",      
            max_length=8192,             
            truncation=True            
        )

        inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}

        with torch.no_grad():
            reward = self.reward_model.model(**inputs).last_hidden_state
            reward = self.reward_model.score(reward)[:, -1]
        return reward.squeeze(-1).float().tolist()
    
def eval(data_path,if_train=False):
    model_path = '/path/to/model'
    rm1 = RewardModel(model_path,'cuda:0')
    rm2 = RewardModel(model_path,'cuda:1')
    file_name = os.path.basename(data_path)            
    data_name = os.path.splitext(file_name)[0]           
    with open(data_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)
    tokenizer = rm1.tokenizer
    
    batch_sentence1 = []
    batch_sentence2 = []
    for item in tqdm(datas, total=len(datas), desc=f'Eval {data_name}...'):
        if not if_train:
            question = item['question']
            response1 = item['response_A']
            response2 = item['response_B']
        if if_train:
            question = item['instruction']
            response1 = item['chosen']
            response2 = item['rejected']
        sentence1 = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response1}
        ]
        sentence2 = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response2}
        ]
        sentence1 = tokenizer.apply_chat_template(sentence1, tokenize=False)
        sentence2 = tokenizer.apply_chat_template(sentence2, tokenize=False)
        batch_sentence1.append(sentence1)
        batch_sentence2.append(sentence2)

    batch_size = 8
    total_num = 0
    current_num = 0
    total = math.ceil(len(batch_sentence1)/batch_size)
    for i in tqdm(range(0,len(batch_sentence1), batch_size), desc='batch process...',total=total):
        sents1_batch = batch_sentence1[i: i + batch_size]
        sents2_batch = batch_sentence2[i: i + batch_size]
        with ThreadPoolExecutor(max_workers=2) as executor:
            future_1  = executor.submit(rm1.get_reward, sents1_batch)
            future_2  = executor.submit(rm2.get_reward, sents2_batch)

            reward_1    = future_1.result()    
            reward_2    = future_2.result()

        for item, r1, r2 in zip(datas[i: i + batch_size],reward_1,reward_2):
            if r1 > r2:
                choice = 'A'
            else:
                choice = 'B'
            if not if_train:
                if item['good_response'] != choice:
                    current_num += 1
            if if_train:
                if choice == 'B':
                    current_num += 1
            total_num += 1
    output_text = (
        f"dataset: {data_name}\n"
        f"finished: success num: {current_num} / total: {total_num}\n"
        f"success rate = {current_num / total_num:.4f}\n\n"
    )
    log_path = f'/path/to/log'
    with open(log_path, 'a', encoding='utf-8') as f:
        f.write(output_text)
    
if __name__ == '__main__':
    data_list = [
        '/path/to/data'
    ]

    for data_path in data_list:
        eval(data_path)
