
import json
import os
import random
import re
from concurrent.futures import ThreadPoolExecutor
from time import sleep
from typing import Any, Dict, List, Optional, Union

import requests

REWARD_SERVER = os.environ.get("REWARD_MODEL_SERVER", "localhost:30000")
MODEL_NAME = os.environ.get("REWARD_MODEL_NAME", "default")
REQUIRE_THINKING = os.environ.get("REQUIRE_THINKING", "false").lower() == "true"
MAX_RETRIES = 2
MAX_PARSE_RETRIES = 2
BASE_DELAY = 1
MAX_WORKERS = 32

SCORE_TEMPLATE = """Score the assistant's response against each rubric item.

## Conversation
<Conversation>
<<conversation>>
</Conversation>

## Rubrics
<Rubric_items>
<<rubric_item>>
</Rubric_items>

## Output
Return a JSON list of booleans, one for each rubric in order.
- true: criteria met
- false: criteria not met
For negative criteria (bad behavior): true if response shows bad behavior, false if avoided.""".strip()


def has_thinking_tags(text: str) -> bool:

    if not text:
        return False
    return '<think>' in text and '</think>' in text


def match_choice(text: str, options: Optional[Dict[str, str]] = None) -> Optional[str]:

    if not text:
        return None

    if '## Final Response\n\n' in text:
        text = text.split('## Final Response\n\n')[-1]

    if '</think>' in text:
        text = text.split('</think>')[-1].strip()

    answer_matches = re.findall(r'<answer>([A-Za-z]+)</answer>', text)
    if answer_matches:
        ans = answer_matches[-1].upper()
        return ans

    return None


def compute_mc_score(
    solution_str: str,
    ground_truth: Dict[str, Any],
    do_print: bool = False
) -> Dict[str, Any]:

    correct_answer = ground_truth.get('answer_idx', '').upper()
    options = ground_truth.get('options', {})

    extracted_answer = match_choice(solution_str, options)

    if extracted_answer is None:
        return {
            "score": 0.0,
            "acc": 0.0,
            "success": True,
            "error": "",
            "grading_str": f"extracted=None, correct={correct_answer}",
        }

    valid_options = [k.upper() for k in options.keys()] if options else ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
    if extracted_answer not in valid_options:
        return {
            "score": 0.0,
            "acc": 0.0,
            "success": True,
            "error": "",
            "grading_str": f"extracted={extracted_answer}, correct={correct_answer}",
        }

    if extracted_answer == correct_answer:
        score = 1.0
    else:
        score = 0.0

    return {
        "score": score,
        "acc": score,
        "success": True,
        "error": "",
        "grading_str": f"extracted={extracted_answer}, correct={correct_answer}",
    }


def remove_thinking_part(text: str) -> str:
    if not text:
        return text
    if '## Final Response\n\n' in text:
        text = text.split('## Final Response\n\n')[-1]
    if '</think>' in text:
        text = text.split('</think>')[-1].strip()
    return text

def build_conversation_str(prompt: list, response_text: str) -> str:
    conv = list(prompt) + [{"role": "assistant", "content": response_text}]
    return "\n".join([f"{m['role']}: {m['content']}" for m in conv])


def build_rubrics_str(rubric_items: list) -> str:
    lines = []
    for i, r in enumerate(rubric_items):
        points = r.get('points', 0)
        criterion = r.get('criterion', '')
        lines.append(f"{i+1}. ({points:+g}pts) {criterion}")
    return "\n".join(lines)


def parse_bool_list(text: str, expected_len: int) -> Optional[List[bool]]:
    if not text:
        return None

    text = text.strip()
    if '</think>' in text:
        text = text.split('</think>')[-1].strip()

    try:
        match = re.search(r'\[.*?\]', text, re.DOTALL)
        if match:
            json_str = match.group(0).lower()
            result = json.loads(json_str)
            if isinstance(result, list) and all(isinstance(x, bool) for x in result):
                if len(result) == expected_len:
                    return result
                else:
                    return None
    except:
        pass

    try:
        bools = re.findall(r'\b(true|false)\b', text.lower())
        if len(bools) == expected_len:
            return [b == 'true' for b in bools[:expected_len]]
        else:
            return None
    except:
        pass

    return None


def calculate_rubric_score(rubric_items: list, grading_list: list) -> Optional[float]:
    if not grading_list or len(grading_list) != len(rubric_items):
        return None

    total_possible = sum(r['points'] for r in rubric_items if r['points'] > 0)
    if total_possible == 0:
        return None

    achieved = sum(r['points'] for r, g in zip(rubric_items, grading_list) if g is True)
    score = achieved / total_possible
    return max(0, score)


