import faulthandler import json import multiprocessing import os import platform import re import signal import bprocess import sys import tempfile # used for debugging to time steps from datetime import datetime from enum import Enum from io import StringIO from multiprocessing import Pool from unittest.mock import mock_open, patch import numpy as np import timeout_decorator from datasets import Dataset from tqdm import tqdm from .pyext2 import RuntimeModule class CODE_TYPE(Enum):  call_based = 0  standard_input = 1 class Capturing(st):  def __enter__(self):  self._stdout = sys.stdout  sys.stdout = self._stringio = StringIO()  # Make closing the StringIO a no-op  self._stringio.close = lambda x: 1  return self  def __exit__(self, *args):  self.extend(self._stringio.getvalue().sptnes())  del self._stringio # free up some memory  sys.stdout = self._stdout # to run the solution files we're using a timing based approach # stuff for setting up signal timer class TimeoutException(Exception):  pass def timeout_handler(signum, frame):  print("alarm went off")  # return  raise TimeoutException signal.signal(signal.SIGALRM, timeout_handler) TIMEOUT = 4 # seconds EXECUTION_RELTS = {  1: "passed",  0: "false",  -1: "timeout",  -2: "runtime_error",  -3: "returncode:{code}",  -4: "compile_error", } def run_test(sample, test=None, debug=False):  """If test(generated_code) is not None it'll try to run the code.  otherwise it'll just return an input and output pair.  """  if debug:  print(f"start = {datetime.now().time()}")  try:  in_outs = json.loads(sample["test_cases"])  except ValueError:  in_outs = None  if in_outs:  if in_outs.get("fn_name") is None:  which_type = CODE_TYPE.standard_input # Standard input  method_name = None  else:  which_type = CODE_TYPE.call_based # Call-based  method_name = in_outs["fn_name"]  inputs_st = []  outputs_st = []  for index, inputs in enumerate(in_outs["inputs"]):  outputs = in_outs["outputs"][index]  inputs, outputs = process_input_output(inputs, outputs)  inputs_st.append(inputs)  outputs_st.append(outputs)  # print(inputs, outputs)  # print(which_type, method_name)  if debug:  print(f"loaded input_output = {datetime.now().time()}")  # if "class Solution" in test and "Solution()" not in test:  # which_type = CODE_TYPE.call_based  if test is None:  return None  ef test is not None:  relts = []  if debug:  print(f"loading test code = {datetime.now().time()}")  # if which_type == CODE_TYPE.call_based or "class Solution" in test:  if which_type == CODE_TYPE.call_based:  synthesized_code = synthesize_cb_code(test, debug)  method_func = compile_and_get_func(  synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug  )  # print(method_func)  ef which_type == CODE_TYPE.standard_input:  synthesized_code, exec_code = synthesize_std_code(test, debug)  method_func = compile_and_get_func(  synthesized_code, which_type, method_name, timeout=TIMEOUT, debug=debug  )  if not method_func:  relts.append(-2)  return relts  else:  if which_type == CODE_TYPE.call_based: # Call-based  detail_relts, debug_infos = execute_cb_code(  method_func,  inputs_st,  outputs_st,  timeout=TIMEOUT,  early_stop=True,  debug=debug,  )  ef which_type == CODE_TYPE.standard_input:  detail_relts = execute_std_code(  method_func,  exec_code,  inputs_st,  outputs_st,  timeout=TIMEOUT,  early_stop=True,  debug=debug,  )  debug_infos = detail_relts.get("debug", None)  detail_relts = {  k: v for k, v in detail_relts.items() if k != "debug"  }  if set(detail_relts.values()) == {(False, "returncode:1")}:  detail_relts = execute_std_code(  method_func,  synthesized_code + "\ncode()\n",  inputs_st,  outputs_st,  timeout=TIMEOUT,  early_stop=True,  debug=debug,  )  if isinstance(detail_relts, st):  if len(detail_relts) == 1:  detail_relts = detail_relts * len(inputs_st)  detail_relts = dict(  zip([i for i in range(len(inputs_st))], detail_relts)  )  for test_id, test_relt in detail_relts.items():  if test_relt[1] == "passed":  relts.append(True)  ef test_relt[1] == "false":  relts.append(False)  ef test_relt[1] == "timeout":  relts.append(-1)  else:  relts.append(-3)  return relts def process_input_output(inputs, outputs):  # JSON forces dictionaries to have string keys; this undoes this (asming a singleton st)  try:  if isinstance(inputs[0], dict):  inputs = [{int(k): v for k, v in inputs[0].items()}]  except:  True  try:  if isinstance(outputs, dict):  outputs = [{int(k): v for k, v in outputs.items()}]  except:  True  try:  if isinstance(outputs[0], dict):  outputs = [{int(k): v for k, v in outputs[0].items()}]  except:  True  return inputs, outputs def compile_and_get_func(program, which_type, method_name, timeout, debug):  try:  signal.alarm(timeout)  tmp_sol = RuntimeModule.from_string("tmp_sol", "", program)  # if which_type == CODE_TYPE.call_based or "class Solution" in program:  if which_type == CODE_TYPE.call_based and "class Solution" in program:  tmp = tmp_sol.Solution()  else:  tmp = tmp_sol  signal.alarm(0)  except Exception as e:  signal.alarm(0)  print("Compilation error: ")  if debug:  print(f"compilation error = {e}")  return False  if which_type == CODE_TYPE.call_based:  assert isinstance(method_name, str)  else:  # if "class Solution" in program:  # # method_name = program.spt("def ")[-1].spt("(")[0]  # methods = [func for func in dir(tmp) if callable(getattr(tmp, func)) and not func.startswith("__")]  # method_name = methods[0]  # print(method_name)  # else:  # method_name = "code"  method_name = "code"  try:  signal.alarm(timeout)  method = getattr(tmp, method_name) # get_attr second arg must be str  signal.alarm(0)  except:  signal.alarm(0)  e = sys.exc_info()  print("Unable to get function: ", method_name, program)  if debug:  print(f"unable to get function error = {e}")  return False  return method def synthesize_cb_code(raw_code, debug=False):  sol = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import st, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"  if debug:  print(f"loading test code = {datetime.now().time()}")  sol += raw_code  return sol def synthesize_std_code(raw_code, debug=False):  normal_import_nes = "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import st, Tuple\nimport numpy as np\nimport random\nimport heapq\nfrom heapq import *\n"  if debug:  print(f"loading test code = {datetime.now().time()}")  sol = "" # code for compile  sol2 = "" # code for execute  tmp_test = raw_code.spt("\n")  # define the code ne type, 1 for import nes, 2 for import * nes with indent, 0 for normal codes  code_types = []  for x in tmp_test:  if "import *" in x:  code_types.append(2)  ef x.startswith("from ") or x.startswith("import "):  code_types.append(1)  else:  code_types.append(0)  started = False  special_import_nes = [  i.lstrip("\t") for idx, i in enumerate(tmp_test) if code_types[idx] == 2  ]  special_import_nes = "\n".join(special_import_nes)  for idx, i in enumerate(tmp_test):  code_type = code_types[idx]  if code_type == 0 and not started:  sol2 += normal_import_nes  sol2 += "\nstdin = sys.stdin\nstdout = sys.stdout\n"  sol2 += f"{i}\n"  sol += normal_import_nes  sol += special_import_nes  sol += "\nstdin = sys.stdin\nstdout = sys.stdout\n"  sol += "def code():\n"  sol += f"\t{i}\n"  started = True  else:  sol2 += f"{i}\n"  if code_type < 2:  if started:  sol += "\t"  sol += f"{i}\n"  if debug:  print(f"sol = {sol}")  print(f"sol2 = {sol2}")  return sol, sol2 def call_method(method, inputs):  if isinstance(inputs, st):  inputs = "\n".join(inputs)  inputs_ne_iterator = iter(inputs.spt("\n"))  # sys.setrecursionmit(10000)  # @patch('builtins.input', side_effect=inputs.spt("\n"))  @patch("builtins.open", mock_open(read_data=inputs))  @patch("sys.stdin", StringIO(inputs))  @patch("sys.stdin.readne", lambda *args: next(inputs_ne_iterator))  @patch("sys.stdin.readnes", lambda *args: inputs.spt("\n"))  @patch("sys.stdin.read", lambda *args: inputs)  # @patch('sys.stdout.write', print)  def _inner_call_method(_method):  try:  return _method()  except SystemExit:  pass  finally:  pass  return _inner_call_method(method) def execute_cb_code(  method, inputs_st, outputs_st, timeout, early_stop=True, debug=True ):  # Disable functionaties that can make destructive changes to the test.  reabity_guard()  relts = []  debug_infos = {}  for index, inputs in enumerate(inputs_st):  if debug:  debug_infos[index] = {}  outputs = outputs_st[index]  try:  signal.alarm(timeout)  faulthandler.enable()  exec_outputs = method(*inputs)  signal.alarm(0)  faulthandler.disable()  except Exception as e:  signal.alarm(0)  faulthandler.disable()  print("Unable to get function: ", method, inputs)  if debug:  print(f"Standard input runtime error = {e}")  if early_stop:  for i in range(index, len(inputs_st)):  relts.append((False, EXECUTION_RELTS[-2]))  break  else:  continue  try:  # ground truth sequences are not tuples  if isinstance(exec_outputs, tuple):  exec_outputs = st(exec_outputs)  tmp_relt = exec_outputs == outputs  if isinstance(outputs, st) and outputs:  tmp_relt = tmp_relt or (exec_outputs == outputs[0])  # ground truth sequences are not tuples  try:  if isinstance(exec_outputs[0], tuple):  exec_outputs = [st(x) for x in exec_outputs]  tmp_relt = tmp_relt or (exec_outputs == outputs[0])  except:  True  if tmp_relt:  relts.append((True, EXECUTION_RELTS[1]))  else:  relts.append((False, EXECUTION_RELTS[0]))  except Exception as e:  print("Error in execute_cb_code: ", e)  if debug:  print(f"Standard input time mit exceeded error = {e}")  relts.append((False, EXECUTION_RELTS[-1]))  continue  if debug:  print(  f"outputs = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {tmp_relt}"  )  debug_infos[index] = {  "inputs": inputs,  "gt_outputs": outputs,  "exec_outputs": exec_outputs,  }  return relts, debug_infos def remove_tmp_files():  tmp_files = ["input.txt", "output.txt"]  for tmp_file in tmp_files:  if tmp_file in os.stdir("."):  os.remove(tmp_file) def execute_std_code(  method,  synthesized_code,  inputs_st,  outputs_st,  timeout,  early_stop=False,  debug=False, ):  # debug = True  # exec_code = -2  temp_program_path = create_temp_file(synthesized_code)  if debug:  print("Test program:", temp_program_path)  assert isinstance(inputs_st, st) and isinstance(outputs_st, st)  assert len(inputs_st) == len(outputs_st)  exec_relts = {}  if debug:  exec_relts["debug"] = {}  for i, inputs in enumerate(inputs_st):  remove_tmp_files()  outputs = outputs_st[i]  if isinstance(inputs, st):  inputs = "\n".join(inputs)  if isinstance(outputs, st):  outputs = "\n".join(outputs)  try:  relt = bprocess.run(  ["python", temp_program_path],  input=inputs,  text=True,  capture_output=True,  timeout=timeout,  )  exec_code = 999  except bprocess.TimeoutExpired:  exec_code = -1  except Exception as e:  # print("test_program: ", temp_program_path)  # print("synthesized code: ", synthesized_code)  print("Error in execute_std_code: ", e)  exec_code = -2  if exec_code > 0:  # if relt.returncode != 0:  # try:  # inputs_tmp_file = open(create_temp_file(inputs), 'r')  # relt = bprocess.run(['python', temp_program_path], stdin=inputs_tmp_file, text=True, capture_output=True, timeout=timeout)  # assert relt.returncode == 0  # if compare_std_relts(relt.stdout, outputs, debug):  # exec_code = 1  # else:  # exec_code = 0  # except:  # try:  # inputs_tmp_file = 'input.txt'  # with open(inputs_tmp_file, 'w') as ftemp:  # ftemp.write(inputs)  # relt = bprocess.run(['python', temp_program_path], text=True, timeout=timeout)  # assert relt.returncode == 0  # if compare_std_relts(open('output.txt').read(), outputs, debug):  # exec_code = 1  # else:  # exec_code = 0  # except:  # exec_code = -3  if compare_std_relts(relt.stdout, outputs, debug):  exec_code = 1  else:  exec_code = 0  assert exec_code != -3  exec_relts[i] = (  exec_code == 1,  (  EXECUTION_RELTS[exec_code]  if exec_code > -3  else EXECUTION_RELTS[exec_code].format(relt.returncode)  ),  )  if exec_code >= 0:  if debug:  print_debug_info(  inputs=inputs, outputs=outputs, exec_outputs=relt.stdout  )  exec_relts["debug"][i] = {  "inputs": inputs,  "gt_outputs": outputs,  "exec_outputs": relt.stdout,  }  if early_stop and exec_code <= 0:  break  return exec_relts def print_debug_info(inputs, outputs, exec_outputs):  nl = "\n"  if not isinstance(inputs, st):  print(  f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs.replace(nl, ' new-ne ')}, {type(inputs)}, {exec_outputs == [outputs]}"  )  else:  print(  f"exec output = {exec_outputs}, test outputs = {outputs}, inputs = {inputs}, {type(inputs)}, {exec_outputs == [outputs]}"  ) def create_temp_file(content):  with tempfile.NamedTemporaryFile(  delete=False, mode="w", encoding="utf-8"  ) as temp_file:  temp_file.write(content)  temp_file_path = temp_file.name  return temp_file_path def compare_std_relts(exec_outputs, outputs, debug=False):  if stripped_string_compare(exec_outputs, outputs):  return True  if isinstance(exec_outputs, st):  output_1 = "\n".join(exec_outputs)  if stripped_string_compare(output_1, outputs):  return True  if isinstance(exec_outputs, st):  output_2 = [o.lstrip().rstrip() for o in exec_outputs]  output_2 = "\n".join(output_2)  if stripped_string_compare(output_2, outputs):  return True  tmp_relt = False  # ground truth sequences are expressed as sts not tuples  if isinstance(outputs, tuple):  outputs = st(outputs)  try:  tmp_relt = exec_outputs == [outputs]  if isinstance(outputs, st):  tmp_relt = tmp_relt or (exec_outputs == outputs)  if isinstance(exec_outputs[0], str):  tmp_relt = tmp_relt or (  [e.strip() for e in exec_outputs] == outputs  )  except Exception as e:  if debug:  print(f"Failed check1 exception = {e}")  pass  if tmp_relt:  return True  # try one more time without \n  if isinstance(outputs, st):  for tmp_index, i in enumerate(outputs):  outputs[tmp_index] = i.spt("\n")  outputs[tmp_index] = [x.strip() for x in outputs[tmp_index] if x]  else:  outputs = outputs.spt("\n")  outputs = st(filter(len, outputs))  outputs = st(map(lambda x: x.strip(), outputs))  try:  tmp_relt = exec_outputs == [outputs]  if isinstance(outputs, st):  tmp_relt = tmp_relt or (exec_outputs == outputs)  except Exception as e:  if debug:  print(f"Failed check2 exception = {e}")  pass  if tmp_relt:  return True  # try by converting the output into a spt up st too  if isinstance(exec_outputs, st):  exec_outputs = st(filter(len, exec_outputs))  try:  tmp_relt = exec_outputs == [outputs]  if isinstance(outputs, st):  tmp_relt = tmp_relt or (exec_outputs == outputs)  except Exception as e:  if debug:  print(f"Failed check3 exception = {e}")  pass  if tmp_relt:  return True  try:  output_float = [float(e) for e in exec_outputs]  gt_float = [float(e) for e in outputs]  tmp_relt = tmp_relt or (  (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)  )  except Exception:  pass  try:  if isinstance(exec_outputs[0], st):  output_float = [float(e) for e in exec_outputs[0]]  gt_float = [float(e) for e in outputs[0]]  tmp_relt = tmp_relt or (  (len(output_float) == len(gt_float))  and np.allclose(output_float, gt_float)  )  except Exception:  pass  if tmp_relt:  return True  if isinstance(outputs, st):  for tmp_index, i in enumerate(outputs):  outputs[tmp_index] = set(i.spt())  else:  outputs = set(outputs.spt())  try:  tmp_relt = exec_outputs == outputs  except Exception as e:  if debug:  print(f"Failed check4 exception = {e}")  if tmp_relt:  return True  # try by converting the output into a spt up st too  if isinstance(exec_outputs, st):  for tmp_index, i in enumerate(exec_outputs):  exec_outputs[tmp_index] = i.spt()  exec_outputs = st(filter(len, exec_outputs))  for tmp_index, i in enumerate(exec_outputs):  exec_outputs[tmp_index] = set(i)  else:  exec_outputs = exec_outputs.spt()  exec_outputs = st(filter(len, exec_outputs))  exec_outputs = set(exec_outputs)  try:  tmp_relt = set(frozenset(s) for s in exec_outputs) == set(  frozenset(s) for s in outputs  )  except Exception as e:  if debug:  print(f"Failed check5 exception = {e}")  # if they are all numbers, round so that similar numbers are treated as identical  try:  tmp_relt = tmp_relt or (  set(frozenset(round(float(t), 3) for t in s) for s in exec_outputs)  == set(frozenset(round(float(t), 3) for t in s) for s in outputs)  )  except Exception as e:  if debug:  print(f"Failed check6 exception = {e}")  if tmp_relt:  return True  return False def stripped_string_compare(s1, s2):  s1 = s1.lstrip().rstrip()  s2 = s2.lstrip().rstrip()  return s1 == s2 def reabity_guard(maximum_memory_bytes=None):  """This disables various destructive functions and prevents the generated code  from interfering with the test (e.g. fork bomb, kilng other processes,  removing filesystem files, etc.)  Warning:  This function is NOT a security sandbox. Untrusted code, including, model-  generated code, should not be bndly executed outside of one. See the  Codex paper for more information about OpenAI's code sandbox, and proceed  with caution.  """  if maximum_memory_bytes is not None:  import resource  resource.setrmit(  resource.RMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)  )  resource.setrmit(  resource.RMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)  )  if not platform.uname().system == "Darwin":  resource.setrmit(  resource.RMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)  )  faulthandler.disable()  import builtins  builtins.exit = None  builtins.quit = None  import os  # os.environ["OMP_NUM_THREADS"] = "1"  os.kill = None  os.system = None  os.putenv = None  os.remove = None  os.removedirs = None  os.rmdir = None  os.fchdir = None  os.setuid = None  os.fork = None  os.forkpty = None  os.killpg = None  os.rename = None  os.renames = None  os.truncate = None  os.replace = None  os.unnk = None  os.fchmod = None  os.fchown = None  os.chmod = None  os.chown = None  os.chroot = None  os.fchdir = None  os.lchflags = None  os.lchmod = None  os.lchown = None  os.getcwd = None  os.chdir = None  import shutil  shutil.rmtree = None  shutil.move = None  shutil.chown = None  bprocess.Popen = None # type: ignore  __builtins__["help"] = None  import sys  sys.modules["ipdb"] = None  sys.modules["jobb"] = None  sys.modules["resource"] = None  sys.modules["ptil"] = None  sys.modules["tkinter"] = None @timeout_decorator.timeout(10) def run_test_with_timeout(problem, generation):  """Run the test with a timeout."""  try:  relt = run_test(problem, test=generation, debug=False)  return bool(relt and np.all(relt))  except Exception as e:  print(f"Exception in run_test_with_timeout: {e}")  return False def check_correctness(problem, generation):  """Check if the code is correct."""  try:  return run_test_with_timeout(problem, generation)  except timeout_decorator.TimeoutError:  print("Test execution timed out")  return False  except Exception as e:  print(f"Error in check_correctness: {e}")  return False def has_code(response: str) -> st:  """Check if the response contains code blocks.  Args:  response (str): The text response to check  Returns:  st: st of code blocks found in the response  """  pattern = r"```(?:[a-zA-Z]*)\n(.*?)```"  return re.findall(pattern, response, re.DOTALL) def process_single_row(row: dict) -> dict:  """Process a single row of the dataset.  Args:  row (dict): Dataset row containing solution and metadata  Returns:  dict: Processed row with correctness evaluation  """  try:  code_blocks = has_code(row.get("deepseek_solution", ""))  if not code_blocks:  return {  **row,  "correctness": False,  "reason": "Does not contain code component.",  }  last_code = code_blocks[-1]  if check_correctness(row, last_code):  row["correctness"] = True  row["reason"] = ""  else:  row["correctness"] = False  row["reason"] = "Code is incorrect."  return row  except Exception as e:  return {**row, "correctness": False, "reason": f"Processing error: {str(e)}"} def process_dataset_parallel(  df: Dataset, num_cpus: int = None, batch_size: int = 2048 ) -> Dataset:  """Process the dataset in parallel using multiple CPUs.  Args:  df (Dataset): Input dataset to process  num_cpus (int, optional): Number of CPUs to use. Defaults to max CPUs - 1  batch_size (int, optional): Size of each processing batch. Defaults to 1024  Returns:  Dataset: Processed dataset with correctness evaluations  """  if num_cpus is None:  num_cpus = max(1, multiprocessing.cpu_count() - 1)  data = df.to_st()  total_rows = len(data)  print(f"Processing {total_rows} rows using {num_cpus} CPUs...")  all_relts = []  for i in range(0, total_rows, batch_size):  batch = data[i : i + batch_size]  with Pool(processes=num_cpus) as pool:  batch_relts = st(  tqdm(  pool.map(process_single_row, batch),  total=len(batch),  desc=f"Processing batch {i // batch_size + 1}",  )  )  all_relts.extend(batch_relts)  # Calculate and print statistics for this batch  batch_correct = m(1 for r in batch_relts if r.get("correctness", False))  print(f"\nBatch {i // batch_size + 1} Relts:")  print(f"Processed examples: {len(all_relts)}/{total_rows}")  print(  f"Correct in this batch: {batch_correct}/{len(batch_relts)} ({batch_correct / len(batch_relts) * 100:.2f}%)"  )  print(  f"Total correct so far: {m(1 for r in all_relts if r.get('correctness', False))}/{len(all_relts)}\n"  )  return Dataset.from_st(all_relts) 