import re
import json
import numpy as np
from collections import Counter
import string
import os, time
from collections import defaultdict
from transformers import AutoTokenizer
from openai import OpenAI
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Load tokenizer for token counting
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
browse_comp_indices = [0, 1, 4, 6, 9, 10, 11, 12, 22, 27, 32, 33, 38, 39, 40, 43, 45, 48, 51, 52, 53, 54, 55, 57, 58, 60, 62, 63, 66, 71, 72, 73, 74, 77, 81, 88, 94, 95, 96, 98, 105, 106, 107, 117, 121, 124, 128, 129, 133, 136]
PROMPT='''You will be given a question and its ground truth answer. Provided a pred_answer, you need to judge if the pred_answer correctly answers the question based on the ground truth answer.
You should first give your rationale for the judgement, and then give your judgement result (i.e., correct or incorrect).

Here is the criteria for the judgement:
The pred_answer doesn't need to be exactly the same as the ground truth answer, but should be semantically same for the question.

question: {question}
ground truth answer: {gt_answer}
pred_answer: {pred_answer}

The output should in the following json format:
\'\'\'json
{
\"rationale\": \"your rationale for the judgement, as a text\",
\"judgement\": \"your judgement result, can only be \'correct\' or \'incorrect\'\"
}
\'\'\'
Your output:
'''

def extract_answer(output, baseline_name="Search-r1"):
    pattern = r'\\boxed{([\s\S]*?)}'
    matches = re.findall(pattern, output)
    if matches:
        pred = matches[-1]
        pred = pred.strip("()").strip("{").strip("}")
        pred = pred.replace("\\text", "").replace("{", "").replace("}", "")
        return pred
    return ''

def get_mbe_result(question, gts, pred_answer):
    judgement = ""
    try_cnt = 0
    max_retries = 3
    
    while try_cnt < max_retries:
        try:
            prompt = PROMPT.replace("{question}",question).replace("{gt_answer}",str(gts)).replace("{pred_answer}",pred_answer)
            output = call_api(prompt)
            
            if "incorrect" in output.lower():
                return 0.0
            elif "correct" in output.lower():
                return 1.0
            else:
                # If no clear judgement, try again
                try_cnt += 1
                if try_cnt >= max_retries:
                    print(f"Warning: No clear judgement after {max_retries} attempts for question: {question[:100]}...")
                    return 0.0
                time.sleep(1)  # Brief delay before retry
                
        except Exception as e:
            try_cnt += 1
            print(f"API call failed (attempt {try_cnt}/{max_retries}): {e}")
            if try_cnt >= max_retries:
                print(f"Failed to get result after {max_retries} attempts")
                return 0.0
            time.sleep(2)  # Longer delay for API errors
    
    return 0.0


def call_api(prompt):
    client = OpenAI(
        api_key="EMPTY",
        base_url="http://0.0.0.0:8000/v1"
    )
    response = client.chat.completions.create(
        model="Qwen/Qwen3-32B",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.,
        extra_body={"top_k": 20,
                    "chat_template_kwargs": {"enable_thinking": False}}
    )
    return response.choices[0].message.content

def process_single_item(item, gold, dataset_name, baseline_name, args):
    """Process a single item to get MBE score"""
    if dataset_name == "browse_comp":
        if gold["id"] not in browse_comp_indices:
            return None
    
    if "question" in item:
        question = item["question"]
    elif "Question" in item:
        question = item["Question"]
    else:
        question = item["item"]["Question"]
    
    if "med_browse_comp" in args.output_path:
        if gold["answer"] is None:
            answers = ["DATE: NA"]
        else:
            answers = [gold["answer"]]
    elif "browse_comp" in args.output_path:
        answers = [gold["Answer"]]
    elif "musique" in args.output_path:
        answers = [gold["answer"]]
    elif "2wiki" in args.output_path:
        answers = [gold["answer"]]
    elif "cwq" in args.output_path:
        answers = [gold["answer"][0]]
    elif "hotpotqa" in args.output_path:
        answers = [gold["answer"]]
    elif "fanoutqa" in args.output_path:
        answers = [str(gold["answer"])]
    elif "frames" in args.output_path:
        answers = [gold["answer"]]
    
    if baseline_name == "RAG-R1":
        pred = item["predict_answer"]
    elif baseline_name == "Ours":
        pred = item["generated_answer"]
    elif baseline_name == "DeepResearcher":
        pred = extract_answer(item["message_str"], baseline_name)
    elif "generated_answer" in item:
        pred = item["generated_answer"]
    elif "pred_ans" in item:
        pred = item["pred_ans"]
    elif "output" in item and type(item["output"]) == dict:
        pred = item["output"]["pred"]
    elif "Output" in item:
        pred = extract_answer(item["Output"], baseline_name)
    else:
        pred = extract_answer(item["output"], baseline_name)
    
    score = get_mbe_result(question, answers[0], pred)
    item["mbe_score"] = score 
    return item

