import torch
import json
import re
from openrlhf.models import get_llm_for_sequence_regression
from openrlhf.utils import get_processor, get_strategy, get_tokenizer


# def reward_func(queries, prompts, labels, **kwargs):
#     # queries is prompts + responses
#     # labels is answers
#     print(queries)
#     return torch.randint(0, 2, (len(queries),)).float()

MAX_LEN = 2048

def compute_Tx(samples, file_path, eval=False):
    average_proxy_rewards = []
    completion = samples[0][-1 * len('<|endoftext|>'):] == '<|endoftext|>'

    # '<|prompter|>Identify the most suitable adverb for the following sentence\nShe works<|endoftext|><|assistant|>She works diligently.<|endoftext|>'
    def format_samples(prompts, outputs, completion):
        return [
            # f"<|prompter|>{prompt}<|endoftext|><|assistant|>{output}{'<|endoftext|>' if eval else ''}"
            f"{prompt}{output}{'<|endoftext|>' if completion else ''}"
            for prompt, output in zip(prompts, outputs)
        ]
    
    def normalize_text(text):
        # 去除多余空格和换行符，并统一多个空格为一个空格
        return text.replace(" ", "").lower()

    def find_proxy_score_by_prompts(file_path, target_prompts):
        results = {prompt: None for prompt in target_prompts}
        exact_target_set = set(target_prompts)
        normalized_target_map = {normalize_text(p): p for p in target_prompts}

        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                prompt = data.get("prompt", "")
                # 先尝试精确匹配
                if prompt in exact_target_set:
                    results[prompt] = data.get("proxy_score_0", [])
                else:
                    # 再尝试归一化后匹配
                    norm_prompt = normalize_text(prompt)
                    if norm_prompt in normalized_target_map:
                        original_prompt = normalized_target_map[norm_prompt]
                        if results[original_prompt] is None:
                            results[original_prompt] = data.get("proxy_score_0", [])

                if all(results[p] is not None for p in target_prompts):
                    break
        return results

    original_prompts = []
    for i in range(len(samples)):
        tmp_text = samples[i]
        match = re.search(r"<\|prompter\|>(.*?)<\|endoftext\|>", tmp_text, re.DOTALL)
        if match:
            tmp_text = match.group(1).strip()
            original_prompts.append(f"<|prompter|>{tmp_text}<|endoftext|><|assistant|>")

    average_proxy_rewards = find_proxy_score_by_prompts(file_path, original_prompts)
    
    return average_proxy_rewards


class reward_score():
    def __init__(self, model_names=None):
        model_names = ['LxzGordon/URM-LLaMa-3-8B', ]
        value_head_prefix = 'value_head'
        sft_generation_score_file = ''
        self.reward_models = []
        self.tokenizers = []

        strategy = get_strategy(args)
        for model_name in model_names:
            reward_model = get_llm_for_sequence_regression(
                            model_name,
                            "reward",
                            normalize_reward=True,
                            use_flash_attention_2=False,
                            bf16=True,
                            value_head_prefix=value_head_prefix,
                        )
            reward_model.eval()
            reward_model.requires_grad_(False)
            reward_model.to(torch.cuda.current_device())
            tokenizer = get_tokenizer(args.pretrain, model, "left", strategy, use_fast=not args.disable_fast_tokenizer)
            # prepare model
            model = strategy.prepare(model)
            model.eval()
            self.reward_models.append(reward_model)
            self.tokenizers.append(tokenizer)

    def compute_func(self, queries):

        pass

def reward_func(queries, prompts, labels, **kwargs):
    # queries is prompts + responses
    # labels is answers
    average_proxy_rewards = compute_Tx(queries, eval)
    # average_proxy_rewards = 0

    return 