{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Jackal JQL Query Generator (Anonymized Version)\n",
    "\n",
    "This script constructs valid compound JQL queries from a Jira schema,\n",
    "validates them against a Jira instance, and saves subsets of queries\n",
    "for use in Jackal benchmarking. No LLM calls or API keys are required.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import random\n",
    "import asyncio\n",
    "import aiohttp\n",
    "import ssl\n",
    "import urllib\n",
    "import pandas as pd\n",
    "#import time\n",
    "from pathlib import Path\n",
    "from typing import List, Dict, Any, Tuple\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()\n",
    "\n",
    "BASE_URL = \"https://<jira-instance-url>\"\n",
    "# Note: for anonymity, actual Jira endpoints are omitted.\n",
    "\n",
    "SEARCH_URL = f\"{BASE_URL}/rest/api/2/search\"\n",
    "HEADERS = {\"Accept\": \"application/json\"}\n",
    "PARAMS_BASE = {\"maxResults\": 0, \"fields\": \"\"}\n",
    "\n",
    "MAX_CLAUSES = 5\n",
    "MIN_CLAUSES = 2\n",
    "BATCH_SIZE = 300\n",
    "TARGET_VALID_QUERY_COUNT = 100\n",
    "CERT_FILE = Path(\"utils/cacert.pem\")  # optional; replace with your local cert if needed\n",
    "FIELD_SCHEMA_PATH = Path(\"input/input_field_values.json\") # replace with your local path to field schema\n",
    "OUTPUT_PATH = Path(\"output/new_25k_compound.parquet\")\n",
    "\n",
    "multi_value_groups = {\n",
    "    \"Issue Type\": [[\"Epic\", \"User Story\", \"Task\", \"Sub-task\"],\n",
    "                   [\"Improvement\", \"Change Request\", \"New Feature\"], [\"Research\", \"Suggestion\"],\n",
    "                   [\"Requirement\", \"Risk\", \"Test\"]],\n",
    "    \"Priority\": [[\"P0: Blocker\", \"P1: Critical\"], [\"P2: Important\", \"P3: Somewhat important\"],\n",
    "                 [\"P4: Low\", \"P5: Not important\"], [\"P0: Blocker\", \"P1: Critical\", \"P2: Important\"],\n",
    "                 [\"P3: Somewhat important\", \"P4: Low\", \"P5: Not important\"]],\n",
    "    \"Resolution\": [[\"Fixed\", \"Done\"], [\"Won't Do\", \"Out of scope\", \"Invalid\"],\n",
    "                   [\"Duplicate\", \"Incomplete\", \"Cannot Reproduce\"]],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ssl_context(cert_path: Path) -> ssl.SSLContext:\n",
    "    return ssl.create_default_context(cafile=str(cert_path))\n",
    "\n",
    "ssl_context = get_ssl_context(CERT_FILE)\n",
    "\n",
    "def load_field_schema(path: Path) -> Dict[str, Any]:\n",
    "    with open(path) as f:\n",
    "        return json.load(f)\n",
    "\n",
    "def get_clause_name(field_data: Dict[str, Any]) -> str:\n",
    "    return field_data['clauseNames'][0]\n",
    "\n",
    "\n",
    "def make_clause(field_name, field_data):\n",
    "    clause = get_clause_name(field_data)\n",
    "    ftype = field_data[\"type\"]\n",
    "    if field_data.get(\"emptyAllowed\", False) and random.random() < 0.2:\n",
    "        return f'{clause} is not EMPTY' if random.random() < 0.5 else f'{clause} is EMPTY'\n",
    "    if field_name in multi_value_groups:\n",
    "        if random.random() < 0.2:\n",
    "            values = random.choice(multi_value_groups[field_name])\n",
    "        else:\n",
    "            values = [random.choice(field_data[\"values\"])]\n",
    "    else:\n",
    "        values = [random.choice(field_data[\"values\"])]\n",
    "\n",
    "    if ftype in [\"categorical\", \"special/categorical\"]:\n",
    "        value_clause = \", \".join(f'\"{v}\"' for v in values)\n",
    "        return f'{clause} in ({value_clause})'\n",
    "    elif ftype == \"text_search\":\n",
    "        return f'{clause} ~ \"{values[0]}\"'\n",
    "    elif ftype == \"date/relative_date\":\n",
    "        return values[0]\n",
    "\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_batch(\n",
    "    fields: List[Tuple[str, Dict[str, Any]]],\n",
    "    valid_queries: Dict[int, List[Dict[str, Any]]],\n",
    "    batch_size: int,\n",
    "    per_clause_target: int,\n",
    "    seen_jqls: set,\n",
    ") -> List[Dict[str, Any]]:\n",
    "    batch: List[Dict[str, Any]] = []\n",
    "\n",
    "    while len(batch) < batch_size:\n",
    "        for clause_count in range(2, 6):\n",
    "            if len(valid_queries[clause_count]) >= per_clause_target:\n",
    "                continue\n",
    "\n",
    "            random.shuffle(fields)\n",
    "            selected_fields = fields[:clause_count]\n",
    "\n",
    "            clauses = []\n",
    "            for fname, fdata in selected_fields:\n",
    "                clause = make_clause(fname, fdata)\n",
    "                if clause:\n",
    "                    clauses.append({\"field_name\": fname, \"jql\": clause})\n",
    "\n",
    "\n",
    "            if not is_valid_clause_set(clauses):\n",
    "                continue\n",
    "\n",
    "            jql_str = \" AND \".join(c[\"jql\"] for c in clauses)\n",
    "\n",
    "            if jql_str in seen_jqls:\n",
    "                continue\n",
    "\n",
    "            seen_jqls.add(jql_str)\n",
    "            batch.append({\n",
    "                \"jql\": jql_str,\n",
    "                \"field_name\": clauses[0]['field_name'],\n",
    "                \"field_names\": [c[\"field_name\"] for c in clauses],\n",
    "                \"clause_count\": clause_count,\n",
    "            })\n",
    "\n",
    "            if len(batch) >= batch_size:\n",
    "                break\n",
    "\n",
    "    return batch\n",
    "\n",
    "def is_valid_clause_set(clauses: List[Dict[str, str]]) -> bool:\n",
    "    field_names = [c['field_name'] for c in clauses]\n",
    "    date_fields = {'Created', 'Updated', 'Resolved'}\n",
    "    text_fields = {'Summary', 'Description'}\n",
    "    important_fields = {\"Issue Type\", \"Summary\", \"Description\", \"Priority\", \"Components\", \"Labels\"}\n",
    "\n",
    "    if len(field_names) != len(set(field_names)):\n",
    "        return False\n",
    "    if field_names.count(\"Project\") > 1:\n",
    "        return False\n",
    "    if len(set(field_names) & date_fields) > 1:\n",
    "        return False\n",
    "    if 'Resolution' in field_names and 'Priority' in field_names:\n",
    "        res = next(c['jql'] for c in clauses if c['field_name'] == 'Resolution')\n",
    "        prio = next(c['jql'] for c in clauses if c['field_name'] == 'Priority')\n",
    "        if any(p in prio for p in ['P0', 'P1']) and any(r in res for r in ['Done', 'Duplicate', 'Invalid']):\n",
    "            return False\n",
    "    if 'Issue Type' in field_names and 'Priority' in field_names:\n",
    "        iss = next(c['jql'] for c in clauses if c['field_name'] == 'Issue Type')\n",
    "        prio = next(c['jql'] for c in clauses if c['field_name'] == 'Priority')\n",
    "        if 'Bug' in iss and any(p in prio for p in ['P4', 'P5']):\n",
    "            return False\n",
    "    if 'Assignee' in field_names and 'Resolution' in field_names:\n",
    "        assignee = next(c['jql'] for c in clauses if c['field_name'] == 'Assignee')\n",
    "        res = next(c['jql'] for c in clauses if c['field_name'] == 'Resolution')\n",
    "        if 'Unassigned' in assignee and 'Fixed' in res:\n",
    "            return False\n",
    "    if len(set(field_names) & text_fields) > 1:\n",
    "        return False\n",
    "    if any(f in field_names for f in date_fields) and 'Project' not in field_names and 'Issue Type' not in field_names:\n",
    "        return False\n",
    "    if not any(f in field_names for f in important_fields):\n",
    "        return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sem = asyncio.Semaphore(50)\n",
    "\n",
    "async def run_jql_query(session, query_data, ssl_context, sem):\n",
    "    jql = query_data['jql']\n",
    "    field_name = query_data['field_name']\n",
    "\n",
    "    params = PARAMS_BASE | {\"jql\": jql}\n",
    "    querystring = urllib.parse.urlencode(params, safe='(),\" ')\n",
    "    url = f\"{SEARCH_URL}?{querystring}\"\n",
    "\n",
    "    try:\n",
    "        async with sem:\n",
    "           # time.sleep(3) lightweight request\n",
    "            async with session.get(url, headers=HEADERS, ssl=ssl_context) as resp:\n",
    "                data = await resp.json()\n",
    "                count = int(data.get(\"total\", -1))\n",
    "    except Exception as exc:\n",
    "        print(f\"[WARN] {jql} failed – {exc}\", file=sys.stderr)\n",
    "        count = -1\n",
    "\n",
    "    return {\"jql\": jql, \"field_name\": field_name, \"count\": count}\n",
    "\n",
    "async def run_jql_queries(queries, ssl_context):\n",
    "    connector = aiohttp.TCPConnector(limit=50)\n",
    "    async with aiohttp.ClientSession(connector=connector) as session:\n",
    "        tasks = [run_jql_query(session, q, ssl_context, sem) for q in queries]\n",
    "        return await asyncio.gather(*tasks)\n",
    "\n",
    "async def generate_valid_queries(field_schema: Dict[str, Any]) -> List[Dict[str, Any]]:\n",
    "    valid_queries = []\n",
    "    fields = list(field_schema.items())\n",
    "\n",
    "    while len(valid_queries) < TARGET_VALID_QUERY_COUNT:\n",
    "        batch = []\n",
    "\n",
    "        while len(batch) < BATCH_SIZE:\n",
    "            random.shuffle(fields)\n",
    "            num_clauses = random.randint(MIN_CLAUSES, MAX_CLAUSES)\n",
    "            selected_fields = fields[:num_clauses]\n",
    "\n",
    "            clauses = []\n",
    "            for fname, fdata in selected_fields:\n",
    "                clause = make_clause(fname, fdata)\n",
    "                if clause:\n",
    "                    clauses.append({\"field_name\": fname, \"jql\": clause})\n",
    "\n",
    "\n",
    "            if not is_valid_clause_set(clauses):\n",
    "                continue\n",
    "\n",
    "            jql_str = \" AND \".join(c[\"jql\"] for c in clauses)\n",
    "            batch.append({\"jql\": jql_str, \"field_name\": clauses[0]['field_name'], \"field_names\": [c[\"field_name\"] for c in clauses]})\n",
    "\n",
    "        results = await run_jql_queries(batch, ssl_context)\n",
    "\n",
    "        for i, result in enumerate(results):\n",
    "            if result[\"count\"] > 0:\n",
    "                result[\"field_names\"] = batch[i][\"field_names\"]\n",
    "                valid_queries.append(result)\n",
    "                print(f\"✅ {len(valid_queries)} valid queries found\")\n",
    "                if len(valid_queries) >= TARGET_VALID_QUERY_COUNT:\n",
    "                    break\n",
    "\n",
    "    return valid_queries\n",
    "\n",
    "TARGET_PER_CLAUSE = 6250\n",
    "SAVE_INTERVAL = 1000  # Save every 1000 valid queries\n",
    "\n",
    "def save_progress(valid_queries: Dict[int, List[Dict[str, Any]]], output_path: Path):\n",
    "    all_queries = sum(valid_queries.values(), [])\n",
    "    if all_queries:\n",
    "        df = pd.DataFrame(all_queries)\n",
    "        temp_path = output_path.with_suffix('.temp.parquet')\n",
    "        df.to_parquet(temp_path, index=False)\n",
    "        temp_path.rename(output_path)\n",
    "        print(f\"💾 Saved {len(all_queries)} queries to {output_path}\")\n",
    "\n",
    "async def generate_subset_data(field_schema: Dict[str, Any]) -> List[Dict[str, Any]]:\n",
    "    valid_queries = {2: [], 3: [], 4: [], 5: []}\n",
    "    fields = list(field_schema.items())\n",
    "    seen_jqls = set()\n",
    "    total_saved = 0\n",
    "\n",
    "    while any(len(valid_queries[k]) < TARGET_PER_CLAUSE for k in valid_queries):\n",
    "        batch = create_batch(fields, valid_queries, BATCH_SIZE, TARGET_PER_CLAUSE, seen_jqls)\n",
    "        results = await run_jql_queries(batch, ssl_context)\n",
    "\n",
    "        for i, result in enumerate(results):\n",
    "            clause_count = batch[i][\"clause_count\"]\n",
    "            if result[\"count\"] > 0 and len(valid_queries[clause_count]) < TARGET_PER_CLAUSE:\n",
    "                result[\"field_names\"] = batch[i][\"field_names\"]\n",
    "                result[\"clause_count\"] = clause_count\n",
    "                valid_queries[clause_count].append(result)\n",
    "                print(f\"✅ [{clause_count} clauses] → {len(valid_queries[clause_count])}/{TARGET_PER_CLAUSE}\")\n",
    "        \n",
    "        # Periodic saving\n",
    "        current_total = sum(len(v) for v in valid_queries.values())\n",
    "        if current_total >= total_saved + SAVE_INTERVAL:\n",
    "            save_progress(valid_queries, OUTPUT_PATH)\n",
    "            total_saved = current_total\n",
    "\n",
    "    return sum(valid_queries.values(), [])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    field_schema = load_field_schema(FIELD_SCHEMA_PATH)\n",
    "    valid_queries = asyncio.run(generate_subset_data(field_schema))\n",
    "    df = pd.DataFrame(valid_queries)\n",
    "    df.to_parquet(OUTPUT_PATH, index=False)\n",
    "    print(f\"✅ Saved {len(df)} queries to {OUTPUT_PATH}\")\n",
    "    print(df.head())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
