from transformers import AutoTokenizer, AutoConfig, PreTrainedModel, AutoModel, AutoModelForSequenceClassification
import torch.nn as nn
import torch
from tqdm import tqdm
import json
import os
import math
from concurrent.futures import ThreadPoolExecutor

class RewardModel():
    def __init__ (self, path, device:str):
        self.path = path
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(
                    self.path,
                    num_labels=1,
                    torch_dtype=torch.bfloat16,
                    attn_implementation="flash_attention_2",
                    use_cache=False,
                    device_map=device
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.path)
        self.tokenizer.padding_side = "left"

    def get_reward(self,text_list):
        inputs = self.tokenizer(
            text_list,
            return_tensors="pt",
            padding="max_length",      
            max_length=8192,             
            truncation=True            
        )

        inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}

        with torch.no_grad():
            reward = self.reward_model.model(**inputs).last_hidden_state
            reward = self.reward_model.score(reward)[:, -1]
        return reward.squeeze(-1).float().tolist()

def eval(model_path, data_path, log_path):
    rm1 = RewardModel(model_path,'cuda:0')
    rm2 = RewardModel(model_path,'cuda:1')
    file_name = os.path.basename(data_path)            
    data_name = os.path.splitext(file_name)[0]           
    with open(data_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)
    tokenizer = rm1.tokenizer
    
    batch_sentence1 = []
    batch_sentence2 = []
    for item in tqdm(datas, total=len(datas), desc=f'Eval {data_name}...'):
        question = item['question']
        response1 = item['response_A']
        response2 = item['response_B']
        sentence1 = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response1}
        ]
        sentence2 = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": response2}
        ]
        sentence1 = tokenizer.apply_chat_template(sentence1, tokenize=False)
        sentence2 = tokenizer.apply_chat_template(sentence2, tokenize=False)
        batch_sentence1.append(sentence1)
        batch_sentence2.append(sentence2)

    batch_size = 8
    total_num = 0
    current_num = 0
    total = math.ceil(len(batch_sentence1)/batch_size)
    for i in tqdm(range(0,len(batch_sentence1), batch_size), desc='batch process...',total=total):
        sents1_batch = batch_sentence1[i: i + batch_size]
        sents2_batch = batch_sentence2[i: i + batch_size]
        with ThreadPoolExecutor(max_workers=2) as executor:
            future_1  = executor.submit(rm1.get_reward, sents1_batch)
            future_2  = executor.submit(rm2.get_reward, sents2_batch)

            reward_1    = future_1.result()    
            reward_2    = future_2.result()

        for item, r1, r2 in zip(datas[i: i + batch_size],reward_1,reward_2):
            if r1 > r2:
                answer = 'A'
            else:
                answer = 'B'
            if item['good_response'] == answer:
                current_num += 1
            total_num += 1
    output_text = (
        f"dataset: {data_name}\n"
        f"finished：current {current_num} / total {total_num}\n"
        f"acc = {current_num / total_num:.4f}\n\n"
    )
    with open(log_path, 'a', encoding='utf-8') as f:
        f.write(output_text)

if __name__ == '__main__':
    log_path = './log/{model_name}_test.log'
    model_path = [
        'path/to/model'
    ]

    data_lists = [
        'JudgeBench',
        'JudgeBench_A^2_adv',
        'JudgeBench_style_adv',
        'RewardBench',
        'RewardBench_A^2_adv',
        'RewardBench_style_adv',
        'HelpSteer3Eval',
        'HelpSteer3Eval_A^2_adv',
        'HelpSteer3Eval_style_adv'
    ]
    
    for model_name in model_path:
        for data_path in data_lists:
            name = model_name.split("/")[-1]
            eval(model_name, data_path, log_path.replace("{model_name}",name))
