import re
import hashlib
from typing import Dict, Tuple, Optional, List
import subprocess
import json
import time
import random
import signal
import sys
import os
import concurrent.futures
import tempfile
import threading

from custom_verl.reward_utils import RewardType


# Ensure \n, \t in some string definitions are not escaped.
def escape_string_literals(code_str):
    # This function will locate C++ string literals and replace
    # actual newline, tab, etc. characters with their escaped forms.
    def replacer(match):
        literal_content = match.group(1)
        escaped_content = literal_content.encode("unicode_escape").decode("utf-8")
        return f'"{escaped_content}"'

    def replacer2(match):
        literal_content = match.group(1)
        escaped_content = literal_content.encode("unicode_escape").decode("utf-8")
        return f"'{escaped_content}'"

    # This regex will match text inside double quotes.
    # Note: This simple regex does not handle escaped quotes within a literal.
    pattern = r'"(.*?)"'  # escape "" string
    code_str = re.sub(pattern, replacer, code_str, flags=re.DOTALL)
    pattern = r"'(.*?)'"  # escape '' string
    code_str = re.sub(pattern, replacer2, code_str, flags=re.DOTALL)
    return code_str


class OnlineJudge(object):
    def __init__(self, timeout=1, do_parallel=True, early_exit=True, ignore_output=False):
        self.timeout = timeout
        self.do_parallel = do_parallel
        self.early_exit = early_exit
        self.ignore_output = ignore_output

    def check_language(self, code_string):
        if code_string.find("#include") != -1:
            return "c++"
        if code_string.find("raw_input") != -1:
            return "python2.7"
        if code_string.find("print ") != -1:
            return "python2.7"
        return "python3.10"

    def run_test_case(self, file_name, lan_bin, code_string, one_in, one_out):
        if lan_bin == "c++":
            cmd = f"ulimit -c 0 && {file_name}"
        else:
            cmd = [lan_bin, file_name]

        try:
            with subprocess.Popen(
                cmd,
                stdin=subprocess.PIPE,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                preexec_fn=os.setsid,
                shell=isinstance(cmd, str),  # shell=True only for C++
            ) as proc:
                try:
                    output, errors = proc.communicate(
                        one_in.encode(), timeout=self.timeout
                    )
                    if errors:
                        self.eval_error += 1
                        res = (0, RewardType.ExecutionError)
                    else:
                        res = 1 if self.ignore_output or output.decode("utf-8").strip() == one_out.strip() else 0
                        res = (res, RewardType.Correct if res == 1 else RewardType.Wrong)
                except subprocess.TimeoutExpired:
                    os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
                    proc.wait()
                    res = (0, RewardType.TimeoutError)
                finally:
                    # Ensure no process remains running after the with block
                    try:
                        os.killpg(os.getpgid(proc.pid), signal.SIGKILL)  # Force kill process group
                    except ProcessLookupError:
                        pass  # Process already terminated
                    proc.wait()  # Reap process
                
                return res
        except Exception:
            self.execute_error += 1
            return (0, RewardType.ExecutionError)

    def check_ce(self, code_string):
        if not code_string:
            return True

        with tempfile.TemporaryDirectory() as tmpdirname:
            md5_hash = hashlib.md5()
            md5_hash.update(
                code_string.encode("utf-8") + str(time.time()).encode("utf-8")
            )
            file_name = os.path.join(tmpdirname, md5_hash.hexdigest())
            cpp_file = f"{file_name}.cpp"

            with open(cpp_file, "w") as f:
                f.write(escape_string_literals(code_string))

            try:
                with subprocess.Popen(
                    f"c++ -o {file_name} {cpp_file}",
                    stdin=subprocess.PIPE,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    preexec_fn=os.setsid,
                    shell=True,
                ) as p:
                    output, errors = p.communicate(timeout=10)
                    if errors:
                        return True
            except subprocess.TimeoutExpired:
                os.killpg(os.getpgid(p.pid), signal.SIGTERM)
                p.wait()
                return True

            return False

    def compile(self, code_string, lang_bin, tmpdirname):
        if not code_string:
            return None

        md5_hash = hashlib.md5()
        md5_hash.update(code_string.encode("utf-8") + str(time.time()).encode("utf-8"))
        file_name = os.path.join(tmpdirname, md5_hash.hexdigest())

        if lang_bin == "c++":
            cpp_file = file_name + ".cpp"
            with open(cpp_file, "w+") as f:
                f.write(escape_string_literals(code_string))

            try:
                with subprocess.Popen(
                    f"c++ -o {file_name} {cpp_file}",
                    stdin=subprocess.PIPE,
                    stdout=subprocess.PIPE,
                    stderr=subprocess.PIPE,
                    preexec_fn=os.setsid,
                    shell=True,
                ) as p:
                    output, errors = p.communicate(timeout=10)
                    if errors:
                        return None  # Compilation failed
            except subprocess.TimeoutExpired:
                os.killpg(os.getpgid(p.pid), signal.SIGTERM)
                p.wait()
                return None
            finally:
                try:
                    os.remove(cpp_file)
                except FileNotFoundError:
                    pass
            return file_name

        elif lang_bin == "python3.10":
            py_file = file_name + ".py"
            with open(py_file, "w") as f:
                # f.write(code_string)
                f.write(escape_string_literals(code_string))
            return py_file

    def run(self, code_string, test_cases) -> List[int]:
        self.execute_error = 0
        self.eval_error = 0

        lan_bin = self.check_language(code_string)
        # Create a list to store results in order.
        results = [None] * len(test_cases["inputs"])
        reward_types = [None] * len(test_cases["inputs"])

        with tempfile.TemporaryDirectory() as tmpdirname:
            file_name = self.compile(code_string, lan_bin, tmpdirname)
            if not file_name:
                return (
                    [0] * len(test_cases["inputs"]),
                    [RewardType.CompileError] * len(test_cases["inputs"]),
                )

            # Prepare a list of test cases with their index so we can restore order.
            cases = list(
                zip(
                    range(len(test_cases["inputs"])),
                    test_cases["inputs"],
                    test_cases["outputs"],
                )
            )

            if self.do_parallel:
                # Use a ThreadPoolExecutor to run test cases concurrently.
                with concurrent.futures.ThreadPoolExecutor(
                    max_workers=min(32, len(cases))
                ) as executor:
                    # Map each submitted future to its test index.
                    future_to_index = {
                        executor.submit(
                            self.run_test_case,
                            file_name,
                            lan_bin,
                            code_string,
                            str(one_in),
                            str(one_out),
                        ): idx
                        for idx, one_in, one_out in cases
                    }

                    # Process futures as they complete.
                    for future in concurrent.futures.as_completed(future_to_index):
                        idx = future_to_index[future]
                        try:
                            full_res = future.result()
                        except Exception:
                            full_res = (
                                0,
                                RewardType.ExecutionError,
                            )  # Treat any exception as a failed test case.
                        finally:
                            result = full_res[0]
                            reward_type = full_res[1]

                        self.execute_error += 1 if result == 0 else 0
                        # If early exit is enabled and one test fails, cancel pending futures.
                        results[idx] = result
                        reward_types[idx] = reward_type
                        if self.early_exit and result == 0:
                            for fut in future_to_index:
                                if not fut.done():
                                    fut.cancel()
                            # For any futures that never ran (or got cancelled), record 0.
                            for i, res in enumerate(results):
                                if res is None:
                                    results[i] = 0
                                    reward_types[i] = RewardType.EarlyStopped
                            break

                # In case early exit was not triggered but some tasks are missing, mark them as 0.
                results = [0 if r is None else r for r in results]
                reward_types = [
                    RewardType.EarlyStopped if r is None else r for r in reward_types
                ]
            else:
                # Sequential execution
                for idx, one_in, one_out in cases:
                    try:
                        full_res = self.run_test_case(
                            file_name, lan_bin, code_string, str(one_in), str(one_out)
                        )
                    except Exception:
                        full_res = (0, RewardType.ExecutionError)
                    finally:
                        result = full_res[0]
                        reward_type = full_res[1]
                        self.execute_error += 1 if result == 0 else 0

                    results[idx] = result
                    reward_types[idx] = reward_type
                    if self.early_exit and result == 0:
                        # Mark all remaining test cases as failed and break.
                        for j in range(idx + 1, len(test_cases["inputs"])):
                            results[j] = 0
                            reward_types[j] = RewardType.EarlyStopped
                        break

        # print(
        #     f"Execute Error num: {self.execute_error}, Eval Error: {self.eval_error} among {len(test_cases['inputs'])} test cases"
        # )
        return results, reward_types

    def run_parallel_code(
        self, code_strings: List[str], testcases_list: List[dict]
    ) -> List[Tuple[List[int], List[RewardType]]]:
        """
        Executes multiple code submissions concurrently.

        Args:
            code_strings (List[str]): A list of code strings to be executed.
            testcases_list (List[dict]): A list of test cases corresponding to each code string.
                                         Each test case dictionary should have keys "inputs" and "outputs".

        Returns:
            List[Tuple[List[int], List[RewardType]]]: A list where each element is a tuple containing
            the list of results and the list of reward types for the corresponding code string.
        """
        if len(code_strings) != len(testcases_list):
            raise ValueError(
                "The number of code strings must match the number of testcases lists."
            )

        parallel_results = [None] * len(code_strings)
        # Use a ThreadPoolExecutor to run multiple code submissions concurrently.
        with concurrent.futures.ThreadPoolExecutor(
            max_workers=min(8, len(code_strings))
        ) as executor:
            future_to_index = {
                executor.submit(self.run, code, testcases): idx
                for idx, (code, testcases) in enumerate(
                    zip(code_strings, testcases_list)
                )
            }

            for future in concurrent.futures.as_completed(future_to_index):
                idx = future_to_index[future]
                try:
                    res = future.result()
                except Exception:
                    # If an exception occurs, mark all test cases for this submission as failed.
                    test_cases = testcases_list[idx]
                    res = (
                        [0] * len(test_cases["inputs"]),
                        [RewardType.ExecutionError] * len(test_cases["inputs"]),
                    )
                parallel_results[idx] = res

        return parallel_results

    def score(self, pid, code_string):
        all_tests, correct_tests = self.run(pid, code_string)
        return int(5.0 * correct_tests / all_tests)


