import argparse
import copy
import json
import os
import traceback

from vllm import LLM, SamplingParams
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer

MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"

def remove_boxed(raw_response):
    s = last_boxed_only_string(raw_response)
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def prompt_format(messages, tokenizer):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    return text

def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data

def math_verify_reward_function(solution_str, ground_truth):
    from math_verify import parse, verify
    ground_truth = [ground_truth] if isinstance(ground_truth, str) else ground_truth

    try:
        math_verify_parsed = parse(solution_str, parsing_timeout=5)
    except Exception:
        return 0.0

    if len(math_verify_parsed) < 2:
        return 0.0

    if math_verify_parsed[1] in ground_truth:
        return 1.0

    for gt in ground_truth:
        try:
            if verify(
                    parse(f"\\boxed{{{gt}}}", parsing_timeout=5),
                    math_verify_parsed,
                    timeout_seconds=5,
            ):
                return 1.0
        except Exception:
            continue

    return 0.0

def compute_score_mathverify_judge(solution_str: str, ground_truth: str):
    if "</think>" in solution_str:
        solution_str = solution_str.split("</think>")[1]
    else:
        solution_str = solution_str

    try:
        return math_verify_reward_function(solution_str, ground_truth)
    except:
        traceback.print_exc(10)
        return False

def judge_two_texts_same(model, code_results_lists, dv3_results_lists, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
    model_name = model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    hendrycks_math_ins = []

    system_prompt = """## **Role and Objective**

You are a meticulous scientific analyst. Your mission is to determine if **Text A** and **Text B** are substantively equivalent in their core results.

Focus strictly on **numerical values, mathematical expressions, and final conclusions.** Ignore all differences in wording, formatting, and sentence structure.

## **Core Principles**

1.  **Content Over Form:** Prioritize the core message and data. Completely ignore stylistic choices like wording, formatting (bolding, spacing), and sentence construction.
2.  **Standardize Before Comparing:** Before comparison, you must normalize all values to a common standard. Convert percentages to decimals (`50%` → `0.5`), unify units (`1m` → `100cm`), and resolve scientific notation.
3.  **Compare Overlapping Information Only:** If one text contains data information absent in the other, ignore the missing parts. Base your comparison solely on the data and claims present in *both* texts. Asymmetry is not a basis for inequivalence.

## **Evaluation Criteria**

### **1. Numerical Equivalence**

*   **Rule:** Two numbers, `A` and `B`, are equivalent if their absolute difference is less than or equal to `1.0` after normalization.
*   **Formula:** `|A - B| ≤ 1.0`
*   **Examples:**
    *   **Equivalent**: `|1.453125 - 1.390625| = 0.0625` (≤ 1.0)
    *   **Equivalent**: `|-32.015744 - (-32.515744)| = 0.5` (≤ 1.0)
    *   **Inequivalent**: `|12.91829 - 11.78172| = 1.136` (> 1.0)
*   **Clarifications:**
    *   Differences in significant figures or decimal places are acceptable as long as the absolute difference rule is met.
    *   Numerical signs must match (e.g., `5` and `-5` are not equivalent).

### **2. Mathematical Expression Equivalence**

*   **Rule:** Assess if expressions are mathematically equivalent, not just if they are written identically.
*   **Examples of Equivalence:**
    *   **Commutativity/Associativity:** `a + b` is equivalent to `b + a`
    *   **Alternative Forms:** `x/2` is equivalent to `0.5x`
    *   **Factoring/Expansion:** `x² - 1` is equivalent to `(x - 1)(x + 1)`
    *   **Variable Renaming:** `f(x) = x²` is equivalent to `g(y) = y²`
*   **Important Caveat:** If a transformation alters the domain in a way that impacts the conclusion (e.g., introduces division by zero), the expressions are not equivalent.

### **3. Conclusion Equivalence**

*   **Rule:** The final answers or main conclusions must align.
    *   **Numerical Conclusions:** Must meet the numerical equivalence standard defined above.
    *   **Categorical Conclusions:** Must be identical (e.g., "Positive" vs. "Positive"; "Category A" vs. "Category A").

## **Strict Output Format**

1.  **Reasons:** Provide a concise, objective, bulleted list explaining your rationale.
2.  **Verdict:** State the final decision: `True` (for equivalent) or `False` (for inequivalent).
3.  **Boxed:** Wrap the final boolean value in `\\boxed{{}}`, for example, `\\boxed{{True}}` or `\\boxed{{False}}`.

## **Texts to Evaluate**

[Text A]
{Text_A}

[Text B]
{Text_B}    


"""
    for code_res, dv3_res in zip(code_results_lists, dv3_results_lists):
        instruct = system_prompt.format(Text_A=code_res, Text_B=dv3_res)
        temp_convs = [{"content": instruct, "role": "user"}]
        prompt_ins = prompt_format(temp_convs, tokenizer)
        hendrycks_math_ins.append(prompt_ins)

    print('total length ===', len(hendrycks_math_ins))
    if end == MAX_INT:
        end = len(hendrycks_math_ins)
    hendrycks_math_ins = hendrycks_math_ins[start:end]
    print('lenght ====', len(hendrycks_math_ins))

    batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)

    stop_tokens=["<|im_end|>","<|endoftext|>"]
    sampling_params = SamplingParams(temperature=0.7, top_p=0.8, max_tokens=4096, stop=stop_tokens, ignore_eos=False, skip_special_tokens=False)

    print('sampleing =====', sampling_params)
    llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
    res_completions = []
    for idx, (prompt) in enumerate(batch_hendrycks_math_ins):
        if isinstance(prompt, list):
            pass
        else:
            prompt = [prompt]
        completions = llm.generate(prompt, sampling_params)

        for output in completions:
            prompt_temp = output.prompt
            generated_text = output.outputs[0].text
            res_completions.append(generated_text)

    judge_res_list = []
    for item in res_completions:
        extract_text = last_boxed_only_string(item)
        if extract_text is not None and "true" in extract_text.lower():
            judge_res_list.append(True)
        else:
            judge_res_list.append(False)

    print("==================end=====================")
    return res_completions, judge_res_list



