              
                                                      
                                          
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor

from tasks.math_rl_v3.math_rule_rm import (
    cal_accuracy_reward,
    cal_format_reward,
    validate_samples_useful,
)
from gpatch.core.utils import list_for_tensor_tolist
from gpatch.core.models.gpt import GptPpoCriticModel
from megatron.training import get_tokenizer
from megatron.core import mpu
from typing_extensions import override
from typing import List, Dict, Union, Any
import re
import string
import torch

from transformers import AutoTokenizer
import requests

import os
ADVANTAGE_METHOD = os.environ.get("ADVANTAGE_METHOD", None)
assert ADVANTAGE_METHOD is not None
print(f"Using ADVANTAGE_METHOD {ADVANTAGE_METHOD} for ppo critic")


                    
current_file = str(Path(__file__).resolve())
coherent_SG_URL = "TODO"
relevance_SG_URL = "TODO"

if "reasoning_reward" in ADVANTAGE_METHOD:
    coherent_MODEL_PATH = "../../../models/models--nvidia--Llama-3.3-Nemotron-70B-Reward-Multilingual"
    coherent_tokenizer = AutoTokenizer.from_pretrained(coherent_MODEL_PATH)
    relevance_MODEL_PATH = "../../../models/models--Qwen--Qwen3-Reranker-8B"
    relevance_tokenizer = AutoTokenizer.from_pretrained(
        relevance_MODEL_PATH, padding_side='left')
    token_true_id = relevance_tokenizer.convert_tokens_to_ids("yes")
    prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
    suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
    prefix_tokens = relevance_tokenizer.encode(
        prefix, add_special_tokens=False)
    suffix_tokens = relevance_tokenizer.encode(
        suffix, add_special_tokens=False)


def get_coherent_score_sglang(prompt_strs, response_strs):
    """
    调用 sglang server 计算每个 (prompt, response) 的 reward（logprob）

    参数:
        prompt_strs: list[str]  用户输入的 prompt
        response_strs: list[str]  模型生成的 response

    返回:
        list[float]  每个样本的 reward（logprob）
    """
    assert len(prompt_strs) == len(
        response_strs), "prompt_strs 和 response_strs 长度必须一致"

                   
    batch_texts = []
    for p, r in zip(prompt_strs, response_strs):
        msg = [
            {"role": "user", "content": p},
            {"role": "assistant", "content": r}
        ]
        batch_texts.append(msg)

                              
    batch_token_ids = coherent_tokenizer.apply_chat_template(
        batch_texts, tokenize=True, add_generation_prompt=False
    )

                         
    try:
        resp = requests.post(
            f"{coherent_SG_URL}/generate",
            json={
                "input_ids": batch_token_ids,
                "sampling_params": {
                    "temperature": 0.0,
                    "max_new_tokens": 1,
                },
                "return_logprob": True,
                "token_ids_logprob": [0]
            }
        )
        resp.raise_for_status()
        data = resp.json()

                      
        rewards = []
        for sample in data:
                            
            assert sample['meta_info']['output_token_ids_logprobs'][0][0][1] == 0
            rewards.append(sample['meta_info']
                           ['output_token_ids_logprobs'][0][0][0])

        return rewards

    except Exception as e:
        print(f"[get_coherent_score_sglang] 请求失败: {e}")
        return [None] * len(prompt_strs)


def get_relevance_score_sglang(prompt_strs, response_strs):
    """
    使用 sglang server 对 Query + Document 进行 yes/no 评分，返回 yes 概率列表

    参数:
        prompt_strs: list[str]  每个元素是 Query
        response_strs: list[str] 每个元素是 Document

    返回:
        list[float]  每个元素是 yes 概率
    """
    assert len(prompt_strs) == len(response_strs), "输入列表长度不一致"

          
    task_relevance_all = """
    You are given a Query and a Document. The Query contains a user question and related information, and the Document contains a reasoning process that leads to a final answer. Your task is to retrieve the Document with reasoning that:
    - is relevant to the information in the Query.
    - directly helps to answer the user question.
    - is logically coherent and leads to the final answer.
    """.strip()

    def format_instruction(instruction, query, doc):
        return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

             
    pairs = [
        format_instruction(task_relevance_all, q, d)
        for q, d in zip(prompt_strs, response_strs)
    ]

                   
    batch_token_ids = []
    for text in pairs:
        encoded = relevance_tokenizer.encode(text, add_special_tokens=False)
        input_ids = prefix_tokens + encoded + suffix_tokens
        batch_token_ids.append(input_ids)

                      
    try:
        resp = requests.post(
            f"{relevance_SG_URL}/generate",
            json={
                "input_ids": batch_token_ids,
                "sampling_params": {
                    "temperature": 0.0,
                    "max_new_tokens": 1
                },
                "return_logprob": True,
                                          
                "token_ids_logprob": [token_true_id]
            }
        )
        resp.raise_for_status()
        data = resp.json()
        scores = []
        for idx, sample in enumerate(data):
            logprob = sample['meta_info']['output_token_ids_logprobs'][0][0][0]
            prob_yes = pow(2.718281828, logprob)             
            scores.append(prob_yes)
        return scores
    except Exception as e:
        print(f"[get_relevance_score_sglang] 请求失败: {e}")
        return [None] * len(batch_token_ids)