def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
    # Split response to isolate assistant output
    if "Assistant:" in solution_str:
        processed_str = solution_str.split("Assistant:", 1)[1]
        question_str = solution_str.split("Assistant:", 1)[0]
    elif "<|im_start|>assistant" in solution_str:
        processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
        question_str = solution_str.split("<|im_start|>assistant", 1)[0]
    else:
        print("[Error] Failed to locate model response header")
        return None, solution_str, ""

    # Extract final answer using XML-style tags
    answer_pattern = r"<answer>(.*?)</answer>"
    matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))

    if not matches:
        print("[Error] No valid answer tags found")
        return None, processed_str, question_str

    final_answer = matches[-1].group(1).strip()
    return final_answer, processed_str, question_str


def validate_response_structure(processed_str: str) -> bool:
    """Performs comprehensive validation of response structure.

    Args:
        processed_str: Processed response string from the model

    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    debug_str = []
    debug_str.append("\n[Structure Validation]")
    validation_passed = True

    # Check required tags
    tags = {
        "think_start": ("<think>", 1),
        "think_end": ("</think>", 1),
        "answer_start": ("<answer>", 1),
        "answer_end": ("</answer>", 1),
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)

        debug_str.append(f"  {tag_str}: count={count}, position={pos}")

        if count != expected_count:
            debug_str.append(
                f"  [Error] {tag_str} appears {count} times (expected {expected_count})"
            )
            validation_passed = False

    # Verify tag order
    if (
        positions["think_start"] > positions["think_end"]
        or positions["think_end"] > positions["answer_start"]
        or positions["answer_start"] > positions["answer_end"]
    ):
        debug_str.append(
            "  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>"
        )
        validation_passed = False
    elif (
        processed_str.strip()[-len("</answer><|endoftext|>") :]
        != "</answer><|endoftext|>"
    ):
        debug_str.append(
            "  [Error] Incorrect end token: Expected </answer><|endoftext|>"
        )
        validation_passed = False
    elif processed_str.strip()[0 : len("<think>")] != "<think>":
        debug_str.append("  [Error] Incorrect start token: Expected <think>")
        validation_passed = False
    else:
        debug_str.append("  Tag sequence validation passed")

    return validation_passed, debug_str