def call_reward_model_api(prompt: str, server: str = None, temperature: float = 0.0) -> Optional[str]:


    server = server or REWARD_SERVER
    url = f"http://{server}/v1/chat/completions"

    messages = [{"role": "user", "content": prompt}]
    data = {
        "model": MODEL_NAME,
        "messages": messages,
        "max_tokens": 256,
        "temperature": temperature,
        "extra_body": {
            "chat_template_kwargs": {"enable_thinking": False}
        }
    }
    headers = {"Content-Type": "application/json"}

    for attempt in range(MAX_RETRIES):
        try:
            response = requests.post(url, headers=headers, json=data, timeout=60)
            response.raise_for_status()
            result = response.json()
            return result["choices"][0]["message"]["content"]
        except Exception as e:
            if attempt < MAX_RETRIES - 1:
                delay = BASE_DELAY * (2 ** attempt)
                sleep(delay)
            else:
                return None
    return None


def compute_openend_score(
    solution_str: str,
    ground_truth: Dict[str, Any],
    server: str = None,
    do_print: bool = False
) -> Dict[str, Any]:
    rubrics = ground_truth.get("rubrics", []) or []
    prompt = ground_truth.get("prompt", []) or []

    if not rubrics:
        return {"score": 0.0, "acc": 0.0, "success": False, "error": "No rubrics", "grading_str": "[]"}

    server = server or REWARD_SERVER

    response_text = remove_thinking_part(solution_str)
    if len(response_text) > 10000:
        response_text = response_text[:10000]

    conv_str = build_conversation_str(prompt, response_text)
    rubrics_str = build_rubrics_str(rubrics)
    input_text = SCORE_TEMPLATE.replace("<<conversation>>", conv_str).replace("<<rubric_item>>", rubrics_str)
    expected_len = len(rubrics)

    grading_list = None
    rm_response = ""
    temperatures = [0.0, 0.3, 0.5]

    for retry in range(MAX_PARSE_RETRIES):
        temp = temperatures[retry] if retry < len(temperatures) else 0.7
        rm_response = call_reward_model_api(input_text, server, temperature=temp)
        if rm_response:
            grading_list = parse_bool_list(rm_response, expected_len)
            if grading_list is not None:
                break

    if grading_list is None:
        grading_list = [False] * expected_len
        return {
            "score": 0.0,
            "acc": 0.0,
            "success": False,
            "error": "ParseFailed",
            "grading_str": str(grading_list),
        }

    score = calculate_rubric_score(rubrics, grading_list) or 0.0
    acc = score

    return {
        "score": float(score),
        "acc": acc,
        "success": True,
        "error": "",
        "grading_str": str(grading_list),
    }


def compute_score(
    data_source: str,
    solution_str: str,
    ground_truth: Union[str, Dict[str, Any]],
    extra_info: Optional[Dict[str, Any]] = None,
    **kwargs
) -> Dict[str, Any]:
    do_print = random.randint(1, 50) == 1

    require_thinking = kwargs.get("require_thinking", REQUIRE_THINKING)

    if require_thinking and not has_thinking_tags(solution_str):
        return {
            "score": 0.0,
            "acc": 0.0,
            "success": True,
            "error": "no_thinking_tags",
            "grading_str": "missing <think></think>",
        }


    if not isinstance(ground_truth, dict):
        return {
            "score": 0.0,
            "acc": 0.0,
            "success": False,
            "error": "Invalid ground_truth format"
        }

    question_type = ground_truth.get('type', 'mc')

    if question_type == 'openend':
        server = kwargs.get("reward_router_address") or REWARD_SERVER
        result = compute_openend_score(solution_str, ground_truth, server, do_print)
    else:
        result = compute_mc_score(solution_str, ground_truth, do_print)

    return result


def compute_score_batch(
    data_sources: List[str],
    solution_strs: List[str],
    ground_truths: List[Union[str, Dict[str, Any]]],
    extra_infos: List[Optional[Dict[str, Any]]],
    **kwargs
) -> List[Dict[str, Any]]:
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = []
        for data_source, solution_str, ground_truth, extra_info in zip(
            data_sources, solution_strs, ground_truths, extra_infos, strict=True
        ):
            future = executor.submit(
                compute_score, data_source, solution_str, ground_truth, extra_info, **kwargs
            )
            futures.append(future)

        results = [future.result() for future in futures]

    return results