def extract_answer_positions(msg, tokenizer):
    matches = list(re.finditer(r"<answer>(.*?)</answer>", msg, re.DOTALL))
    if not matches:
        return {"answer_start": -1, "answer_end": -1}

    last_match = matches[-1]
    start_char = last_match.start(1)
    end_char = last_match.end(1)

    enc = tokenizer(msg, return_offsets_mapping=True)
    offsets = enc["offset_mapping"]

    answer_start_token = None
    answer_end_token = None

                                             
    for i, (start, end) in enumerate(offsets):
        if end > start_char:
            answer_start_token = i
            break

                                            
    for i, (start, end) in enumerate(offsets):
        if start >= end_char:
            answer_end_token = i
            break

                        
    if answer_end_token is None:
        answer_end_token = len(offsets)

                         
    if answer_start_token is None:
        answer_start_token = 0

    return {
        "answer_start": answer_start_token,
        "answer_end": answer_end_token,
    }


def extract_tag(string, tag, check=True):
    start_tag = f"<{tag}>"
    end_tag = f"</{tag}>"
    if check:
        assert end_tag in string, f"{end_tag} not found"
        assert string.count(start_tag) <= 1 and string.count(
            end_tag) <= 1, f"{tag} should only appear once in {string}"
    if start_tag in string:
        res = string.split(start_tag)[-1].split(end_tag)[0]
    else:
        return ""
    return res.strip()


