
import re
import json
from nltk.translate.bleu_score import corpus_bleu
import jsonlines
from openai import OpenAI
cache_file = "/home/liyishan/Search_R1_WorkSpace/cache.jsonl"
# EOS_TOKEN ="<|endoftext|>"
# BOS_TOKEN = "<|startoftext|>"
import random
openai_api_key = "EMPTY"
openai_api_base = "http://g3015:8700/v1"
openai_api_bases = ["http://g3015:8700/v1", "http://g3015:8701/v1", "http://g3015:8702/v1", "http://g3015:8703/v1", "http://g3015:8704/v1", "http://g3015:8705/v1", "http://g3015:8706/v1", "http://g3015:8707/v1"]

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
def get_messages(completion_str):
    completion_str = "<|im_start|>assistant\n"+completion_str
    pattern = r"<\|im_start\|>(\w+)\n(.*?)<\|im_end\|>"
    matches = re.findall(pattern, completion_str, re.DOTALL)

    messages = []
    for role, content in matches:
        messages.append({
            "role": role,
            "content":
                content.strip()  # Remove any extra whitespace or newlines
        })
    return messages

def is_start_before_end(text: str, start: str, end: str) -> bool:
    return text.find(start) < text.find(end)

def validate_template_format(text: str) -> tuple[bool, str]:
    """
    check if the text is a valid qa template
    return: (is_valid, error_message)
    """
    
    if text.count('<think>') != text.count('</think>'):
        return False
    # if not is_start_before_end(text, '<think>', '</think>'):
    #     return False
    
    if "<answer>" in text:
        if text.count('<answer>') != text.count('</answer>'):
            return False
        # if not is_start_before_end(text, '<answer>', '</answer>'):
        #     return False
    elif "<tool_call>" in text:
        if text.count('<tool_call>') != text.count('</tool_call>'):
            return False
        if not is_start_before_end(text, '<tool_call>', '</tool_call>'):
            return False
        if extract_tool_call(text) is None:
            return False
    else:
        return False    
    return True

def extract_answer(text: str):
    text = text.strip()

    pattern = r"<answer>(.*?)</answer>"
    match = re.findall(pattern, text, re.DOTALL)
    if not match:
        return ""
    return " ".join(match)

def extract_tool_call(text: str):
    text = text.strip()

    pattern = r"<tool_call>(.*?)</tool_call>"
    match = re.search(pattern, text, re.DOTALL)
    if not match:
        return None
    tool_text = match.group(1)
    try:
        tool_call = json.loads(tool_text)
    except json.JSONDecodeError:
        return None
    return tool_call if isinstance(tool_call, dict) else None

def match_reference(text:str):
    # references = re.findall(r"\[\s*(\d+)(\s*,\s*\d+)*\s*\]", text)
    reg = r"\\\w*cite(?!style)\w*\{(.+?)\}"
    placeholder_reg = re.compile(r"^#\d+$")
    reg_bibkeys = re.findall(reg, text)
    bibkeys = set()
    for bibkey in reg_bibkeys:
        single_bib = bibkey.split(",")
        for bib in single_bib:
            if not placeholder_reg.match(bib):
                bib = bib.strip()
                if bib and bib != "*":
                    bibkeys.add(bib)
                    
    reg = r"\\nocite{(.+?)\}"
    reg_bibkeys = re.findall(reg, text)
    for bibkey in reg_bibkeys:
        single_bib = bibkey.split(",")
        for bib in single_bib:
            if not placeholder_reg.match(bib):
                bib = bib.strip()
                if bib and bib != "*":
                    bibkeys.remove(bib)
        
        
        # new_reference.append(reference_map[ref.strip()])
        # replaced_refs.append(new_reference)
        # new_reference_text = "[" + ", ".join(new_reference) + "]"
        # text = text.replace(reference, new_reference_text)
    ref_key_list = list(bibkeys)
    return ref_key_list
def split_conversion_to_list(conversation_text: str):
    messages = get_messages(conversation_text)
    conversation_texts_assistant = [message["content"] for message in messages if message["role"] == "assistant"]
    return conversation_texts_assistant, messages