if __name__ == "__main__":
    import argparse

    # Parse command-line arguments for flexibility
    parser = argparse.ArgumentParser(description="Evaluate model outputs with optional backoff.")
    parser.add_argument('--output_path', '-o', type=str, help='Path to the model output JSON file.')
    parser.add_argument('--output_metrics_path', '-s', type=str, help='Path to save the evaluation metrics.')
    parser.add_argument('--apply_backoff', action='store_true', help='Enable backoff to normal outputs if main output is invalid.')
    parser.add_argument('--type', type=str, help='HotpotQA comparison vs bridge')
    parser.add_argument('--dataset_name', type=str, help='Dataset name')
    parser.add_argument('--max_workers', type=int, default=10, help='Maximum number of worker threads for parallel processing (default: 10)')
    args = parser.parse_args()
    
    for baseline_name in ["results_naive"]: #results_naive results_zeroshot
        for dataset_name in ["musique", "fanoutqa", "frames", "med_browse_comp", "browse_comp"]: 
            args.output_path = f"results/{baseline_name}/{dataset_name}_qwen3-8b.json"
            args.output_metrics_path = f"results/{baseline_name}/{dataset_name}_qwen3-8b.mbe.json"
            args.dataset_name = dataset_name

            if "jsonl" in args.output_path:
                with open(args.output_path, "r", encoding='utf-8') as f:
                    outputs = [json.loads(line) for line in f]  
            elif "json" in args.output_path:
                with open(args.output_path, "r", encoding='utf-8') as f:
                    outputs = json.load(f)
                
            if "med_browse_comp" in args.output_path:
                args.dataset_name = "med_browse_comp"
                with open("reasoning_rag/datasets/MEDBROWSECOMP/test.json") as f:
                    golds = json.load(f)
            elif "browse_comp" in args.output_path:
                args.dataset_name = "browse_comp"
                with open("reasoning_rag/datasets/BROWSE_COMP/test.json") as f:
                    golds = json.load(f)
            elif "musique" in args.output_path:
                args.dataset_name = "musique"
                with open("reasoning_rag/datasets/MUSIQUE/dev_sub.json") as f:
                    golds = json.load(f)
            elif "2wiki" in args.output_path:
                args.dataset_name = "2wiki"
                with open("reasoning_rag/datasets/2WIKI/dev.json") as f:
                    golds = json.load(f)
            elif "cwq" in args.output_path:
                args.dataset_name = "cwq"
                with open("reasoning_rag/datasets/CWQ/dev.json") as f:
                    golds = json.load(f)
            elif "hotpotqa" in args.output_path:
                args.dataset_name = "hotpotqa"
                with open("reasoning_rag/datasets/HOTPOTQA/dev_sub.json") as f:
                    golds = json.load(f)
            elif "fanoutqa" in args.output_path:
                args.dataset_name = "fanoutqa"
                with open("reasoning_rag/datasets/FANOUTQA/dev.json") as f:
                    golds = json.load(f)
            elif "frames" in args.output_path:
                args.dataset_name = "frames"
                with open("reasoning_rag/datasets/FRAMES/test.json") as f:
                    golds = json.load(f)
            
            print("=="*10, args.dataset_name, baseline_name, "=="*10)
            
            if baseline_name == "R1-Searcher":
                outputs_dict = {item["question"]: item for item in outputs}
                outputs = [outputs_dict[g["question"] if "question" in g else g["Question"]] for g in golds]
            
            if baseline_name == "DeepResearcher" and dataset_name == "musique":
                outputs_dict = {item["question"]: item for item in outputs}
                outputs = [outputs_dict[g["question"] if "question" in g else g["Question"]] for g in golds]
            
            if dataset_name == "browse_comp":
                if len(outputs) != len(browse_comp_indices):
                    outputs = [outputs[i] for i in browse_comp_indices]
                if len(golds) != len(browse_comp_indices):
                    golds = [golds[i] for i in browse_comp_indices]
            
            outputs_with_mbe_score = []
            with ThreadPoolExecutor(max_workers=args.max_workers) as executor:
                futures = {executor.submit(process_single_item, item, gold, dataset_name, baseline_name, args): (i, item, gold) for i, (item, gold) in enumerate(zip(outputs, golds))}
                
                # Track progress with tqdm
                results = {}
                with tqdm(total=len(futures), desc=f"Processing {dataset_name} with {baseline_name}") as pbar:
                    for future in as_completed(futures):
                        i, item, gold = futures[future]
                        try:
                            processed_item = future.result()
                            if processed_item is not None:
                                results[i] = processed_item
                        except Exception as e:
                            print(f"Error processing item {i}: {item.get('question', 'Unknown')}. Error: {e}")
                            # Keep original item without MBE score
                            results[i] = item
                        pbar.update(1)
                
                # Sort results by original order
                outputs_with_mbe_score = [results[i] for i in sorted(results.keys())]
            
            with open(args.output_metrics_path, "w", encoding="utf-8") as f:
                json.dump(outputs_with_mbe_score, f, indent=4)
            
            
            print(f'Results are saved to {args.output_metrics_path}')
