# %%
import os
import json
import re
from rich.console import Console
from rich.table import Table
from rich import box
from dataclasses import dataclass
import argparse
from typing import Any, Dict, List, Set, Union


parser = argparse.ArgumentParser(description="Choose table type")
parser.add_argument(
    "--table",
    required=True,
    help="Choose table type: 1 for Table 1, 2 for Table 2 etc..",
)

parser.add_argument(
    "--base_folder",
    default="patches",
    type=str,
    help="Base folder for the patches directory",
)

args = parser.parse_args()
args.table = int(args.table) if args.table.isdigit() else args.table
benchmark = "kBenchSyz" if isinstance(args.table, int) else "ffmpeg"
BASE_FOLDER = os.path.join(args.base_folder, benchmark)

@dataclass
class FileInfo:
    file_name: str
    setting: str = ""
    tool: str = ""
    max_calls: int = 1
    k: int  = 1
ffmpeg_file_infos = [
    FileInfo("Unassisted-Code Researcher (GPT-4o + o1)-15-1.json", "Unassisted", "Code Researcher (GPT-4o + o1)", 15, 1)
]

kBenchSyz_file_infos = [
FileInfo("Assisted-GPT-4o-1-5.json", "Assisted", "GPT-4o", 1, 5),
FileInfo("Assisted-o1-1-5.json", "Assisted", "o1", 1, 5),

FileInfo("Stack context-GPT-4o-1-5.json", "Stack context", "GPT-4o", 1, 5),
FileInfo("Stack context-o1-1-5.json", "Stack context", "o1", 1, 5),

FileInfo("Unassisted-Agentless (GPT-4o)-4-5.json", "Unassisted", "Agentless (GPT-4o)", 4, 5),
FileInfo("Unassisted-SWE-agent (GPT-4o)-15-5.json", "Unassisted", "SWE-agent (GPT-4o)", 15, 5),
FileInfo("Unassisted-Code Researcher (GPT-4o)-15-5.json", "Unassisted", "Code Researcher (GPT-4o)", 15, 5),
FileInfo("Unassisted-Code Researcher (GPT-4o + o1)-15-5.json", "Unassisted", "Code Researcher (GPT-4o + o1)", 15, 5),
FileInfo("Unassisted-Code Researcher (Gemini 2.5-Flash)-15-5.json", "Unassisted", "Code Researcher (Gemini 2.5-Flash)", 15, 5),

FileInfo("Unassisted + Scaled-SWE-agent (GPT-4o)-30-5.json", "Unassisted + Scaled", "SWE-agent (GPT-4o)", 30, 5),
FileInfo("Unassisted + Scaled-Code Researcher (GPT-4o)-30-5.json", "Unassisted + Scaled", "Code Researcher (GPT-4o)", 30, 5),
FileInfo("Unassisted + Scaled-SWE-agent (GPT-4o)-15-10.json", "Unassisted + Scaled", "SWE-agent (GPT-4o)", 15, 10),
FileInfo("Unassisted + Scaled-Code Researcher (GPT-4o)-15-10.json", "Unassisted + Scaled", "Code Researcher (GPT-4o)", 15, 10),
]

wo_search_commits_file_infos = [
    FileInfo("Unassisted-Code Researcher (GPT-4o)-15-5.json", "Unassisted", "Code Researcher (GPT-4o)", 15, 5),
    FileInfo("wo SC-Code Researcher (GPT-4o)-15-5.json", "Unassisted", "W/O search_commits", 15, 5),
]

wo_context_filtering_file_infos = [
    FileInfo("Unassisted-Code Researcher (GPT-4o)-15-5.json", "Unassisted", "Code Researcher (GPT-4o)", 15, 5),
    FileInfo("wo CF-Code Researcher (GPT-4o)-15-5.json", "Unassisted", "W/O context_filtering", 15, 5),
]

if args.table == "ffmpeg":
    file_infos = ffmpeg_file_infos
elif args.table == 1 or args.table == 2:
    file_infos = kBenchSyz_file_infos
elif args.table == 3:
    file_infos = wo_search_commits_file_infos
elif args.table == 4:
    file_infos = wo_context_filtering_file_infos 

gold_info = FileInfo("gold_json.json")

