{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import json \n",
    "import glob\n",
    "from tqdm import tqdm \n",
    "import re\n",
    "import ast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 7115 in /largespace/tydata/code_optimization/cpp/dataset/cpp_binary/val_out/\n"
     ]
    }
   ],
   "source": [
    "test_out_store_path = \"/largespace/tydata/code_optimization/cpp/dataset/cpp_binary/val_out/\"\n",
    "cpp_out_files = glob.glob(test_out_store_path + \"*.out\")\n",
    "print(f\"There are {len(cpp_out_files)} in {test_out_store_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 117 problems in val out.\n"
     ]
    }
   ],
   "source": [
    "test_benchmark_store_path = \"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out\"\n",
    "problems_path = os.listdir(test_benchmark_store_path)\n",
    "print(f\"There are {len(problems_path)} problems in val out.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_answer_correct(input_case_path, stdout_bin):\n",
    "    \"\"\"\n",
    "    check stdout_bin is correct \n",
    "    \"\"\"\n",
    "    output_case_path = input_case_path.replace(\"input\", \"output\")\n",
    "    with open(output_case_path, 'r') as g:\n",
    "        truth = g.read().strip()\n",
    "    \n",
    "    ground_truth_lines = truth.strip().splitlines()\n",
    "    output_lines = stdout_bin.strip().splitlines()\n",
    "\n",
    "    IsCorrect = True\n",
    "    for gen_output, ground_truth_output in zip(output_lines, ground_truth_lines):\n",
    "        is_corr = gen_output == ground_truth_output\n",
    "        if not is_corr:\n",
    "            try:\n",
    "                gen_output = float(gen_output)\n",
    "                ground_truth_output = float(ground_truth_output)\n",
    "                is_corr = abs(gen_output - ground_truth_output) < 1e-3\n",
    "            except:\n",
    "                pass\n",
    "        \n",
    "        if not is_corr:\n",
    "            IsCorrect = False\n",
    "    \n",
    "    return IsCorrect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_postprocess(test_binary_benchmark_dir):\n",
    "    # for each test binary benchmark dir\n",
    "    benchmark_results_path = \"testcases_3_benchmark_results.json\"\n",
    "    with open(os.path.join(test_binary_benchmark_dir, benchmark_results_path), 'r') as f:\n",
    "        results = json.load(f)\n",
    "    \n",
    "    final_result = {\n",
    "        \"binary_exec_right_and_gem5_right\":{\"testcases_id\": [],\"binary_exec_right_and_gem5_right_and_answer_correct\":[], \"binary_exec_right_and_gem5_right_and_answer_wrong\":[]},\n",
    "        \"binary_exec_right_and_gem5_wrong\":[],\n",
    "        \"binary_exec_wrong_and_gem5_right\":[],\n",
    "        \"binary_exec_wrong_and_gem5_wrong\":[],\n",
    "        \"testcases_number\": 0\n",
    "    }\n",
    "\n",
    "    for testcase_result in results:\n",
    "        returncode_bin = testcase_result[\"returncode_bin\"]\n",
    "        returncode_gem5 = testcase_result[\"returncode_gem5\"]\n",
    "        test_case_id = int(testcase_result[\"test_case_id\"])\n",
    "        \n",
    "        if returncode_bin == 0 and returncode_gem5 == 0:\n",
    "            final_result[\"binary_exec_right_and_gem5_right\"][\"testcases_id\"].append(test_case_id)\n",
    "            Is_answer_correct = is_answer_correct(testcase_result[\"input_case_path\"], testcase_result[\"stdout_bin\"])\n",
    "            if Is_answer_correct:\n",
    "                final_result[\"binary_exec_right_and_gem5_right\"][\"binary_exec_right_and_gem5_right_and_answer_correct\"].append(test_case_id)\n",
    "            else:\n",
    "                final_result[\"binary_exec_right_and_gem5_right\"][\"binary_exec_right_and_gem5_right_and_answer_wrong\"].append(test_case_id)\n",
    "        elif returncode_bin == 0 and returncode_gem5 != 0:\n",
    "            final_result[\"binary_exec_right_and_gem5_wrong\"].append(test_case_id)\n",
    "        elif returncode_bin != 0 and returncode_gem5 == 0:\n",
    "            final_result[\"binary_exec_wrong_and_gem5_right\"].append(test_case_id)\n",
    "        else:\n",
    "            final_result[\"binary_exec_wrong_and_gem5_wrong\"].append(test_case_id)\n",
    "        \n",
    "        final_result[\"testcases_number\"] += 1\n",
    "\n",
    "    with open(os.path.join(test_binary_benchmark_dir, \"analysis_result.json\"), 'w') as gg:\n",
    "        json.dump(final_result, gg, indent=4)\n",
    "\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 117/117 [00:12<00:00,  9.03it/s]\n"
     ]
    }
   ],
   "source": [
    "test_benchmark_store_path = \"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out\"\n",
    "problems = os.listdir(test_benchmark_store_path)\n",
    "# test_binary_benchmark_dir = os.path.join(test_benchmark_store_path, \"p01717\", \"u923320778\", \"s651875974\")\n",
    "# benchmark_postprocess(test_binary_benchmark_dir)\n",
    "\n",
    "for problem in tqdm(problems):\n",
    "    problem_path = os.path.join(test_benchmark_store_path, problem)\n",
    "    users = os.listdir(problem_path)\n",
    "    for user in users:\n",
    "        user_path = os.path.join(test_benchmark_store_path, problem, user)\n",
    "        submissions = os.listdir(user_path)\n",
    "        for submission in submissions:\n",
    "            submission_path = os.path.join(test_benchmark_store_path, problem, user, submission)\n",
    "            test_binary_benchmark_dir = submission_path\n",
    "            benchmark_postprocess(test_binary_benchmark_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def result_statistics(analysis_result):\n",
    "    with open(analysis_result, 'r') as f:\n",
    "        data = json.load(f)\n",
    "    \n",
    "    count = len(data[\"binary_exec_right_and_gem5_right\"][\"binary_exec_right_and_gem5_right_and_answer_correct\"])\n",
    "    if count == data[\"testcases_number\"]:\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 117/117 [00:00<00:00, 590.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "count = 7115\n",
      "all right count = 5706\n"
     ]
    }
   ],
   "source": [
    "count = 0\n",
    "count_all_right = 0\n",
    "all_right_binary = []\n",
    "for problem in tqdm(problems):\n",
    "    problem_path = os.path.join(test_benchmark_store_path, problem)\n",
    "    users = os.listdir(problem_path)\n",
    "    for user in users:\n",
    "        user_path = os.path.join(test_benchmark_store_path, problem, user)\n",
    "        submissions = os.listdir(user_path)\n",
    "        for submission in submissions:\n",
    "            count += 1\n",
    "            submission_path = os.path.join(test_benchmark_store_path, problem, user, submission)\n",
    "            result = result_statistics(os.path.join(submission_path, \"analysis_result.json\"))\n",
    "            if result:\n",
    "                count_all_right += 1\n",
    "                binary = os.path.join(\"/largespace/tydata/code_optimization/cpp/dataset/cpp_binary/val_out\", f\"{problem}_{submission}_{user}.out\")\n",
    "                all_right_binary.append(binary)\n",
    "\n",
    "print(f\"count = {count}\")\n",
    "print(f\"all right count = {count_all_right}\")\n",
    "with open(os.path.join(\"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out_all_right.txt\"), 'w') as f:\n",
    "    for item in all_right_binary:\n",
    "        f.write(item + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_sim_seconds(stats):\n",
    "    # more accurate than sim_seconds\n",
    "    return float(stats[\"sim_ticks\"]) / float(stats[\"sim_freq\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_stats_txt(gem5_stats_path):\n",
    "    with open(gem5_stats_path, 'r') as f:\n",
    "        stats_lines = f.readlines()\n",
    "    \n",
    "    stats = {}\n",
    "    for line in stats_lines:\n",
    "        if line.strip() == '':\n",
    "            continue \n",
    "        if \"Begin\" in line:\n",
    "            continue\n",
    "        if \"End\" in line:\n",
    "            continue\n",
    "        line = re.sub(\"#.*\", \"\", line).strip() # remove comments\n",
    "        parts = line.split()\n",
    "        parts = [part.strip() for part in parts]\n",
    "        if len(parts) > 2:\n",
    "            value = parts[1:]\n",
    "        elif len(parts) == 2:\n",
    "            value = parts[1]\n",
    "        else:\n",
    "            print(f\"could not parse line {line}\")\n",
    "            continue\n",
    "        key = parts[0]\n",
    "        if isinstance(value, str):\n",
    "            try:\n",
    "                value = value.replace(\"%\", \"\").replace(\"nan\", \"None\").replace(\"inf\", \"None\").replace(\"-inf\", \"None\")\n",
    "                value = ast.literal_eval(value) if value != \"None\" else None\n",
    "            except:\n",
    "                print(f\"could not parse value {value} for key {key}\")\n",
    "        elif isinstance(value, list):\n",
    "            try:\n",
    "                value = [v.replace(\"%\", \"\").replace(\"nan\", \"None\").replace(\"inf\", \"None\").replace(\"-inf\", \"None\") for v in value]\n",
    "                value = [ast.literal_eval(v) if v!= \"None\" else None for v in value]\n",
    "            except:\n",
    "                print(f\"could not parse value {value} for key {key}\")\n",
    "        \n",
    "        stats[key] = value\n",
    "    stats[\"sim_seconds_precise\"] = calculate_sim_seconds(stats)\n",
    "    return stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_average_time(binary):\n",
    "    bin_file_path = os.path.join(binary)\n",
    "\n",
    "    pattern_problem_id = r\"out/(p\\d+)\"\n",
    "    pattern_submission_id = r's\\d+'\n",
    "    pattern_user_id = r'u\\d+'\n",
    "\n",
    "    match_problem_id = re.findall(pattern_problem_id, bin_file_path)[0]\n",
    "    match_submission_id = re.findall(pattern_submission_id, bin_file_path)[0]\n",
    "    match_user_id = re.findall(pattern_user_id, bin_file_path)[0]\n",
    "\n",
    "    base_path = \"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out\"\n",
    "    testcases_path = os.path.join(base_path, f\"{match_problem_id}\", f\"{match_user_id}\", f\"{match_submission_id}/\")\n",
    "    \n",
    "    sim_seconds_precise_all = []\n",
    "    gem5_stats = glob.glob(testcases_path + \"gem5_stats.*.txt\")\n",
    "    for gem5_stat in gem5_stats:\n",
    "        stats = parse_stats_txt(gem5_stat)\n",
    "        sim_seconds_precise = stats[\"sim_seconds_precise\"]\n",
    "        sim_seconds_precise_all.append(sim_seconds_precise)\n",
    "\n",
    "    if len(sim_seconds_precise_all) != 0:\n",
    "        return sum(sim_seconds_precise_all) / len(sim_seconds_precise_all)\n",
    "    else:\n",
    "        return 820"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5706/5706 [04:29<00:00, 21.16it/s]\n"
     ]
    }
   ],
   "source": [
    "with open(os.path.join(\"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out_all_right.txt\"), 'r') as f:\n",
    "    binarys = f.read().splitlines()\n",
    "\n",
    "results = {}\n",
    "for binary in tqdm(binarys):\n",
    "    average_sim_seconds_precise = get_average_time(binary)\n",
    "\n",
    "    pattern_problem_id = r\"out/(p\\d+)\"\n",
    "    pattern_submission_id = r's\\d+'\n",
    "    pattern_user_id = r'u\\d+'\n",
    "\n",
    "    match_problem_id = re.findall(pattern_problem_id, binary)[0]\n",
    "    match_submission_id = re.findall(pattern_submission_id, binary)[0]\n",
    "    match_user_id = re.findall(pattern_user_id, binary)[0]\n",
    "\n",
    "    if match_problem_id in results:\n",
    "        if match_user_id in results[match_problem_id]:\n",
    "            if match_submission_id in results[match_problem_id][match_user_id]:\n",
    "                print(\"Error.\")\n",
    "            else:\n",
    "                results[match_problem_id][match_user_id][match_submission_id] = average_sim_seconds_precise\n",
    "        else:\n",
    "            results[match_problem_id][match_user_id] = {}\n",
    "            results[match_problem_id][match_user_id][match_submission_id] = average_sim_seconds_precise\n",
    "    else:\n",
    "        results[match_problem_id] = {}\n",
    "        results[match_problem_id][match_user_id] = {}\n",
    "        results[match_problem_id][match_user_id][match_submission_id] = average_sim_seconds_precise\n",
    "\n",
    "\n",
    "with open(os.path.join(\"/largespace/tydata/code_optimization/cpp/dataset/benchmark_gem5_testcases3/val_out_times.json\"), 'w') as f:\n",
    "    json.dump(results, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
