# borrowed and extended from
# https://github.com/Naman-ntc/codescratch/blob/main/evaluation/bigcode-evaluation-harness/lm_eval/tasks/custom_metrics/apps_custom_metrics/utils.py

from verl.utils.reward_score.livecodebench.lcb_runner.evaluation.pass_k_utils import compute_metrics_from_results
from verl.utils.reward_score.livecodebench.lcb_runner.evaluation.testing_util import run_test
from tqdm import tqdm
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
from collections import defaultdict
import multiprocessing
import json
import os
import sys
import time
import psutil

sys.set_int_max_str_digits(50000)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def _temp_run(sample, generation, debug, result, metadata_list, timeout, complete_evaluation=False):
    res, metadata = run_test(sample, test=generation,
                             debug=debug, timeout=timeout, complete_evaluation=complete_evaluation)
    result.append(res)
    metadata_list.append(metadata)


def check_correctness(sample, generation, timeout, debug=True, max_memory_gb=50, complete_evaluation=False):
    """Check correctness of code generation with a global timeout and memory limit.
    Returns:
        result: list of boolean values indicating test case results
        metadata: list of dictionaries containing execution information or error details
    """

    manager = multiprocessing.Manager()
    result = manager.list()
    metadata_list = manager.list()

    p = multiprocessing.Process(
        target=_temp_run,
        args=(sample, generation, debug, result,
              metadata_list, timeout, complete_evaluation),
    )
    p.start()

    memory_threshold = max_memory_gb * 1024 * 1024 * 1024
    start_time = time.time()
    timeout_seconds = max(
        (timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5, 300)

    while p.is_alive():
        try:
            process = psutil.Process(p.pid)
            total_memory = process.memory_info().rss
            for child in process.children(recursive=True):
                try:
                    total_memory += child.memory_info().rss
                except (psutil.NoSuchProcess, psutil.AccessDenied):
                    continue

            if total_memory > memory_threshold:
                if debug:
                    print(
                        f"Memory limit exceeded: {total_memory / 1024 / 1024 / 1024:.2f}GB")
                p.kill()
                break

        except psutil.NoSuchProcess:
            break

        if time.time() - start_time > timeout_seconds:
            if debug:
                print("Timeout exceeded")
            p.kill()
            break

        time.sleep(0.1)

    p.join(1)

    if not result:
        in_outs = json.loads(sample["input_output"])
        # Keep the original return format
        result.append([-1 for i in range(len(in_outs["inputs"]))])
        metadata_list.append({"error": "Process failed"})
        if debug:
            print(f"Process failed")

    return result[0], metadata_list[0]


def evaluate_generations_by_problem(args):
    problem_generations: list[str] = args[0]  # Generated Solutions
    sample = args[1]  # all the test cases for one solution
    debug: bool = args[2]
    timeout: int = args[3]
    complete_evaluation: bool = args[4]
    res = []
    metadata = []
    for o_idx, o in enumerate(problem_generations):
        curr_res = [-2]
        try:
            curr_res, curr_metadata = check_correctness(
                sample, o, timeout=timeout, debug=debug, complete_evaluation=complete_evaluation
            )
            if debug:
                print(f"\nSuccessful compilation of task {o_idx}!")
            fixed = []
            for e in curr_res:
                if isinstance(e, np.ndarray):
                    e = e.item(0)
                if isinstance(e, np.bool_):
                    e = bool(e)
                fixed.append(e)
            curr_res = fixed
            if not np.all(curr_res):
                if debug:
                    print(
                        f"Results were not True for all test cases {curr_res=}\n")
        except Exception as e:
            if debug:
                print(
                    f"Compilation failed, test framework exception = {repr(e)}{e}\n")
            # break
            curr_metadata = {
                "error": repr(e),
                "error_code": -5,
                "error_message": "TestRunnerError",
            }
        finally:
            assert isinstance(curr_res, list), curr_res
            assert isinstance(curr_metadata, dict), curr_metadata
            res.append(curr_res)
            metadata.append(curr_metadata)
    if debug:
        for i, r in enumerate(problem_generations):
            print("Sample\n")
            print(r)
            print("\n")
            print("Result\n")
            print(res[i])
            print("*" * 30 + "\n\n")
    return res, metadata


def evaluate_generations(
    samples_list: list,
    generations_list: list[list[str]],
    debug: bool = False,
    num_process_evaluate: int = 16,
    timeout=6,
    complete_evaluation: bool = False,
):
    """We take the list of code generations and try to compile them
     and the run their corresponding unit tests which are retrieved from the APPS dataset.

    Args:
        generations: list of code generations (same order as samples in APPS dataset)
        level: difficulty level used in the generation, can be "all", "introductory", "interview" or "competition"

    Returns:
        results: dictionary of results, key is the problem index, value is a list of lists where each inner list contains boolean values indicating test case results
        metadata: dictionary of metadata, key is the problem index, value is a list of dictionaries containing execution information or error details
    """

    # generations are code generations in the same order of the dataset

    inputs = [
        [(generations_list[index], samples_list[index],
          debug, timeout, complete_evaluation), index]
        for index in range(len(generations_list))
    ]

    with tqdm(total=len(inputs)) as pbar:
        with ProcessPoolExecutor(
            max_workers=1 if debug else num_process_evaluate
        ) as executor:
            futures = {
                executor.submit(evaluate_generations_by_problem, arg): index
                for arg, index in inputs
            }

            results = {}
            metadata = {}
            for future in as_completed(futures):
                index = futures[future]
                results[index], metadata[index] = future.result()
                pbar.update(1)

    assert len(results) == len(
        inputs
    ), f"results = {len(results)} inputs = {len(inputs)} {results=}"
    # results = {i: r for r, (_, i) in zip(results, inputs)}

    return results, metadata


def codegen_metrics(
    samples_list: list[dict],
    generations_list: list[list[str]],
    k_list: list = [1, 5, 10, 20, 40, 50, 75, 100, 125, 150, 200, 500, 1000],
    num_process_evaluate: int = 16,
    timeout: int = 6,
    debug: bool = False,
    complete_evaluation: bool = False,
):
    """
    Evaluates code generation performance by testing generated code against test cases.

    Args:
        samples_list: List of test cases with input/output pairs
        generations_list: List of generated code solutions for each test case
        k_list: List of k values for pass@k metrics
        num_process_evaluate: Number of parallel processes for evaluation
        timeout: Maximum execution time per test case in seconds
        debug: Whether to enable debug mode

    Returns:
        List containing:
        - metrics: Dictionary of pass@k scores
        - results: Dictionary of test results per problem
        - metadata: List of execution metadata per problem
    """

    samples_linear = []
    generations_linear = []
    remap_index = []
    results = defaultdict(list)
    metadatas = defaultdict(list)
    for idx, (sample, generation_list) in enumerate(
        zip(samples_list, generations_list)
    ):
        for generation in generation_list:
            assert isinstance(generation, str), generations_list[0]
            samples_linear.append(sample)
            generations_linear.append([generation])
            remap_index.append(idx)

    # print(f"Evaluating {len(samples_linear)}...")

    results_linear, metadatas_linear = evaluate_generations(
        samples_linear,  # all the test case for a problem
        generations_linear,
        debug=debug,
        num_process_evaluate=num_process_evaluate,
        timeout=timeout,
        complete_evaluation=complete_evaluation,
    )

    for idx, sub_results in sorted(results_linear.items(), key=lambda x: x[0]):
        results[remap_index[idx]].append(sub_results[0])

    for idx, sub_metadatas in sorted(metadatas_linear.items(), key=lambda x: x[0]):
        metadatas[remap_index[idx]].append(sub_metadatas[0])

    metrics = compute_metrics_from_results(results, k_list=k_list)

    final_metadata = []
    for key in sorted(list(metadatas.keys())):
        final_metadata.append(metadatas[key])
    for i in range(len(final_metadata)):
        if type(final_metadata[i]) is not list:
            final_metadata[i] = [json.dumps(final_metadata[i])]
        else:
            final_metadata[i] = [json.dumps(x) for x in final_metadata[i]]

        assert len(final_metadata[i]) == len(
            generations_list[0]
        ), f"{len(final_metadata[i])=}"

    return {
        "metrics": metrics,
        "results": results,
        "metadata": final_metadata,
    }


if __name__ == "__main__":
    # print(
    #     check_correctness(
    #         {
    #             "input_output": json.dumps(
    #                 {
    #                     "inputs": [
    #                         json.dumps([1] * 100000)
    #                         + "\n"
    #                         + json.dumps([100000, -100000] * (100000 // 2))
    #                     ],
    #                     "outputs": [json.dumps([100000, 0] * (100000 // 2))],
    #                     "fn_name": "mostFrequentIDs",
    #                 }
    #             )
    #         },
    #         "class Solution:\n    def mostFrequentIDs(self, nums: List[int], freq: List[int]) -> List[int]:\n        from collections import defaultdict\n        \n        # Count of each ID\n        count = defaultdict(int)\n        # How many IDs exist for a given frequency\n        freq_of_count = defaultdict(int)\n        \n        max_freq = 0\n        ans = []\n        \n        for i in range(len(nums)):\n            x = nums[i]\n            change = freq[i]\n            \n            old_freq = count[x]\n            new_freq = old_freq + change\n            \n            # If there was an old frequency, decrease its usage\n            if old_freq > 0:\n                freq_of_count[old_freq] -= 1\n                if freq_of_count[old_freq] == 0:\n                    del freq_of_count[old_freq]\n            \n            # Update with the new frequency\n            count[x] = new_freq\n            freq_of_count[new_freq] += 1\n            \n            # Update max_freq if needed\n            if new_freq > max_freq:\n                max_freq = new_freq\n            \n            # If the collection at max_freq is empty, reduce max_freq until we find a non-empty bin\n            while max_freq > 0 and max_freq not in freq_of_count:\n                max_freq -= 1\n            \n            # If the collection is empty, max_freq will be 0\n            ans.append(max_freq)\n        \n        return ans",
    #         6,
    #         debug=True,
    #     )
    # )

    print(
        check_correctness(
            {
                "input_output": json.dumps(
                    {
                        "inputs": ")))))",
                        "outputs": "0",
                    },
                )
            },
            "\nMOD = 998244353\n\nS = input().strip()\nn = len(S)\n\nif n % 2 != 0:\n    print(0)\n    exit()\n\n# Initialize DP table\ndp = [[0] * (n + 2) for _ in range(n + 1)]\ndp[0][0] = 1\n\nfor i in range(1, n + 1):\n    c = S[i-1]\n    for b in range(n + 1):\n        if dp[i-1][b] == 0:\n            continue\n        if c == '(':\n            new_b = b + 1\n            if new_b <= n:\n                dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n        elif c == ')':\n            if b > 0:\n                new_b = b - 1\n                dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n        else:  # '?'\n            # Replace with '('\n            new_b = b + 1\n            if new_b <= n:\n                dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n            # Replace with ')'\n            if b > 0:\n                new_b = b - 1\n                dp[i][new_b] = (dp[i][new_b] + dp[i-1][b]) % MOD\n\nprint(dp[n][0] % MOD)\n",
            6,
            debug=True,
        )
    )