def calculate_crr(info: FileInfo, subset:bool = False) -> float:
    with open(os.path.join(BASE_FOLDER, info.file_name), "r") as f:
        data = json.load(f)
    # pass@k is 1 if atleast one of the candidate patches prevents the crash, 0 otherwise

    if subset:
        pass_at_k = [1 if len(d.get("non_crash", [])) > 0 else 0 
                 for idx, d in data.items() 
                 if idx in ['384b9a05e5d4c3e74e31427e43a08f40f71db54a', '30f3a18df370dd15cce6f725620ac001e57dbee1', 'f55849742bdcdf58f6eaae61e747ac00e5c061f4', 'e0640a911365faa499433155c3d5d5b674b36a83', '9a58b47cdbbb40f7679a00991cac436716c70192', '4892aaa2ef26ab83c6b974f1db422f526f9aaec0', 'deb061ff946b02c2f4821f91683d89a68b2f45f8', 'e58343393efd3e0657ca37cf4466fc38d1f4d4df', '38efa4618c3379cf98642b9379d6e340b14d2702', 'f0ec9a394925aafbdf13d0a7e6af4cff860f0ed6', '75b0feac9a3f9e361fd60605e261f8a4ffef1c40', 'd8e8c32d387c7b35680b035aab36efdefe253ab7', '1a351beaed9d438481f1fc96aa336a25f71a2ae1', '641c688b5c5a0c80d5d5832c5cd9f361c1cbb0a9', '7d3c28ba3d4bf4b26e89ed1f1ca146e0223a2d36', '2e1943a94647f7732dd6fc60368642d6e8dc91b1', '46dd655664a8b38dbf7234683b294171a4e0142b', '093e7092e01bed192b564b04528826cc6f1dbf91', 'ee9018bd8989530a2dbdd62436efd8b1c3ecd3e5', 'bb3342477c4669ef082c7056b6e2be4f903e646e']]
    else:
        pass_at_k = [1 if len(d.get("non_crash", [])) > 0 else 0 for idx, d in data.items()]
    if args.table in [1, 2, 3]:
        total_bugs = 200
    if args.table == 4:
        total_bugs = 20
    if args.table == "ffmpeg":
        total_bugs = 10 
    # total_bugs = 200 if isinstance(args.table, int) else 10
    return round(100 * sum(pass_at_k) / total_bugs, 2) # CRR is average of pass@k, CRR is defined in Section 4

def get_files_in_patch(patch: str) -> Set[str]:
    diff_pattern = r'^diff --git a/(.+?) b/.*$'
    files = []
    for line in patch.strip().splitlines():
        if line.startswith("diff --git"):
            match = re.match(diff_pattern, line)
            if match:
                file_name = match.group(1)
                files.append(file_name)
    return set(files)

def get_all_patches(patches: Dict[str, Union[List[str], List[Dict[str, Any]]]]) -> List[str]:
    all_patches = []
    for patch_type, v in patches.items():
        if patch_type != 'error': all_patches += v
        else:
            for d in v: all_patches.append(d['patch']) 
    return all_patches

def calculate_table2(info: FileInfo, goldInfo: FileInfo, subset:bool = False) -> Dict[str, float]:
    results = {
        "recall": [],
        "all": 0,
        "any": 0,
        "none": 0
    }

    with open(os.path.join(BASE_FOLDER, info.file_name), "r") as f:
        data = json.load(f)
    with open(os.path.join(BASE_FOLDER, goldInfo.file_name), "r") as f:
        gold_data = json.load(f)
    for idx, gold_patch_data in gold_data.items():
        if subset:
            if idx not in ['384b9a05e5d4c3e74e31427e43a08f40f71db54a', '30f3a18df370dd15cce6f725620ac001e57dbee1', 'f55849742bdcdf58f6eaae61e747ac00e5c061f4', 'e0640a911365faa499433155c3d5d5b674b36a83', '9a58b47cdbbb40f7679a00991cac436716c70192', '4892aaa2ef26ab83c6b974f1db422f526f9aaec0', 'deb061ff946b02c2f4821f91683d89a68b2f45f8', 'e58343393efd3e0657ca37cf4466fc38d1f4d4df', '38efa4618c3379cf98642b9379d6e340b14d2702', 'f0ec9a394925aafbdf13d0a7e6af4cff860f0ed6', '75b0feac9a3f9e361fd60605e261f8a4ffef1c40', 'd8e8c32d387c7b35680b035aab36efdefe253ab7', '1a351beaed9d438481f1fc96aa336a25f71a2ae1', '641c688b5c5a0c80d5d5832c5cd9f361c1cbb0a9', '7d3c28ba3d4bf4b26e89ed1f1ca146e0223a2d36', '2e1943a94647f7732dd6fc60368642d6e8dc91b1', '46dd655664a8b38dbf7234683b294171a4e0142b', '093e7092e01bed192b564b04528826cc6f1dbf91', 'ee9018bd8989530a2dbdd62436efd8b1c3ecd3e5', 'bb3342477c4669ef082c7056b6e2be4f903e646e']:
                continue
        gold_files = get_files_in_patch(gold_patch_data['non_crash'][0])
        patches = get_all_patches(data.get(idx, {}))
        for patch in patches:
            files = get_files_in_patch(patch)
            results["recall"].append(len(files.intersection(gold_files)) / len(gold_files))
            if files.issuperset(gold_files):
                results["all"] += 1
            elif len(files.intersection(gold_files)) > 0:
                results["any"] += 1
            elif len(files.intersection(gold_files)) == 0:
                results["none"] += 1
        if len(patches) < info.k: # when the tool does not produce a patch we assume the set of edited files to be empty, hence recall for them is 0 and they are in the none category
            results["recall"] += [0] * (info.k - len(patches))
            results["none"] += (info.k - len(patches))
    results["recall"] = round(sum(results["recall"]) / len(results["recall"]), 2)
    total = results["all"] + results["any"] + results["none"]
    results["all"] = round(100 * results["all"] / total, 2)
    results["any"] = round(100 * results["any"] / total, 2)
    results["none"] = round(100 * results["none"] / total, 2)

    return results

