from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import json
import math
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor


class RewardModel():
    def __init__ (self, path, device):
        self.path = path
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(
                    self.path,
                    num_labels=1,
                    torch_dtype=torch.bfloat16,
                    attn_implementation="flash_attention_2",
                    use_cache=False,
                    device_map=device
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.path)
        self.tokenizer.padding_side = "left"

    def get_reward(self,text_list):
        inputs = self.tokenizer(
            text_list,
            return_tensors="pt",
            padding="max_length",      
            max_length=8192,             
            truncation=True            
        )

        inputs = {k: v.to(self.reward_model.device) for k, v in inputs.items()}

        with torch.no_grad():
            reward = self.reward_model.model(**inputs).last_hidden_state
            reward = self.reward_model.score(reward)[:, -1]
        return reward.squeeze(-1).float().tolist()

def filter_adv_sample(data_path, pref:str):
    model_path = 'path/to/model'
    rm1 = RewardModel(model_path, 'cuda:1')
    rm2 = RewardModel(model_path, 'cuda:2')
    rm3 = RewardModel(model_path, 'cuda:3')
    with open(data_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)

    batch_sentence1 = []
    batch_sentence2 = []
    batch_orig = []
    for item in datas:
        query = item['question']
        orig_resp = item[pref]
        response_1 = item['response'][0]
        response_2 = item['response'][1]

        orig = [
            {"role": "user", "content": query},
            {"role": "assistant", "content": orig_resp}
        ]
        
        message_1 = [
            {"role": "user", "content": query},
            {"role": "assistant", "content": response_1}
        ]
        message_2 = [
            {"role": "user", "content": query},
            {"role": "assistant", "content": response_2}
        ]

        message_1 = rm1.tokenizer.apply_chat_template(message_1, tokenize=False)
        message_2 = rm1.tokenizer.apply_chat_template(message_2, tokenize=False)
        orig = rm1.tokenizer.apply_chat_template(orig, tokenize=False)
        batch_sentence1.append(message_1)
        batch_sentence2.append(message_2)
        batch_orig.append(orig)

    batch_size = 16
    score_list = []
    total = math.ceil(len(batch_sentence1)/batch_size)
    for i in tqdm(range(0,len(batch_sentence1), batch_size), desc='batch process...',total=total):
        sents1_batch = batch_sentence1[i: i + batch_size]
        sents2_batch = batch_sentence2[i: i + batch_size]
        orig_batch = batch_orig[i: i + batch_size]
        # orig_reward = rm1.get_reward(orig_batch)
        # reward_1 = rm2.get_reward(sents1_batch)
        # reward_2 = rm3.get_reward(sents2_batch)
        with ThreadPoolExecutor(max_workers=3) as executor:
            future_orig  = executor.submit(rm1.get_reward, orig_batch)
            future_1     = executor.submit(rm2.get_reward, sents1_batch)
            future_2     = executor.submit(rm3.get_reward, sents2_batch)

            orig_reward = future_orig.result()  
            reward_1    = future_1.result()     
            reward_2    = future_2.result()     

        for r0, r1, r2 in zip(orig_reward,reward_1,reward_2):
            if pref == 'chosen':
                r10 = r1 - r0
                r20 = r2 - r0
                if r10 > r20:
                    cur = {
                        "rejected_win": 0,
                        "rejected_hack_score": r10
                    }
                else:
                    cur = {
                        "rejected_win": 1,
                        "rejected_hack_score": r20
                    }
            elif pref == 'rejected':
                r01 = r0 - r1
                r02 = r0 - r2
                if r01 > r02:
                    cur = {
                        "chosen_win": 0,
                        "chosen_hack_score": r01
                    }
                else:
                    cur = {
                        "chosen_win": 1,
                        "chosen_hack_score": r02
                    }
                
            print(cur)
            score_list.append(cur)
        save_path = data_path.replace(".json",f"_only_with_adv_chosen_score.json")
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(score_list,f,indent=2,ensure_ascii=False)
    print(f'save success！✅')

def combin_data(data_path, score_path,pref):
    with open(data_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)
    
    with open(score_path, 'r', encoding='utf-8') as f:
        score_list = json.load(f)

    assert len(score_list) == len(datas), "error！"

    for item, score_info in tqdm(zip(datas, score_list),total=len(datas)):
        item[f'{pref}_win'] = score_info[f'{pref}_win']
        item[f'{pref}_hack_score'] = score_info[f'{pref}_hack_score']

    with open(data_path, 'w', encoding='utf-8') as f:
        json.dump(datas, f, ensure_ascii=False, indent=2)

    print(f'{pref}_hack_score和{pref}_win combine finished✅')

if __name__ == '__main__':
    data_path = 'path/to/data'
    score_path = data_path.replace(".json",f"_only_with_adv_chosen_score.json")
    filter_adv_sample(data_path,'rejected')
    combin_data(data_path,score_path,'chosen')

