import os
from openai import AsyncOpenAI

from slime.rollout.context_rollout.reward_utils.llm_prompt.hle_eval_prompt import evaluation_prompt as hle_evaluation_prompt
from slime.rollout.context_rollout.reward_utils.llm_prompt.gaia_eval_prompt import JUDGE_PROMPT_GAIA
from slime.rollout.context_rollout.reward_utils.llm_prompt.bc_eval_prompt import JUDGE_PROMPT_BC_en
from slime.rollout.context_rollout.reward_utils.llm_prompt.c3po_eval_prompt import c3po_evaluation_prompt
from slime.rollout.context_rollout.reward_utils.utils import extract_answers

URL = os.getenv('EVAL_LLM_URL')
KEY = os.getenv('EVAL_LLM_KEY')
LLM_NAME = os.getenv('EVAL_LLM_NAME', 'qwen2.5-72b-instruct')
client = AsyncOpenAI(
    api_key=KEY,
    base_url=URL,
    max_retries=3,
    timeout=1000,
)
MAX_RETRY = 5
POS_REWARD = 1.0
NEG_REWARD = -1.0

async def get_llm_reward(args, sample, **kwargs) -> float:
    """
    Get reward from LLM-as-Judge

    Args:
        args: the arguments
        sample: the sample to evaluate
    """
    data_source = sample.metadata['data_source']
    question = sample.prompt
    answer = sample.label[0] if isinstance(sample.label, list) and len(sample.label) == 1 else sample.label
    pred_str = extract_answers(sample.response)
    if pred_str is None:
        return NEG_REWARD

    if data_source.lower() in ['hle']:
        prompt = hle_evaluation_prompt.format(question=question, response=pred_str, correct_answer=str(answer))

        score = NEG_REWARD
        for attempt in range(MAX_RETRY):
            try:
                response = await client.chat.completions.create(
                    model=LLM_NAME,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0,
                    max_tokens=1024,
                    top_p=1,
                    extra_body={
                        "chat_template_kwargs": {"enable_thinking": False},
                    }
                )
                response_text = response.choices[0].message.content

                try:
                    correct = response_text.lower().split('correct:')[1].split("confidence:")[0].strip()
                except:
                    print(
                        f"Attempt {attempt + 1}/{MAX_RETRY}: Parsing failed for response: '{response_text}'. Retrying..."
                    )
                    continue

                score = POS_REWARD if "yes" in correct.lower() else NEG_REWARD

                break

            except Exception as e:
                print(f"Attempt {attempt + 1}/{MAX_RETRY}: LLM Judge server call failed with error: {str(e)}")
                if attempt == MAX_RETRY - 1:
                    raise Exception(f"LLM Server Error after all retries: {str(e)}")

    elif data_source.lower() in ['bc']:
        prompt = JUDGE_PROMPT_BC_en.format(question=question, correct_answer=str(answer), response=pred_str)
        score = NEG_REWARD
        for attempt in range(MAX_RETRY):
            try:
                response = await client.chat.completions.create(
                    model=LLM_NAME,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0,
                    max_tokens=1024,
                    top_p=1,
                    extra_body={
                        "chat_template_kwargs": {"enable_thinking": False},
                    }
                )
                response_text = response.choices[0].message.content

                if "[CORRECT]" in response_text or "A" in response_text:
                    score = POS_REWARD
                elif "[INCORRECT]" in response_text or "B" in response_text:
                    score = NEG_REWARD
                else:
                    print(
                        f"Attempt {attempt + 1}/{MAX_RETRY}: Parsing failed for BC response: '{response_text}'. Retrying..."
                    )
                    continue

                break

            except Exception as e:
                print(f"Attempt {attempt + 1}/{MAX_RETRY}: LLM Judge server call failed with error: {str(e)}")
                if attempt == MAX_RETRY - 1:
                    raise Exception(f"LLM Server Error after all retries: {str(e)}")

    elif data_source.lower() in ['gaia']:
        prompt = JUDGE_PROMPT_GAIA.format(question=question, correct_answer=str(answer), response=pred_str)
        score = NEG_REWARD
        for attempt in range(MAX_RETRY):
            try:
                response = await client.chat.completions.create(
                    model=LLM_NAME,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0,
                    max_tokens=1024,
                    top_p=1,
                    extra_body={
                        "chat_template_kwargs": {"enable_thinking": False},
                    }
                )
                response_text = response.choices[0].message.content
                
                if "incorrect" in response_text.lower():
                    score = NEG_REWARD
                elif "correct" in response_text.lower():
                    score = POS_REWARD
                else:
                    print(
                        f"Attempt {attempt + 1}/{MAX_RETRY}: Parsing failed for GAIA response: '{response_text}'. Retrying..."
                    )
                    continue

                break

            except Exception as e:
                print(f"Attempt {attempt + 1}/{MAX_RETRY}: LLM Judge server call failed with error: {str(e)}")
                if attempt == MAX_RETRY - 1:
                    raise Exception(f"LLM Server Error after all retries: {str(e)}")

    elif data_source.lower() in ['train']:
        prompt = c3po_evaluation_prompt.format(question=question, correct_answer=str(answer), response=pred_str)
        score = NEG_REWARD
        for attempt in range(MAX_RETRY):
            try:
                response = await client.chat.completions.create(
                    model=LLM_NAME,
                    messages=[
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0,
                    max_tokens=1024,
                    top_p=1,
                    extra_body={
                        "chat_template_kwargs": {"enable_thinking": False},
                    }
                )
                response_text = response.choices[0].message.content
                
                if "true" in response_text.lower():
                    score = POS_REWARD
                elif "false" in response_text.lower():
                    score = NEG_REWARD
                else:
                    print(
                        f"Attempt {attempt + 1}/{MAX_RETRY}: Parsing failed for C3PO response: '{response_text}'. Retrying..."
                    )
                    continue

                break

            except Exception as e:
                print(f"Attempt {attempt + 1}/{MAX_RETRY}: LLM Judge server call failed with error: {str(e)}")
                if attempt == MAX_RETRY - 1:
                    raise Exception(f"LLM Server Error after all retries: {str(e)}")

    else:
        raise NotImplementedError(f"{data_source=} not supported")
    
    return score