def build_table(file_infos) -> Table:
    table = Table(show_header=True, header_style="bold white", box=box.SIMPLE_HEAVY)
    if args.table != 3 and args.table != 4:
        table.add_column("Setting", style="cyan", no_wrap=True)
    table.add_column("Tool", style="magenta")
    table.add_column("Max calls", justify="center")
    table.add_column("P@k", justify="center")
    if args.table == 1:
        table.add_column("CRR (%)", justify="center", style="bold green")
    if args.table == 2:
        table.add_column("Avg. Recall", justify="center", style="magenta")
        table.add_column("All/Any/None (%)", justify="center", style="green")
    if args.table ==3 or args.table ==4 or args.table == "ffmpeg":
        table.add_column("CRR (%)", justify="center", style="bold green")
        table.add_column("Avg. Recall", justify="center", style="magenta")
        table.add_column("All/Any/None (%)", justify="center", style="green")
        
    if args.table == 2:
        file_infos = list(filter(lambda x: x.setting != "Assisted", file_infos))
    file_infos_sorted = sorted(file_infos, key=lambda x: (x.setting))

    last_setting = None
    for info in file_infos_sorted:
        if args.table == 4:
            crr = calculate_crr(info, True)
            result = calculate_table2(info, gold_info, True)
        else:
            crr = calculate_crr(info)
            result = calculate_table2(info, gold_info)
        if last_setting is not None and info.setting != last_setting:
            table.add_section()  # horizontal rule between settings

        setting_display = info.setting if info.setting != last_setting else ""
        if args.table == 1:
            table.add_row(
                setting_display,
                f"[bold]{info.tool}[/bold]" if "Code Researcher" in info.tool else info.tool,
                str(info.max_calls), f"P@{info.k}", f"{crr:.2f}")
        if args.table == 2:
            all_any_none = f"{result['all']:.1f}/{result['any']:.1f}/{result['none']:.1f}"
            table.add_row(
                setting_display,
                f"[bold]{info.tool}[/bold]" if "Code Researcher" in info.tool else info.tool,
                str(info.max_calls),
                f"P@{info.k}",
                f"{result['recall']:.2f}",
                all_any_none
            )
        if args.table == 3 or args.table == 4:
            all_any_none = f"{result['all']:.1f}/{result['any']:.1f}/{result['none']:.1f}"
            table.add_row(
                f"[bold]{info.tool}[/bold]" if "Code Researcher" in info.tool else info.tool,
                str(info.max_calls), f"P@{info.k}", f"{crr:.2f}", f"{result['recall']:.2f}", all_any_none)
        
        if args.table == "ffmpeg":
            all_any_none = f"{result['all']:.1f}/{result['any']:.1f}/{result['none']:.1f}"
            table.add_row(
                setting_display,
                f"[bold]{info.tool}[/bold]" if "Code Researcher" in info.tool else info.tool,
                str(info.max_calls), f"P@{info.k}", f"{crr:.2f}",
                f"{result['recall']:.2f}",
                all_any_none
            )

        last_setting = info.setting

    return table

if __name__ == "__main__":
    console = Console()
    table = build_table(file_infos)
    console.print(table)