def combin(save_vllm_dir):
    all_files = ["start_0_to_end_11000.json"]
    all_res_completions = []
    all_judge_res_list = []
    for fi in all_files:
        dv3_python_path = f"{save_vllm_dir}/{fi}"
        if os.path.exists(dv3_python_path):
            with open(dv3_python_path, 'r', encoding='utf-8') as f:
                data_results_dict = json.load(f)
            res_completions = data_results_dict['res_completions']
            judge_res_list = data_results_dict['judge_res_list']
            all_res_completions.extend(res_completions)
            all_judge_res_list.extend(judge_res_list)

    save_vllm_results_path = f"{save_vllm_dir}/all_results.json"
    python_exe_results_dict = {"res_completions": all_res_completions, "judge_res_list": all_judge_res_list}
    print(f"save_vllm_results_path ==== {save_vllm_results_path}, len(python_exe_results_dict) === {len(python_exe_results_dict)}")
    with open(save_vllm_results_path, 'w', encoding='utf-8') as f:
        json.dump(python_exe_results_dict, f, ensure_ascii=False)

    return python_exe_results_dict


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-32B")
    parser.add_argument("--data_file", type=str, default='data/MATH_test.jsonl')
    parser.add_argument("--save_dir", type=str, default='/root')
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--end", type=int, default=MAX_INT)
    parser.add_argument("--batch_size", type=int, default=50)
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--max_tokens", type=int, default=16384)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    start=args.start
    end=args.end

    dv3_python_consistency_datas_path = args.data_file

    save_vllm_dir = args.save_dir

    os.makedirs(save_vllm_dir, exist_ok=True)
    save_vllm_results_path = f"{save_vllm_dir}/start_{start}_to_end_{end}.json"
    print(f"save_vllm_results_path ==== {save_vllm_results_path}")
    if os.path.exists(dv3_python_consistency_datas_path):
        with open(dv3_python_consistency_datas_path, 'r', encoding='utf-8') as f:
            data_results_dict = json.load(f)
        dv3_results_lists = data_results_dict['dv3_results_lists']
        all_tool_responses = data_results_dict['all_tool_responses']


    model = args.model
    tensor_parallel_size=4
    res_completions, judge_res_list = judge_two_texts_same(model=model, code_results_lists=all_tool_responses, dv3_results_lists=dv3_results_lists, start=start, end=end, batch_size=1000000, tensor_parallel_size=tensor_parallel_size)
    python_exe_results_dict = {"res_completions": res_completions, "judge_res_list": judge_res_list}
    print(f"save_vllm_results_path ==== {save_vllm_results_path}, len(python_exe_results_dict) === {len(python_exe_results_dict)}")
    mm = [(p, v, k) for p, v, k in zip(dv3_results_lists, all_tool_responses, judge_res_list) if k == False]
    with open(save_vllm_results_path, 'w', encoding='utf-8') as f:
        json.dump(python_exe_results_dict, f, ensure_ascii=False)


'''

CUDA_VISIBLE_DEVICES=0,1,2,3 python agentmath/data_synthesis/tool/qwen3_32b_judge_vllm.py --model Qwen/Qwen3-32B --data_file /root/vllm_judge_dv3_python_consistency.json --save_dir /root/data_path/consistency_judge_result --start 0 --end 11000 --batch_size 1000000 --tensor_parallel_size 4

'''