import os
import re 
import time
from rich import print
from openai import Client
from openai import OpenAI
from futuremind.eval.utils import extract_question_from_user_block
from futuremind.eval.utils import extract_solution, normalize_answer, subem_check, save_record_to_md_append_only, truncate_after_marker

os.environ["OPENAI_API_KEY"]  = "" 
os.environ["OPENAI_API_BASE"] = ""

client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
    base_url=os.environ.get("OPENAI_API_BASE")
)

def generate(messages, model_name):
    retry_cnt = 0
    while True: 
        try:
            if retry_cnt:
                print(f"Retry: {retry_cnt}")
            response = client.chat.completions.create(
                **{
                    "model": model_name,
                    "messages": messages,
                    # "max_tokens": 1024,
                    "temperature": 0,
                }
            )
            response = response.choices[0].message.content
            return response
        except Exception as e:
            retry_cnt += 1
            print(f"Error: {e}")
            time.sleep(1)
        
def compute_score_em(solution_str, ground_truth):
    """The scoring function for exact match (EM).

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
    """
    if solution_str is None or ground_truth is None:
        return 0.0
    
    try:
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)
        if answer is None:
            return 0.0
        return float(subem_check(answer, ground_truth))
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_em: {e}")
        return 0.0


def compute_score_format(solution_str):
    """
    Perfect format match for the new structure
    First <|im_start|>assistant should have <think> and possibly <tool_call>
    Then <|im_start|>tool with <tool_response> (can repeat with assistant/tool pairs)
    Final <|im_start|>assistant with the answer and <|im_end|>
    """
    if solution_str is None:
        return 0.0
    
    try:
        format_reward = 0.0
        # Check for basic structure with <|im_start|>assistant and <|im_end|> tags
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)

        # If no blocks found, return 0
        if not assistant_blocks or len(assistant_blocks) == 0:
            return 0.0
    
        # Check first assistant block contains <think> tags
        for i, assistant_block in enumerate(assistant_blocks[:-1]):
            if assistant_block.count('<think>') == 1 and assistant_block.count('</think>') == 1 and assistant_block.count('<tool_call>') == 1 and assistant_block.count('</tool_call>') == 1:
                think_match = re.search(r'^<think>(.*?)</think>(\s*)<tool_call>(.*?)</tool_call>$', assistant_block, re.DOTALL)
                if think_match:
                    # format_reward += 0.2 * (0.8 ** i)
                    format_reward += 0.5

        # Check the last assistant block contains <answer> tags
        last_assistant_block = assistant_blocks[-1]
        think_answer_match = re.search(r'^<think>(.*?)</think>(.*?)<answer>(.*?)</answer>$', last_assistant_block, re.DOTALL)
        if think_answer_match:
            format_reward += 0.5
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format: {e}")
        return 0.0
    return format_reward


def compute_score_answer(solution_str, ground_truth):
    """The scoring function for exact match (EM) with format reward.

    Args:
        solution_str: the solution text
        ground_truth: the ground truth
    
    Returns:
        float: Total reward score (format reward + answer reward)
    """
    if solution_str is None:
        return 0.0
    
    try:
        whole_solution_str = solution_str
        question_contain_otherstr = extract_question_from_user_block(whole_solution_str)
        question = truncate_after_marker(question_contain_otherstr, "<<<")
        assistant_blocks = re.findall(r'<\|im_start\|>assistant\n(.*?)<\|im_end\|>', solution_str, re.DOTALL)
        if not assistant_blocks or len(assistant_blocks) == 0:
             return 0.0
        
        solution_str = assistant_blocks[-1]
        answer = extract_solution(solution_str)
        answer_reward = 0.0
        if answer is not None:
            prompt = '''Given a Question and its Golden Answer, verify whether the Predicted Answer is correct. The prediction is correct if it fully aligns with the meaning and key information of the Golden Answer. Respond with True if the prediction is correct and False otherwise.
            Golden Answer may have multiple options, and matching any one of them is considered correct.

            Question: {question}
            Golden Answer: {reference}
            Predicted Answer: {prediction}
            '''
            record = {}
            record["question"] = str(question)
            record["reference"] = str(ground_truth)
            record["prediction"] = str(answer)

            gpt4o_input = prompt.format(question=str(question) , reference=str(ground_truth), prediction=str(answer))
            flag_final_ans = True
            if flag_final_ans:
                messages = [{'role': 'user', 'content': gpt4o_input}]
                model_output = generate(messages, 'default')
                record["gpt4o_eval"] = model_output
            else:
                record["gpt4o_eval"] = "Fuck! No boxed"
            
            if record["gpt4o_eval"] == "True":
                record["score"] = int(1)
                answer_reward = 1.0
            elif record["gpt4o_eval"] == "False":
                record["score"] = int(0)
                answer_reward = 0.0
            else:
                record["score"] = int(0)
                answer_reward = 0.0
            record["whole_solution"] = whole_solution_str

            md_file = save_record_to_md_append_only(question, ground_truth, answer, record, answer_reward)

    except Exception as e:
        print(f"[DEBUG] Error in compute_score_answer: {e}")
        return 0.0

    return answer_reward


def compute_score_format_answer(solution_str, ground_truth):
    """The scoring function for format reward.

    Args:
        solution_str: the solution text
    """
    if solution_str is None or ground_truth is None:
        return 0.0

    try:
        format_reward = compute_score_format(solution_str)
        answer_reward = compute_score_answer(solution_str, ground_truth)

        format_reward = min(format_reward, 1.0)
        if format_reward >= 0.5:
            return -1.0 + format_reward + answer_reward, answer_reward
        else:
            return -1.0 + format_reward, answer_reward
    except Exception as e:
        print(f"[DEBUG] Error in compute_score_format_answer: {e}")
        return -1.0, answer_reward