import re
import hashlib
from typing import Dict, Tuple, Optional
import subprocess
import json
import time
import random
import signal
import sys
import os
import concurrent.futures
import numpy as np

from custom_verl.rationalecode import code_judge

oj = code_judge.OnlineJudge(timeout=3, early_exit=False)


def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
    processed_str = solution_str

    # Extract final answer using XML-style tags
    answer_pattern = r"<answer>(.*?)</answer>"
    matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))

    if not matches:
        print("[Error] No valid answer tags found")
        return None, processed_str

    final_answer = matches[-1].group(1).strip()
    return final_answer, processed_str


def validate_response_structure(processed_str: str) -> bool:
    """Performs comprehensive validation of response structure.

    Args:
        processed_str: Processed response string from the model

    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    debug_str = []
    debug_str.append("\n[Structure Validation]")
    validation_passed = True

    # Check required tags
    tags = {
        "think_start": ("<think>", 1),
        "think_end": ("</think>", 1),
        "answer_start": ("<answer>", 1),
        "answer_end": ("</answer>", 1),
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)

        debug_str.append(f"  {tag_str}: count={count}, position={pos}")

        if count != expected_count:
            debug_str.append(
                f"  [Error] {tag_str} appears {count} times (expected {expected_count})"
            )
            validation_passed = False

    # Verify tag order
    if (
        positions["think_start"] > positions["think_end"]
        or positions["think_end"] > positions["answer_start"]
        or positions["answer_start"] > positions["answer_end"]
    ):
        debug_str.append(
            "  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>"
        )
        validation_passed = False
    # elif (
    #     processed_str.strip()[-len("</answer><|endoftext|>") :]
    #     != "</answer><|endoftext|>"
    # ):
    #     debug_str.append(
    #         "  [Error] Incorrect end token: Expected </answer><|endoftext|>"
    #     )
    #     validation_passed = False
    elif processed_str.strip()[0 : len("<think>")] != "<think>":
        debug_str.append("  [Error] Incorrect start token: Expected <think>")
        validation_passed = False
    else:
        debug_str.append("  Tag sequence validation passed")

    return validation_passed, debug_str


def compute_score(
    solution_str: str,
    ground_truth: Dict[str, str],
    format_reward: int = 0.1,
    answer_reward: float = 1.0,
    **kwargs,
):
    """Computes comprehensive score for model response.

    Args:
        solution_str: Raw model response string
        ground_truth: Dictionary containing ground truth data
        format_reward: Points awarded/deducted for format correctness
        answer_reward: Points awarded/deducted for answer correctness

    Returns:
        Total score (sum of format and answer rewards)
    """
    if isinstance(ground_truth, str):
        ground_truth = json.loads(ground_truth)
    debug_str = []
    debug_str.append("\n" + "=" * 80)
    debug_str.append(" Processing New Sample ".center(80, "="))

    # Parse ground truth data
    solution_text = ""

    # Extract model answer
    answer_text, processed_str = extract_solution(solution_str)
    # debug_str.append(f"\n[Question]\n{question_str}")
    debug_str.append(f"\n[Model Response]\n{processed_str}")

    # Validate response structure
    format_correct, debug_info = validate_response_structure(processed_str)
    debug_str.extend(debug_info)
    format_score = format_reward if format_correct else -abs(format_reward)
    debug_str.append(f"\n  Format validation: {'PASS' if format_correct else 'FAIL'}")
    debug_str.append(f"  Format score: {format_score}")

    # Validate answer content
    answer_score = 0
    if oj.check_ce(answer_text):
        debug_str.append("Compile Error!")
        answer_score = -2
    else:
        answer_score = oj.run(answer_text, ground_truth)
        answer_score = 5.0 * np.sum(answer_score) / len(answer_score)

    total_score = format_score + answer_score
    debug_str.append("\n" + "-" * 80)
    debug_str.append(" Final Score ".center(80, "-"))
    debug_str.append(f"  Format: {format_score}")
    debug_str.append(f"  Answer: {answer_score}")
    debug_str.append(f"  Total: {total_score}")
    debug_str.append("=" * 80 + "\n")

    return total_score, "\n".join(debug_str)


if __name__ == "__main__":
    oj = code_judge.OnlineJudge()
    print(
        oj.run(
            "#include <bits/stdc++.h>\nusing namespace std;\nint main() {\n  string s;\n  cin >> s;\n  for(int j = 0; j < 10000000; ++j) for (int i = 0; i < s.length(); i++) {\n    if (s[i] == s[i + 1] && s[i] == s[i + 2]) {\n      cout << s[i];\n      return 0;\n    }\n  }\n  cout << -1;\n  return 0;\n}\n",
            {
                "input": [123123123129912857127437128819329319200] * 200,
                "output": [1] * 200,
            },
        )
    )
