import argparse
import json
import os
from copy import deepcopy
import re

from sentence_transformers import SentenceTransformer

from LLMs import close_model, v3_batch
from logic_eval import logic_evaluation


def extract_from_boxed(answer):
    if answer is None:
        return None
    last_boxed_start = answer.rfind('\\boxed{')
    if last_boxed_start == -1:
        return None
    content_start = last_boxed_start + len('\\boxed{')
    brace_count = 1
    for i in range(content_start, len(answer)):
        if answer[i] == '{':
            brace_count += 1
        elif answer[i] == '}':
            brace_count -= 1
        if brace_count == 0:
            return answer[content_start:i]
    return answer[content_start:]


def scientific_notation_to_float(s):
    s = s.strip()
    pattern = r'(\d+(?:\.\d+)?)\s*\\times\s*10\^\{(-?\d+)\}'
    match = re.search(pattern, s)
    if match:
        coefficient = float(match.group(1))
        exponent = int(match.group(2))
        if exponent > 308 or exponent < -308:
            return None
        return coefficient * (10 ** exponent)
    return None


def extract_number(s):
    s = s.strip()
    pattern = r'-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?'
    match = re.search(pattern, s)
    if match:
        try:
            return float(match.group())
        except Exception:
            return None
    return None


def get_final_math_answer(content):
    if content is None:
        return None
    val = scientific_notation_to_float(content)
    if val is not None:
        return val
    val = extract_number(content)
    if val is not None:
        return val
    return None


def judge_choice(judge_dataset):
    results = []
    for q, gt, pred in judge_dataset:
        extracted = extract_from_boxed(pred)
        if extracted:
            m = re.search(r'[ABCD]', extracted)
            choice = m.group(0) if m else None
            correct = 1 if choice == gt else 0
        else:
            choice = None
            correct = 0
        results.append({
            "answer_gt": gt,
            "answer_pred": pred,
            "answer_pred_extracted": extracted,
            "extracted_choice": choice,
            "correct": correct,
        })
    return results


