{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RO4xApzF25XN"
      },
      "source": [
        "# Install & Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# !pip install tree-sitter tree-sitter-c==0.23.4 tree-sitter-cpp==0.23.4 -q\n",
        "# !pip install openai anthropic -q\n",
        "# !pip install torch==2.5.1\n",
        "# !pip install transformers==4.47.0\n",
        "# !pip install accelerate==1.3.0   # for automatic GPU distribution if using multi-GPU"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8VO1Pz9X3INt"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import re\n",
        "import json\n",
        "import time\n",
        "import argparse\n",
        "import pandas as pd\n",
        "from pathlib import Path"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use this to load api keys from files or set in environment otherwise."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "Nlksybf3A6Ay"
      },
      "outputs": [],
      "source": [
        "with open(\"creds/openai_api_key.json\") as oai_fl:\n",
        "\tos.environ[\"OPENAI_API_KEY\"] = json.load(oai_fl)[\"api_key\"]\n",
        "with open(\"creds/anthropic_api_key.json\") as ant_fl:\n",
        "\tos.environ[\"ANTHROPIC_API_KEY\"] = json.load(ant_fl)[\"api_key\"]\n",
        "with open(\"creds/hf_access_token.json\") as hf_fl:\n",
        "\tos.environ[\"HF_TOKEN\"] = json.load(hf_fl)[\"api_key\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "ZkBIZA0FAres"
      },
      "outputs": [],
      "source": [
        "if \"OPENAI_API_KEY\" not in os.environ:\n",
        "    raise ValueError(\"OPENAI_API_KEY not set\")\n",
        "if \"ANTHROPIC_API_KEY\" not in os.environ:\n",
        "    raise ValueError(\"OPENAI_API_KEY not set\")\n",
        "if \"HF_TOKEN\" not in os.environ:\n",
        "    raise ValueError(\"HF_TOKEN not set\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "DYvt1UE88NtX"
      },
      "outputs": [],
      "source": [
        "from agents.normalization_agent import NormalizationAgent\n",
        "from agents.planning_agent      import PlanningAgent\n",
        "from agents.context_agent       import ContextAgent\n",
        "from agents.symbol_backend      import SymbolBackend\n",
        "from agents.detection_agent\t import DetectionAgent\n",
        "from agents.validation_agent\t import ValidationAgent\n",
        "from load_llm import get_tokenizer, get_model, get_pipe\n",
        "from utils import get_json_file_as_dict, save_dict_as_json, get_json_lines_as_list, save_json_lines"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DghowlEs3hLj"
      },
      "source": [
        "# Vul Detect"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "#### CWE ID and Names\n",
        "Prepare CWE ID and Name for record keeping in Norm Results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "yjcP19ly53Mu"
      },
      "outputs": [],
      "source": [
        "cwe_details = pd.read_csv(\"cwe_details.csv\", index_col=False)\n",
        "cwe_details = cwe_details.set_index(\"CWE-ID\")\n",
        "cwe_details.index = cwe_details.index.astype(str).str.strip().map(lambda x: f\"CWE-{x}\")\n",
        "cwe_details = cwe_details.to_dict(orient=\"index\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Prepare Runtime Values\n",
        "Set the parameters for this experiment. This emulates a cli args interface for easy portability to `.py`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0hA7A-_pDp3z"
      },
      "outputs": [],
      "source": [
        "# parser = argparse.ArgumentParser()\n",
        "# parser.add_argument(\"--model_id\", type=str, required=True, help=\"LLM to use. Use huggingface model_id or any of the following [\\\"gpt-4.1\\\", \\\"claude-3-7-sonnet-20250219\\\", \\\"claude-3-7-sonnet-latest\\\"]\")\n",
        "# parser.add_argument(\"--norm_file\", type=str, help=\"File path to use already normalized file.\", default=None)\n",
        "# parser.add_argument(\"--run_norm\", action=\"store_true\", help=\"Whether to run normalization agent\")\n",
        "# parser.add_argument(\"--dataset\", type=str, help=\"Path to the JSON subset for vulnerability detection.\", default=None)\n",
        "\n",
        "# args = parser.parse_args()\n",
        "\n",
        "args = argparse.Namespace(\n",
        "    model_id=\"claude-3-7-sonnet-20250219\",\n",
        "    norm_file=None,\n",
        "    run_norm=True,\n",
        "    dataset=\"random_subset.json\"\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Normalization Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Use this if want to load a previous normalization output. This is a deterministic step so should be same for all models/configurations.\n",
        "\n",
        "Otherwise execute the cell below to run the Normalization Agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 144
        },
        "id": "LLUEcFcUSeL6",
        "outputId": "b6f82216-2eb7-43ff-ca1b-5592a8c7eaa7"
      },
      "outputs": [],
      "source": [
        "norm_agent_otpt = get_json_lines_as_list(args.norm_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 108
        },
        "id": "jVlux7rWDbdO",
        "outputId": "b31be515-bb77-4f3c-a9b6-08ba28b05169"
      },
      "outputs": [],
      "source": [
        "if not args.norm_file and not args.run_norm:\n",
        "\traise ValueError(\"Please provide either --norm_file or pass --run_norm option.\")\n",
        "if args.run_norm and not args.dataset:\n",
        "    raise ValueError(\"Please provide --dataset option when using --run_norm.\")\n",
        "if args.norm_file:\n",
        "    norm_agent_otpt = get_json_file_as_dict(args.norm_file)\n",
        "\n",
        "if args.run_norm:\n",
        "    norm = NormalizationAgent()\n",
        "    norm_agent_otpt = []\n",
        "    random_subset = get_json_file_as_dict(args.dataset)\n",
        "    for cwe_id, sample_list in random_subset.items():\n",
        "        if cwe_id not in cwe_details:\n",
        "            print(f\"Skipping unsupported CWE ID: {cwe_id}\")\n",
        "            continue\n",
        "        vuln_type = f\"{cwe_id} - {cwe_details[cwe_id][\"Name\"]}\"\n",
        "        for sample in sample_list:\n",
        "            repo_path = sample[\"project_repo_path\"]\n",
        "            commit_id = sample[\"commit_id\"]\n",
        "            func_body = sample[\"func_body\"]\n",
        "            filepath = sample[\"filepath\"]\n",
        "            is_vulnerable = sample[\"is_vulnerable\"]\n",
        "\n",
        "            ext = Path(filepath).suffix[1:]\n",
        "            if ext in ['c']:\n",
        "                file_type = \"c\"\n",
        "            elif ext in ['cpp', 'cc', 'cxx', 'C']:\n",
        "                file_type = \"cpp\"\n",
        "            else:\n",
        "                print(f\"Skipping unsupported file extension: {ext}\")\n",
        "                continue\n",
        "\n",
        "            norm_agent_otpt.append(dict(norm_result=dict(\n",
        "                repo_path=repo_path,\n",
        "                commit_id=commit_id,\n",
        "                filepath=filepath,\n",
        "                is_vulnerable=is_vulnerable,\n",
        "                **norm.run(func_body, file_type=file_type, vuln_type=vuln_type)\n",
        "            )))\n",
        "        os.makedirs(f\"{args.model_id.split('/')[-1]}\", exist_ok=True)\n",
        "        save_json_lines(f\"{args.model_id.split(\"/\")[-1]}/norm_agent_otpt.jsonl\", norm_agent_otpt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_WC4vVteEQPe"
      },
      "outputs": [],
      "source": [
        "# os.environ['CUDA_VISIBLE_DEVICES'] = \"2,3\"\n",
        "pipe = None\n",
        "if not (args.model_id.startswith(\"claude\") or args.model_id.startswith(\"gpt\")):\n",
        "\tpipe = get_pipe(\n",
        "\t\tget_model(args.model_id),\n",
        "\t\tget_tokenizer(args.model_id)\n",
        "\t)\n",
        "os.makedirs(args.model_id.split(\"/\")[-1], exist_ok=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Panning Agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "7kSi9YpQG-6F"
      },
      "outputs": [],
      "source": [
        "def get_plan_with_retry(norm_result, max_retries=3):\n",
        "\tfor attempt in range(1, max_retries+1):\n",
        "\t\ttry:\n",
        "\t\t\treturn plan(\n",
        "\t\t\t\tclean_code=norm_result[\"src\"],\n",
        "\t\t\t\tcompact_ast=norm_result[\"compact_ast\"][\"ast\"]\n",
        "\t\t\t)\n",
        "\t\texcept Exception as e:\n",
        "\t\t\tprint(f\"Attempt {attempt} failed: {e}\")\n",
        "\t\t\tif attempt == max_retries:\n",
        "\t\t\t\treturn None\n",
        "# --------------------- Planning ----------------------\n",
        "plan = PlanningAgent(model=args.model_id, pipe=pipe)\n",
        "plan_agent_otpt = []\n",
        "with open(f\"{args.model_id.split(\"/\")[-1]}/plan_agent_otpt.jsonl\", \"a\", encoding=\"utf-8\") as f:\n",
        "\tfor idx, norm_otpt in enumerate(norm_agent_otpt):\n",
        "\t\tplan_result = get_plan_with_retry(norm_otpt[\"norm_result\"])\n",
        "\t\tif plan_result is None:\n",
        "\t\t\tprint(f\"Planning failed for {idx} - {norm_otpt['norm_result']['repo_path']}, {norm_otpt['norm_result']['commit_id']}, {norm_otpt['norm_result']['filepath']}, {norm_otpt['norm_result']['is_vulnerable']}, {norm_otpt['norm_result']['src'][:100]}\\n\\n\")\n",
        "\t\t\tcontinue\n",
        "\t\tplan_agent_otpt.append(dict(\n",
        "\t\t\tnorm_result=norm_otpt[\"norm_result\"],\n",
        "\t\t\tplan_result=plan_result\n",
        "\t\t))\n",
        "\t\tassert isinstance(plan_agent_otpt[-1], dict)\n",
        "\t\tjson.dump(plan_agent_otpt[-1], f, ensure_ascii=False)\n",
        "\t\tf.write(\"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Context Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If you want to start running from the output of previous Planning agent, load planning agent output from this cell.\n",
        "\n",
        "Otherwise go to the next cell directly to run the Context Agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "plan_agent_otpt = get_json_lines_as_list(f\"{args.model_id.split('/')[-1]}/plan_agent_otpt.jsonl\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "exz0yKbtL2gt"
      },
      "outputs": [],
      "source": [
        "# ---------------- Context Extraction -----------------\n",
        "cntxt_agent_otpt = []\n",
        "with open(f\"{args.model_id.split(\"/\")[-1]}/cntxt_agent_otpt.jsonl\", \"a\", encoding=\"utf-8\") as f:\n",
        "\tfor plan_otpt in plan_agent_otpt:\n",
        "\t\tbackend = SymbolBackend(repo=plan_otpt[\"norm_result\"][\"repo_path\"], commit=plan_otpt[\"norm_result\"][\"commit_id\"])\n",
        "\t\tcontext_agent = ContextAgent(\n",
        "\t\t\tmodel=args.model_id,\n",
        "\t\t\tbackend=backend,\n",
        "\t\t\tpipe=pipe,\n",
        "\t\t)\n",
        "\t\tctx_result = context_agent(\n",
        "\t\t\tclean_code=plan_otpt[\"norm_result\"][\"src\"],\n",
        "\t\t\tcompact_ast=plan_otpt[\"norm_result\"][\"compact_ast\"][\"ast\"],\n",
        "\t\t\tplanning_plan=plan_otpt[\"plan_result\"]\n",
        "\t\t)\n",
        "\t\tcntxt_agent_otpt.append(dict(\n",
        "\t\t\tnorm_result=plan_otpt[\"norm_result\"],\n",
        "\t\t\tplan_result=plan_otpt[\"plan_result\"],\n",
        "\t\t\tctx_result=ctx_result\n",
        "\t\t))\n",
        "\t\tassert isinstance(cntxt_agent_otpt[-1], dict)\n",
        "\t\tjson.dump(cntxt_agent_otpt[-1], f, ensure_ascii=False)\n",
        "\t\tf.write(\"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Detection and Validation Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If you want to start running from the output of previous Context agent, load context agent output from this cell.\n",
        "\n",
        "Otherwise go to the next cell directly to run the Detection and Validation Agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [],
      "source": [
        "cntxt_agent_otpt = get_json_lines_as_list(f\"{args.model_id.split('/')[-1]}/cntxt_agent_otpt.jsonl\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "mdW0gmelMNpp"
      },
      "outputs": [],
      "source": [
        "def run_detection_with_validation(\n",
        "    detection_agent,                    # instance of your DetectionAgent\n",
        "    validation_agent,                   # instance of ValidationAgent\n",
        "    *,\n",
        "    det_inputs: dict,                   # kwargs for detection_agent(...)\n",
        "    val_extra: dict = {},               # kwargs *extra* for ValidationAgent\n",
        "):\n",
        "    \"\"\"\n",
        "    1) Detection → Validation\n",
        "    2) If disagree → retry Detection twice → re-validate\n",
        "    3) Return Validation's verdict (second round if retried)\n",
        "    \"\"\"\n",
        "    sleep_time = 0\n",
        "    if args.model_id.startswith(\"gpt\"):\n",
        "        sleep_time = 0.25\n",
        "    if args.model_id.startswith(\"claude\"):\n",
        "        sleep_time = 1.5\n",
        "\n",
        "    # round 1\n",
        "    time.sleep(sleep_time)\n",
        "    det = detection_agent(**det_inputs)\n",
        "    val = validation_agent(detection_out=det, **det_inputs, **val_extra)\n",
        "\n",
        "    if val[\"agree\"]:\n",
        "        return det, val                     # consensus — done!\n",
        "    print(\"Detection and Validation disagree, retrying Detection...\")\n",
        "\n",
        "    # round 2\n",
        "    time.sleep(sleep_time)\n",
        "    det = detection_agent(**det_inputs)    # retry once\n",
        "    val = validation_agent(detection_out=det, **det_inputs, **val_extra)\n",
        "\n",
        "    if val[\"agree\"]:\n",
        "        return det, val                     # consensus — done!\n",
        "    print(\"Detection and Validation disagree again, last retry...\")\n",
        "\n",
        "    # round 2\n",
        "    time.sleep(sleep_time)\n",
        "    det = detection_agent(**det_inputs)    # retry twice\n",
        "    val = validation_agent(detection_out=det, **det_inputs, **val_extra)\n",
        "    # regardless of agree flag, Validation wins third time\n",
        "    return det, val"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "RJWuB1BTNaGU"
      },
      "outputs": [],
      "source": [
        "# ---------------- Vulnerability Detection -----------------\n",
        "det = DetectionAgent(model=args.model_id, pipe=pipe)\n",
        "val = ValidationAgent(model=args.model_id, pipe=pipe)\n",
        "detection_otpt = []\n",
        "validation_otpt = []\n",
        "with open(f\"{args.model_id.split(\"/\")[-1]}/val_agent_otpt.jsonl\", \"a\", encoding=\"utf-8\") as f:\n",
        "\tfor cntxt_otpt in cntxt_agent_otpt:\n",
        "\t\tdet_res, val_res = run_detection_with_validation(det, val, det_inputs=dict(\n",
        "\t\t\t\tclean_code=cntxt_otpt[\"norm_result\"][\"src\"],\n",
        "\t\t\t\tcompact_ast=cntxt_otpt[\"norm_result\"][\"compact_ast\"][\"ast\"],\n",
        "\t\t\t\tsummary=cntxt_otpt[\"plan_result\"][\"summary\"],\n",
        "\t\t\t\tchecklist=cntxt_otpt[\"plan_result\"][\"checklist\"],\n",
        "\t\t\t\tcontext=cntxt_otpt[\"ctx_result\"]\n",
        "\t\t\t))\n",
        "\t\tvalidation_otpt.append(dict(\n",
        "\t\t\tnorm_result=cntxt_otpt[\"norm_result\"],\n",
        "\t\t\tplan_result=cntxt_otpt[\"plan_result\"],\n",
        "\t\t\tctx_result=cntxt_otpt[\"ctx_result\"],\n",
        "\t\t\tdetect_result=det_res,\n",
        "\t\t\tvalidate_result=val_res\n",
        "\t\t))\n",
        "\t\tassert isinstance(validation_otpt[-1], dict)\n",
        "\t\tjson.dump(validation_otpt[-1], f, ensure_ascii=False)\n",
        "\t\tf.write(\"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Output Formatting"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Convert final output to Excel for manual validation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# !pip install openpyxl"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {},
      "outputs": [],
      "source": [
        "from openpyxl import Workbook\n",
        "\n",
        "# Create a new workbook and select the active worksheet\n",
        "wb = Workbook()\n",
        "ws = wb.active\n",
        "\n",
        "# Define the column headers\n",
        "headers = [\n",
        "    \"commit_id\", \"filepath\", \"is_vulnerable\", \"project\", \"src_snippet\", \n",
        "    \"vuln_type\", \"validate_is_vulnerable\", \"validate_vuln_statements\"\n",
        "]\n",
        "ws.append(headers)  # Add headers to the first row\n",
        "\n",
        "# Iterate over validation_otpt and extract the required fields\n",
        "for entry in validation_otpt:\n",
        "    norm_result = entry.get(\"norm_result\")\n",
        "    validate_result = entry.get(\"validate_result\")\n",
        "    \n",
        "    # Extract fields from norm_result\n",
        "    commit_id = norm_result.get(\"commit_id\")\n",
        "    filepath = norm_result.get(\"filepath\")\n",
        "    is_vulnerable = norm_result.get(\"is_vulnerable\")\n",
        "    project = norm_result.get(\"project\")\n",
        "    src_snippet = norm_result.get(\"src\")[:50]  # First 50 characters of src\n",
        "    vuln_type = norm_result.get(\"vuln_type\")\n",
        "    \n",
        "    # Extract fields from validate_result\n",
        "    # print(validate_result)\n",
        "    # break\n",
        "    validate_is_vulnerable = validate_result.get(\"is_vulnerable\")\n",
        "    vuln_statements = validate_result.get(\"vuln_statements\")\n",
        "    \n",
        "    # Process vuln_statements to format as \"statement-reason\" pairs\n",
        "    formatted_statements = \"\\n\\n\".join(\n",
        "        f\"{item.get('statement')} - {item.get('reason')}\" for item in vuln_statements\n",
        "    )\n",
        "    \n",
        "    # Append the row to the worksheet\n",
        "    ws.append([\n",
        "        commit_id, filepath, is_vulnerable, project, src_snippet, \n",
        "        vuln_type, validate_is_vulnerable, formatted_statements\n",
        "    ])\n",
        "\n",
        "# Save the workbook to a file\n",
        "wb.save(f\"{args.model_id.split(\"/\")[-1]}_output.xlsx\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "bzytHm-r9POx",
        "RO4xApzF25XN",
        "ArWbCraV3Qaa",
        "UifMz8-x3SDo"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "lin-data-proc",
      "language": "python",
      "name": "lin-data-proc"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
