import os
import re
import sys
import json
import time
import string
import random
import signal
import tempfile
import threading
import hashlib
import subprocess
import numpy as np
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Tuple, Optional, List

from custom_verl.reward_utils import RewardType
from .code_judge import OnlineJudge


def simple_extract_code(text):
    matches = re.findall(r"```python\n(.*?)\n```", text, re.DOTALL)
    return matches[-1] if matches else ""


def simple_extract_testcase(code):
    fullcode = simple_extract_code(code)
    if fullcode == "":
        return "", []
    else:
        test_fn_names = re.findall(r"def (test_[a-zA-Z0-9_]+)\(", fullcode)
        lines = fullcode.split("\n")
        # Remove the call to the test function
        lines = [line for line in lines if not any(line.startswith(f"{name}()") for name in test_fn_names)]
        fullcode = "\n".join(lines)
        return fullcode, test_fn_names


def get_testcase_type(good_bad_score):
    if good_bad_score[0] == 0:
        return RewardType.FailGood
    else:
        if good_bad_score == (1, 0):
            return RewardType.PassGoodFailBad
        elif good_bad_score == (1, 1):
            return RewardType.PassGoodPassBad


def remove_if_main(code):
    pattern = r'(?m)^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$'
    # check if pattern exists
    if not re.search(pattern, code):
        return code
    replacement = "def DUMMY_MAIN():"
    code = re.sub(pattern, replacement, code)
    code = code + "\n\nDUMMY_MAIN()\n"
    return code


def wrap_in_function_test(code_snippet):
    code_snippet = remove_if_main(code_snippet)
    indented_code = "\n".join([" " * 4 + line for line in code_snippet.splitlines()])
    res = f"""
def foo(input_str):
    import io
    import sys
    sys.stdin = io.StringIO(input_str)  # Mock stdin
    sys.stdout = io.StringIO()  # Capture stdout

{indented_code}

    output = sys.stdout.getvalue().strip()  # Get the output
    sys.stdin = sys.__stdin__  # Restore stdin
    sys.stdout = sys.__stdout__  # Restore stdout
    return output
"""
    return res


def preprocess_code(solution_str, test_cases, num_testcases=10):
    # import pdb; pdb.set_trace()
    fullcode, testcase_fn_names = simple_extract_testcase(solution_str)
    testcase_fn_names = testcase_fn_names[:num_testcases]

    if isinstance(test_cases, str):
        test_cases = json.loads(test_cases)
    good_code, bad_code = (
        test_cases.get("good_code"),
        test_cases.get("bad_code"),
    )
    oj = OnlineJudge(timeout=3, do_parallel=False, early_exit=False, ignore_output=True)
    return fullcode, good_code, bad_code, testcase_fn_names, oj


def check_format(
    solution_str,
    fullcode,
    testcase_fn_names,
    think_score=0.1,
    format_penalty=-2.0,
    num_testcases=10,
):
    """ Check if the code is formatted correctly """
    # Check <think> tag
    reward = 0.0
    correct = True
    if fullcode != "":
        before = solution_str[: solution_str.find(fullcode)]
        # pattern = r"<think>.*?</think>"
        pattern = r"<think>"
        matches = re.findall(pattern, before, re.DOTALL)
        if len(matches) > 2:
            reward += think_score
    # Check format
    if (fullcode == ""
        or "def foo(" in fullcode
        or len(testcase_fn_names) < num_testcases
        or fullcode.count("foo(") < len(testcase_fn_names)
        or "assert" not in fullcode
    ):
        reward += format_penalty
        correct = False
    return reward, correct


def run_test_case(
    fullcode,
    good_code,
    bad_code,
    test_fn_name,
    oj,
    good_test_reward=1.0,
    wrong_test_penalty=-0.1,
):
    tmp_good_code = (
        fullcode
        + "\n"
        + wrap_in_function_test(good_code)
        + f"\n{test_fn_name}()"
        + "\n"
    )
    tmp_bad_code = (
        fullcode + "\n" + wrap_in_function_test(bad_code) + f"\n{test_fn_name}()" + "\n"
    )

    DUMMY_TEST = {"inputs": [""], "outputs": [""]}
    good_code_output = oj.run(tmp_good_code, DUMMY_TEST)
    if good_code_output[0][0] != 1:
        return wrong_test_penalty, RewardType.FailGood
    bad_code_output = oj.run(tmp_bad_code, DUMMY_TEST)
    if bad_code_output[0][0] == 1:
        return 0, RewardType.PassGoodPassBad
    return good_test_reward, RewardType.PassGoodFailBad


def compute_score(
    solution_str, test_cases, num_testcases=10, wrong_test_penalty=-0.1
):
    fullcode, good_code, bad_code, testcase_fn_names, oj = (
        preprocess_code(solution_str, test_cases, num_testcases)
    )
    initial_reward, format_correct = check_format(
        solution_str, fullcode, testcase_fn_names, num_testcases=num_testcases
    )

    if not format_correct:
        return initial_reward, [RewardType.FormatError]

    testcase_fn_score = []
    testcase_fn_reward_types = []

    # If the given code is correct, return error means the test case is wrong
    # If the given code is incorrect, return error means the test case is correct
    for test_fn_name in testcase_fn_names:
        score, reward_type = run_test_case(
            fullcode, good_code, bad_code, test_fn_name, oj, wrong_test_penalty=wrong_test_penalty
        )
        testcase_fn_score.append(score)
        testcase_fn_reward_types.append(reward_type)

    return initial_reward + sum(testcase_fn_score), testcase_fn_reward_types


# This function parallel the test case call; it should be faster
def compute_score_parallel(
    solution_str, test_cases, num_testcases, wrong_test_penalty
):
    fullcode, good_code, bad_code, testcase_fn_names, oj = (
        preprocess_code(solution_str, test_cases, num_testcases)
    )
    initial_reward, format_correct = check_format(
        solution_str, fullcode, testcase_fn_names, num_testcases=num_testcases
    )

    if not format_correct:
        return initial_reward, [RewardType.FormatError]

    testcase_fn_score = []
    testcase_fn_reward_types = []

    # Define a helper function that processes a single test function
    def process_test(test_fn_name):
        return run_test_case(fullcode, good_code, bad_code, test_fn_name, oj, wrong_test_penalty=wrong_test_penalty)

    # Use ThreadPoolExecutor to parallelize the loop
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(testcase_fn_names)) as executor:
        results = list(executor.map(process_test, testcase_fn_names))

    # Unpack results
    testcase_fn_score = [score for score, reward in results]
    testcase_fn_reward_types = [reward for score, reward in results]

    return initial_reward + sum(testcase_fn_score), testcase_fn_reward_types