def judge_comp(judge_dataset, judger_llm, batch_size):
    remainder = len(judge_dataset) % batch_size
    padded = deepcopy(judge_dataset)
    if remainder != 0:
        padding_needed = batch_size - remainder
        padding_item = ["1+1", "2", "\\boxed{2.0}"]
        padded.extend([padding_item] * padding_needed)
    assert len(padded) % batch_size == 0

    prompt_path = os.path.join(os.path.dirname(__file__), "prompt", "LLM_judge.md")
    with open(prompt_path, "r", encoding="utf-8") as f:
        judge_prompt = f.read()

    messages_list = []
    for q, gt, pred in padded:
        extracted = extract_from_boxed(pred) or "no answer"
        data_prompt = judge_prompt.replace("{question}", q)
        data_prompt = data_prompt.replace("{pred}", extracted)
        data_prompt = data_prompt.replace("{gold}", gt)
        messages = [
            {"role": "system", "content": "You are a helpful assistant that judges the correctness of answers."},
            {"role": "user", "content": data_prompt},
        ]
        messages_list.append(messages)

    # Use fixed parallelism of 10 threads for judging
    num_parallel = 10
    llm_judge_results = judger_llm.batch_generate(
        messages_list=messages_list,
        num_workers=num_parallel,
        num_messages_per_worker=max(1, len(messages_list) // num_parallel),
    )

    raw_len = len(judge_dataset)
    results = []
    for i in range(raw_len):
        gt = judge_dataset[i][1]
        pred = judge_dataset[i][2]
        extracted = extract_from_boxed(pred)
        if extracted is None:
            results.append({
                "answer_gt": gt,
                "answer_pred": pred,
                "answer_pred_extracted": extracted,
                "correct": 0,
                "judge_raw_outputs": None,
                "use_numerical_judge": None,
            })
            continue
        llm_judgement = llm_judge_results[i]['answer']
        judge_raw_outputs = llm_judge_results[i]['raw_outputs']
        use_numerical = False
        gt_num = get_final_math_answer(gt)
        pred_num = get_final_math_answer(extracted)
        if (gt_num is not None) and (pred_num is not None):
            use_numerical = True
            if gt_num == 0:
                correct = 1 if abs(pred_num) < 0.001 else 0
            else:
                correct = 1 if abs(gt_num - pred_num) / abs(gt_num) <= 0.05 else 0
        else:
            correct = 1 if (llm_judgement and "A" in llm_judgement) else 0
        results.append({
            "answer_gt": gt,
            "answer_pred": pred,
            "answer_pred_extracted": extracted,
            "correct": correct,
            "judge_raw_outputs": judge_raw_outputs,
            "use_numerical_judge": use_numerical,
        })
    assert len(results) == raw_len
    return results


def inference(evaluated_llm, benchmark_dataset, batch_size):
    remainder = len(benchmark_dataset) % batch_size
    padded = deepcopy(benchmark_dataset)
    if remainder != 0:
        padding_needed = batch_size - remainder
        padded.extend(["hello"] * padding_needed)
    assert len(padded) % batch_size == 0

    messages_list = []
    for item in padded:
        messages_list.append(
            [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": f"{item}\nPlease wrap the choice of answer at the end of the solving process using `\\boxed{{}}` to highlight it."},
            ]
        )

    results = evaluated_llm.batch_generate(
        messages_list=messages_list,
        num_workers=batch_size,
        num_messages_per_worker=len(messages_list) // batch_size,
    )

    inference_results = []
    raw_len = len(benchmark_dataset)
    for i in range(raw_len):
        answer_pred = results[i]['answer']
        inference_raw_outputs = results[i]['raw_outputs']
        inference_results.append({
            "answer_pred": answer_pred,
            "inference_raw_outputs": inference_raw_outputs,
        })
    assert len(inference_results) == raw_len
    return inference_results


def benchmarking(evaluated_llm, judger_llm, encoder_model, benchmark_dataset, batch_size, question_type):
    n = len(benchmark_dataset)
    print("\n\n=================================================================================")
    print(f"Running benchmarking with question type [{question_type}], questions num: {n}", flush=True)

    evaluation_dataset = [data['question'] for data in benchmark_dataset]

    print("\n\n=================================================================================")
    print(f"Begin inference with question type [{question_type}], questions num: {n}", flush=True)
    inference_results = inference(evaluated_llm, evaluation_dataset, batch_size)

    print("\n\n=================================================================================")
    print(f"Begin judge answer with question type [{question_type}], questions num: {n}", flush=True)
    judge_results = None
    if question_type in ("comp_n", "choice"):
        judge_dataset = []
        for i in range(n):
            judge_dataset.append([
                benchmark_dataset[i]['question'],
                benchmark_dataset[i]['final_answer'],
                inference_results[i]['answer_pred'],
            ])
        if question_type == "comp_n":
            judge_results = judge_comp(judge_dataset, judger_llm, batch_size)
        else:
            judge_results = judge_choice(judge_dataset)

    print("\n\n=================================================================================")
    print(f"Begin judge logical_nexuses with question type [{question_type}], questions num: {n}", flush=True)
    logical_nexus_results = []
    for i in range(n):
        logical_nexuses = benchmark_dataset[i]['logical_nexuses']
        logical_nexus_weights = benchmark_dataset[i]['logical_nexus_weights']
        reasoning = None
        if inference_results[i]['answer_pred']:
            reasoning = inference_results[i]['answer_pred'].split("</think>")[-1]
        logical_nexus_results.append(
            logic_evaluation(
                logical_nexuses=logical_nexuses,
                logical_nexus_weights=logical_nexus_weights,
                reasoning=reasoning,
                encoder_model=encoder_model,
            )
        )
    print("\n\n=================================================================================")
    print(f"End judge logical_nexuses with question type [{question_type}], questions num: {n}", flush=True)
    return inference_results, judge_results, logical_nexus_results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Benchmarking parameters')
    parser.add_argument('--model_id', type=str, help='Model ID', default=None)
    parser.add_argument('--batch_size', type=int, help='Batch size', default=12)
    args = parser.parse_args()

    model_id = args.model_id or "o4-mini"
    evaluated_llm = close_model(debug=True, model_id=model_id)
    judger_llm = v3_batch(debug=True)

    # Load sentence encoder: fixed model
    encoder_model = SentenceTransformer('all-MiniLM-L6-v2')

    # Load unified dataset
    data_path = os.path.join(os.path.dirname(__file__), "Datas", "PhysLogic.json")
    with open(data_path, 'r', encoding='utf-8') as f:
        full_dataset = json.load(f)

    # Prepare output
    output_path = os.path.join(os.path.dirname(__file__), "results", model_id)
    os.makedirs(output_path, exist_ok=True)

    # Evaluate per type using unified dataset, but keep judging Acc only for choice and comp_n
    question_types = ["choice", "comp_n", "comp_e", "proof"]
    for qtype in question_types:
        benchmark_dataset = [item for item in full_dataset if item.get("question_type") == qtype]
        inference_results, judge_results, logical_nexus_results = benchmarking(
            evaluated_llm=evaluated_llm,
            judger_llm=judger_llm,
            encoder_model=encoder_model,
            benchmark_dataset=benchmark_dataset,
            batch_size=args.batch_size,
            question_type=qtype,
        )

        if qtype in ("choice", "comp_n"):
            assert len(inference_results) == len(judge_results) == len(logical_nexus_results) == len(benchmark_dataset)
        else:
            assert judge_results is None
            assert len(inference_results) == len(logical_nexus_results) == len(benchmark_dataset)

        benchmarking_results = {
            "benchmark_dataset": benchmark_dataset,
            "inference_results": inference_results,
            "judge_results": judge_results,
            "logical_nexus_results": logical_nexus_results,
        }
        with open(os.path.join(output_path, f"{qtype}.json"), 'w', encoding='utf-8') as f:
            json.dump(benchmarking_results, f, ensure_ascii=False, indent=4)
        print(f"Benchmarking results saved to results/{model_id}/{qtype}.json")


