import re
from researcher.rewards.utils import AbstractAgent

def calculate_value(resp_str: str):
    pattern = re.compile(
        r"<goal>(.*?)</goal>"
        r".*?<\|im_end\|>.*?"
        r"<learnings>(.*?)</learnings>"
        r".*?<\|im_end\|>.*?"
        r"<reflect>(.*?)</reflect>",
        re.DOTALL
    )

    try:
        match = pattern.findall(resp_str)[-1]
        goal_content, learnings_content, reflect_content = match
        value = call_llm(goal_content, learnings_content, reflect_content)
    except:
        value = 0.0

    return value

def call_llm(g_content: str, l_content: str, r_content: str, model_name: str='gpt-4o-mini') -> float:
    llm = AbstractAgent(model_name)
    llm.system_prompt = '''
        Evaluate whether the retrieved materials comprehensively satisfy the information requirements and contextual relevance of the specified search goal. If the materials demonstrate complete semantic alignment with both the explicit intent and implicit contextual nuances of the search goal, return True; otherwise, return False.
    '''
    prompt = f"""Search goal:{g_content}\nsearch materials:{l_content}"""
    resp = llm.response(prompt)
    value = resp.lower() == r_content.lower()
    return float(value)


def calculate_step_reward(resp_str, tokenizer):
    content_list = resp_str.split('</reflect>')
    step_inds = []
    values = []
    prefix = ''
    for content in content_list[:-1]:
        prefix = prefix + content + '</reflect>'
        output = tokenizer(prefix, add_special_tokens=False, return_tensors="pt")
        input_ids = output.input_ids[0].tolist()
        step_inds.append(len(input_ids))
        values.append(calculate_value(prefix))

    return step_inds, values