def normalize_answer(text: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(s):
        return re.sub(r'\b(a|an|the)\b', ' ', s)

    def white_space_fix(s):
        return ' '.join(s.split())

    def remove_punc(s):
        return ''.join(ch for ch in s if ch not in set(string.punctuation))

    def lower(s):
        return s.lower()
    return white_space_fix(remove_articles(remove_punc(lower(text))))


def compute_em(prediction: str, gold: str):
    """Compute Exact Match (EM) between two strings."""
    return normalize_answer(prediction) == normalize_answer(gold)


class RuleGptPpoCriticModel(GptPpoCriticModel):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def infer_rule_based_rm(
        self,
        rewards,
        per_token_rewards=None,
        sequence_lengths: torch.Tensor = None,
        prompt_lengths: torch.Tensor = None,
        batches: List[Dict[str, Union[int, List[Any]]]] = None,
    ):                                                                    

        tokenizer = get_tokenizer()

        acc_reward_tensor = None
        fmt_reward_tensor = None
                            
        is_mp_head = mpu.is_pipeline_first_stage(
        ) and mpu.get_tensor_model_parallel_rank() == 0
        if not is_mp_head:
            assert rewards is None
            assert per_token_rewards is None
        else:
            if self.config.ppo_grpo_reward_type == "rm_only":
                acc_reward_tensor = torch.zeros_like(rewards).to(rewards.dtype)
                fmt_reward_tensor = torch.zeros_like(rewards).to(rewards.dtype)
                return rewards, per_token_rewards, {
                    'rm_rewards': rewards,
                }

            inputs_list: List[torch.Tensor] = []
                                                  
            gt_label_list = []
            prompt_length_list = []
            for batch in batches:
                                                       
                inputs_list.extend(batch["tokens"])
                                                
                gt_label_list.extend(batch["gt_label"])

            tokens_cpu: List[List[int]] = list_for_tensor_tolist(
                inputs_list, False)
            seq_len_cpu: List[int] = sequence_lengths.tolist()
            prompt_length_list: List[int] = prompt_lengths.tolist()
                                             
            gt_label = gt_label_list
            assert len(tokens_cpu) == len(seq_len_cpu)
            assert len(tokens_cpu) == len(gt_label)

            for i in range(len(tokens_cpu)):
                tokens_cpu[i] = tokens_cpu[i][:seq_len_cpu[i]]

            resp_strs = tokenizer._tokenizer.batch_decode(
                tokens_cpu, skip_special_tokens=False)
            prompt_strs = tokenizer._tokenizer.batch_decode(
                [tokens_cpu[i][:prompt_length_list[i]] for i in range(len(tokens_cpu))], skip_special_tokens=True
            )
            response_strs = tokenizer._tokenizer.batch_decode(
                [tokens_cpu[i][prompt_length_list[i]:] for i in range(len(tokens_cpu))], skip_special_tokens=True
            )

            if "reasoning_reward" in ADVANTAGE_METHOD:
                with ThreadPoolExecutor(max_workers=2) as executor:
                    if "GT" in ADVANTAGE_METHOD:
                        prompt_strs_gt = [prompt+f"\n\nWe also provide you with the correct answer to this question.\n[Correct Answer]: {gt}" for prompt,
                                          gt in zip(prompt_strs, gt_label)]
                        prompt_strs = prompt_strs_gt
                    future_coherent = executor.submit(
                        get_coherent_score_sglang, prompt_strs, response_strs)
                    future_relevance = executor.submit(
                        get_relevance_score_sglang, prompt_strs, response_strs)

                    coherent_score = future_coherent.result()
                    relevance_score = future_relevance.result()
            else:
                coherent_score = [0.2] * len(gt_label)
                relevance_score = [0.3] * len(gt_label)
                                                         
                                             
                                   
                                                           
                                             
                                    
            answer_start_list = []
            answer_end_list = []
            for i in range(len(tokens_cpu)):
                extract_result = extract_answer_positions(
                    response_strs[i], tokenizer._tokenizer)
                                   
                answer_start_list.append(
                    extract_result["answer_start"]+prompt_length_list[i])
                answer_end_list.append(
                    extract_result["answer_end"]+prompt_length_list[i])

            test1_tensor = torch.ones(len(gt_label),
                                      dtype=torch.float32,
                                      device=torch.cuda.current_device()).view(-1, 1)*0.1
            relevance_reward = torch.tensor(relevance_score,
                                            dtype=torch.float32,
                                            device=torch.cuda.current_device()).view(-1, 1)
            coherent_reward = torch.tensor(coherent_score,
                                           dtype=torch.float32,
                                           device=torch.cuda.current_device()).view(-1, 1)
            test4_tensor = torch.ones(len(gt_label),
                                      dtype=torch.float32,
                                      device=torch.cuda.current_device()).view(-1, 1)*0.4
            answer_start_tensor = torch.tensor(
                answer_start_list, dtype=torch.int32, device=torch.cuda.current_device()).view(-1, 1)
            answer_end_tensor = torch.tensor(
                answer_end_list, dtype=torch.int32, device=torch.cuda.current_device()).view(-1, 1)
            rule_reward = []
            for response_str, gt in zip(response_strs, gt_label):
                response = extract_tag(
                    response_str, "answer", check=False)
                if compute_em(response, gt):
                    rule_reward.append(1.0)
                else:
                    rule_reward.append(0.0)

            rule_reward = torch.tensor(
                rule_reward, dtype=torch.float32, device=torch.cuda.current_device()).view(-1, 1)

            if self.config.ppo_grpo_reward_type == "rule_only":
                return rule_reward, None, {
                    'confidence_reward': test1_tensor,
                    'relevance_reward': relevance_reward,
                    'coherent_reward': coherent_reward,
                    'critic_reward': test4_tensor,
                    'answer_start': answer_start_tensor,
                    'answer_end': answer_end_tensor,

                }
            elif self.config.ppo_grpo_reward_type == "rm_with_rule":
                                                                                    
                combined_rewards = torch.sigmoid(self.config.ppo_rm_reward_alpha * rewards) + \
                    (self.config.ppo_rule_reward_beta * rule_reward - 1)
                                                 
                return combined_rewards, per_token_rewards, {
                    'rm_rewards': rewards,
                    'acc_rewards': acc_reward_tensor,
                    'fmt_rewards': fmt_reward_tensor,
                }

        return rewards, per_token_rewards, {
            'rm_rewards': rewards,
            'acc_rewards': acc_reward_tensor,
            'fmt_rewards': fmt_reward_tensor,
        }

    @override
    def validate_samples(self, rewards, sampling_repeat=None):
        """"
        Checks the validity of the given samples and returns a dictionary with the results.

        Parameters:
        rewards: a tensor of rewards, shape [b, 1].

        Returns:
        dict: a dict regarding the usefulness of samples. example:
                {'sample_useful': tensor of usefulness ([b])}
        """
        check_result = None
        is_mp_head = mpu.is_pipeline_first_stage(
        ) and mpu.get_tensor_model_parallel_rank() == 0
        if not is_mp_head:
            assert rewards is None
        else:
            check_result = {
                "sample_useful": validate_samples_useful(rewards, sampling_repeat)
            }
        return check_result