def compute_score_with_format(conversation_text, ground_truth) -> float:
    # return 0
    non_finalized_tool_call = 1
    
    conversation_texts_assistant, conversation_texts_all = split_conversion_to_list(conversation_text)
    if len(conversation_texts_assistant) == 0:
        with jsonlines.open(cache_file, "a") as writer:
            writer.write({"conversation_text": conversation_text,"reward": 0})
        return 0
    # if any([not validate_template_format(text) for text in conversation_texts_assistant]):
    #     return 0
    last_tool_call = extract_tool_call(conversation_texts_assistant[-1])
    # if last_tool_call is None:
    #     return 0
    if last_tool_call is None or last_tool_call.get("name",None) != "finalize":
        non_finalized_tool_call = 0.8
    all_answers = [extract_answer(text) for text in conversation_texts_assistant]
    all_answers = [text for text in all_answers if text]
    if not all_answers:
        return 0
    all_answers = " ".join(all_answers)
    llm_as_judge_score = run_llm_as_judge(all_answers)
    
    # bleu_score = corpus_bleu([[ground_truth]], [all_answers]) * 100
    
    # citation
    user_text_before = ""
    wrong_citation = False
    wrong_citation_cofficient = 1
    refs = []
    for message in conversation_texts_all:
        if message["role"] == "user":
            user_text_before += message["content"]
        elif message["role"] == "assistant":
            answer = extract_answer(message["content"])
            ref_key_list = match_reference(answer)
            refs.extend(ref_key_list)
            for ref in ref_key_list:
                if ref not in user_text_before:
                    wrong_citation = True
                    break
            if wrong_citation:
                break
    if wrong_citation:
        wrong_citation_cofficient = 0.9
    
    # ground_truth_refs = match_reference(ground_truth)
    # citation_recall = (len(set(refs) & set(ground_truth_refs)) + 1) / (len(ground_truth_refs) + 1)
    return llm_as_judge_score * non_finalized_tool_call * wrong_citation_cofficient
    # return (bleu_score*4 + citation_recall) * non_finalized_tool_call * wrong_citation_cofficient
max_retry = 3

# TODO: Add user instructions
def run_llm_as_judge(text):
    paras = {'SECTION': text}
    prompt = LANGUAGE_CRITICAL_EVALUATION_PROMPT
    for k in paras.keys():
        prompt = prompt.replace(f"[{k}]", paras[k])
    messages = [{"role": "system", "content": "You are an assistant."},
                            {"role": "user", "content":prompt}]
    retry_time = 0
    temperature = 0.0
    while retry_time < max_retry:
        try:
            rand_base = random.randint(0, len(openai_api_bases)-1)
            openai_api_base = openai_api_bases[rand_base]
            client.base_url = openai_api_base
            response = client.chat.completions.create(
                        model="/home/liyishan/Search_R1_WorkSpace/model/Qwen2.5-14B-Instruct",
                        messages=messages,
                        temperature=temperature,
                        max_tokens=1024
                    ).choices[0].message.content
            match = re.search(r'<SCORE>\s*(\d+(\.\d+)?)\s*</SCORE>', response)
            if match:
                score = float(match.group(1)) 
                if 0 <= score <= 100:
                    return float(score)
            else:
                temperature += 0.2
                retry_time += 1
                continue
        except:
            retry_time += 1
            continue
        break
            
    return 60.0 # default score

LANGUAGE_CRITICAL_EVALUATION_PROMPT = """

[Task]
Rigorously evaluate the quality of an academic survey by scoring three dimensions (each 0-100) and calculating the average as the final score. 

[Evaluation Criteria]  
Evaluate each dimension on a 0-100 scale based strictly on the highest standards below. The final score is the average of the three dimension scores.

1. **Academic Formality** (100 points):  
   - Demonstrates *flawless* academic rigor. Uses precise terminology consistently, avoids colloquial language entirely, and maintains a strictly scholarly tone. Sentence structures are sophisticated and purposefully crafted to enhance analytical depth. **Even a single instance of informal phrasing or imprecise terminology disqualifies a perfect score**.
2. **Clarity & Readability** (100 points):  
   - Writing is *exceptionally* clear and concise. Sentences are logically structured, with no ambiguity. Transitions between ideas are seamless, and the argument progresses with precision. **Any unnecessary complexity or minor ambiguity precludes full marks.**  
3. **Redundancy** (100 points):  
   - **Unique**: each sentence must have unique value and cannot be repeated. Repetition is only allowed to maintain structural coherence, such as using uniform terminology or necessary transitional phrases. Repeating key concept definitions in a new context to help readers understand can be seen as a structural requirement.
   - **Efficient argumentation**: Argumentation needs to be efficient, with logically coherent viewpoints and avoiding unnecessary repetition. Even minor repetitions without actual structural effects can result in deduction of points. For example, repeating a discovery almost identical in the same paragraph without providing new insights or perspectives will result in deduction of points. 
4. **Critical Analysis** (100 points):  
   - Offers a deep, incisive critique of methodologies, results, and underlying assumptions. Provides a clear identification of significant gaps, weaknesses, and areas for improvement. Challenges assumptions with well-supported arguments, offering clear alternatives or improvements.  
5. **Original Insights** (100 points):  
   - Proposes novel, well-supported interpretations or frameworks based on the reviewed literature. Demonstrates a strong understanding of the subject matter and provides genuinely original contributions that challenge the status quo. Insights are clearly connected to existing research, offering fresh perspectives or unique ways forward.  
6. **Future Directions** (100 points):  
   - Clearly identifies specific, promising research directions with strong justification. Suggests actionable, concrete ideas for future research that are rooted in the gaps identified within the reviewed literature. Demonstrates foresight in proposing innovative approaches and methodologies.  

[Section]
[SECTION]

[Output Format]
Rationale:
<Provide a detailed reason for the score, considering all dimensions step by step. Highlight specific strengths and weaknesses, such as the consistency of academic tone, the clarity of sentence structure, or the presence of redundancy.>
Final Score: 
<SCORE>(X+Y+Z)/6</SCORE>  
(Example: <SCORE>23</SCORE>; scores can include two decimal place)
"""