import json
import time
import shutil
import re
import subprocess
import argparse
from pathlib import Path
import concurrent.futures
import threading
from tqdm import tqdm

# ========== Parameter Settings ==========
parser = argparse.ArgumentParser()
parser.add_argument("--language", default="cpp", type=str,
                    choices=['cpp', 'java', 'python'],
                    help="The language type.")
parser.add_argument("--data_file", required=True, type=str,
                    help="Path to the input JSONL file.")
parser.add_argument("--output_dir", required=True, type=str,
                    help="Directory for intermediate temp files.")
parser.add_argument("--save_dir", required=True, type=str,
                    help="Directory for saving output.")
parser.add_argument("--output_file", required=True, type=str,
                    help="Path to final output JSONL file.")
parser.add_argument("--max_workers", type=int, default=10)
parser.add_argument("--node_num_threhold", type=int, default=8)
args = parser.parse_args()

joern_parse_cmd = "joern-parse"
joern_export_cmd = "joern-export"
log_file_path = f"log_{args.language}.txt"

# ========== Macro-related (C/C++ only) ==========
KNOWN_KERNEL_MACROS = [
    "__user", "__force", "__iomem", "__must_check", "__kernel", "__init", "noinline",
    "__exit", "__ref", "__maybe_unused", "__always_inline", "__aligned", "av_cold",
    "__attribute__", "__inline__", "likely", "unlikely", "__must_hold", "u8",
    "noinline_for_stack", "__releases", "__acquires", "__u64", "__le32", "__be32", "__CHAR_UNSIGNED__"
]
TYPENAME_MACROS = {
    "__releases", "__acquires", "__user", "__force", "__iomem",
    "__must_check", "__init", "__exit", "__ref", "__aligned", "__u64",
    "__le32", "__be32", "__CHAR_UNSIGNED__"
}
FUNC_MACROS = {"likely", "unlikely"}
HARDCODE_MACRO_VALUES = {
    "AVERROR": "(x) (-x)", "ENOMEM": "12", "PIX_FMT_BGR24": "3",
    "AV_PIX_FMT_RGB555": "5", "AV_PIX_FMT_YUV420P": "0",
    "RU": "-169", "GU": "-331", "BU": "500",
    "RV": "500", "GV": "-419", "BV": "-81", "RGB2YUV_SHIFT": "15"
}


def collect_sample_json(idx, filename, edges, node_info_list, edge_labels, target):
    return {
        "id": idx,
        "filename": filename,
        "nodes": list(range(len(node_info_list))),
        "edges": edges,
        "nodes_label": [info["Node_Type"] for info in node_info_list],
        "nodes_codes": [info["code"] for info in node_info_list],
        "edges_label": edge_labels,
        "code_lines": [info["line_number"] for info in node_info_list],
        "target": target
    }


def extract_macros_from_code(path):
    with open(path, "r", encoding="utf-8") as f:
        code = f.read()
    macros = set()
    macros |= set(re.findall(r"\b(__?[\w\d]{2,30})\b(?!\s*\()", code))
    macros |= set(re.findall(r"\b([A-Z_][A-Z0-9_]{2,30})\s*\(", code))
    macros |= set(re.findall(r"\b([A-Z_][A-Z0-9_]{2,30})\b(?!\s*\()", code))
    for m in KNOWN_KERNEL_MACROS:
        if re.search(rf'\b{re.escape(m)}\b', code):
            macros.add(m)
    keywords = {"if", "else", "for", "while", "return", "switch", "case", "break", "continue",
                "goto", "sizeof", "typedef", "struct", "int", "void"}
    return {m for m in macros if m not in keywords}


def generate_macro_defs(macros):
    lines = []
    for m in sorted(macros):
        if m in HARDCODE_MACRO_VALUES:
            lines.append(f"#define {m} {HARDCODE_MACRO_VALUES[m]}")
        elif m in TYPENAME_MACROS:
            lines.append(f"#define {m}")
        elif m in FUNC_MACROS:
            lines.append(f"#define {m}(x) (x)")
        else:
            lines.append(f"#define {m}(...)")
    return "\n".join(lines)


def extract_attributes(attr_str):
    pattern = re.findall(r'(\w+)\s*=\s*"((?:[^"\\]|\\.)*)"', attr_str, re.DOTALL)
    return {k: v.replace('\\"', '"') for k, v in pattern}


def parse_node_line(line, line_offset):
    match = re.match(r'"(\d+)"\s+\[(.*)\]', line, re.DOTALL)  # Capture the entire attribute block
    if not match:
        return None, None

    node_id, attr_str = match.groups()

    attr = extract_attributes(attr_str)

    node_type = attr.get("label", "")

    # Node Filtering
    if node_type == "METHOD":
        signature = attr.get("SIGNATURE", "")
        if signature == "":
            return None, None
        code = attr.get("FULL_NAME", "")
    else:
        code = attr.get("CODE", "")
        if node_type == "CONTROL_STRUCTURE":
            if code.strip():
                # Find the first occurrence and use it as the split point
                split_marker = "    "
                if split_marker in code:
                    code = code.split(split_marker, 1)[0].strip()

    if node_type == "UNKNOWN":
        node_type = "OTHERS"

    line_number = attr.get("LINE_NUMBER")
    try:
        if line_number is not None:
            line_number = int(line_number)
            if line_number > 1:
                line_number -= line_offset
        else:
            line_number = -1
    except ValueError:
        line_number = -1

    arg_index = attr.get("ARGUMENT_INDEX")
    if arg_index is not None:
        try:
            arg_index = int(arg_index)
        except ValueError:
            arg_index = None

    is_external = attr.get("IS_EXTERNAL", "false")

    return node_id, {
        "code": code,
        "line_number": line_number,
        "Node_Type": node_type,
        "arg_index": arg_index,
        "is_external": is_external
    }


def parse_dot_to_adj_matrix_with_node_info(dot_file_path, line_offset):
    node_info_map = {}
    edges = []

    edge_pattern = re.compile(r'"(\d+)"\s*->\s*"(\d+)"\s*\[(.*?)\]')
    label_re = re.compile(r'label="(.*?)"')

    with open(dot_file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    buffer = ""
    inside_node = False

    for line in lines:
        line = line.strip()
        if not line:
            continue

        # Check if it is an edge
        match_edge = edge_pattern.search(line)
        if match_edge:
            src, dst, attrs = match_edge.groups()
            label_match = label_re.search(attrs)
            label = label_match.group(1) if label_match else ''
            edges.append((src, dst, label))
            continue

        # Multi-line node processing
        if "[" in line and "]" not in line:
            buffer = line
            inside_node = True
            continue
        elif inside_node:
            buffer += "    " + line
            if "]" in line:
                inside_node = False
                node_id, info = parse_node_line(buffer, line_offset)
                if node_id is not None:
                    node_info_map[node_id] = info
                buffer = ""
        else:
            # Single-line node
            node_id, info = parse_node_line(line, line_offset)
            if node_id is not None:
                node_info_map[node_id] = info

    def should_filter(node_id, info):
        label = info.get("Node_Type", "")
        arg_index = info.get("arg_index")
        is_external = info.get("is_external")
        line_number = info.get("line_number")
        code = info.get("code", "").strip()

        # 保留节点
        if label in ["META_DATA", "BINDING"]:
            return False
        if label == "TYPE_DECL" and line_number != -1:
            return False

        # 过滤节点
        if code == "" or code == "<empty>":
            return True
        if label in ["BLOCK", "NAMESPACE", "NAMESPACE_BLOCK", "FIELD_IDENTIFIER",
                     "FILE", "IDENTIFIER", "LITERAL", "LOCAL", "METHOD_RETURN"]:
            return True
        if arg_index is not None and arg_index != -1:
            return True
        if is_external == "true":
            return True
        if line_number == -1:
            return True

        # No filtering by default
        return False

    filtered_nodes = {
        nid for nid, info in node_info_map.items()
        if not should_filter(nid, info)
    }

    filtered_edges = [
        (src, dst, label) for src, dst, label in edges
        if src in filtered_nodes and dst in filtered_nodes
    ]

    nodes_sorted = sorted(filtered_nodes, key=int)
    node_to_idx = {node_id: idx for idx, node_id in enumerate(nodes_sorted)}

    edge_list = []
    for src, dst, orig_label in filtered_edges:
        src_idx = node_to_idx[src]
        dst_idx = node_to_idx[dst]
        line_src = node_info_map[src].get("line_number")
        line_dst = node_info_map[dst].get("line_number")

        aux_label = None
        if isinstance(line_src, int) and isinstance(line_dst, int) and line_src != -1 and line_dst != -1:
            if line_src == line_dst:
                aux_label = "SAME_ROW"
            elif abs(line_src - line_dst) == 1:
                aux_label = "NEAR"

        edge_list.append([src_idx, dst_idx, orig_label])
        if aux_label:
            edge_list.append([src_idx, dst_idx, aux_label])

    last_edges = []
    edges_label = []
    for edge in edge_list:
        last_edges.append([edge[0], edge[1]])
        edges_label.append(edge[2])  # String type

    node_info_list = []
    for node_id in nodes_sorted:
        info = node_info_map[node_id]
        node_info_list.append({
            "node_id": node_id,
            "code": info["code"],
            "line_number": info["line_number"],
            "Node_Type": info["Node_Type"]
        })

    return last_edges, node_info_list, edges_label

# Thread-safe printing
print_lock = threading.Lock()


def process_single_function(idx, data):
    code = data["func"]
    if args.language == "java":
        code = f"public class Temp {{\n{code}\n}}"
    id_str = str(data["idx"])
    ext = {"cpp": ".c", "python": ".py", "java": ".java"}[args.language]
    temp_file = Path(args.output_dir) / f"{id_str}{ext}"
    with open(temp_file, "w") as f:
        f.write(code)

    patched_file = temp_file
    macro_lines_count = 0

    if args.language == "cpp":
        macros = extract_macros_from_code(temp_file)
        macro_header = generate_macro_defs(macros)
        macro_lines_count = macro_header.count("\n") + 1 if macro_header else 0
        patched_file = Path(args.output_dir) / f"{id_str}_patched.c"
        with open(patched_file, "w") as f:
            if macro_header:
                f.write(macro_header + "\n" + code.lstrip("\n"))
            else:
                f.write(code.lstrip("\n"))

    bin_file = Path(args.output_dir) / f"{id_str}_cpg.bin"
    dot_dir = Path(args.output_dir) / f"{id_str}_dot_ast"
    # Step 1: generate CPG
    parse_result = subprocess.run([joern_parse_cmd, str(patched_file), "--output", str(bin_file)])
    if parse_result.returncode != 0:
        with print_lock:
            with open(log_file_path, "a") as log_file:
                log_file.write(
                    f"[ERROR] joern-parse failed on index {idx}"
                )
        return None
    subprocess.run([
        joern_export_cmd,
        str(bin_file),
        "--out", str(dot_dir),
        "--repr", "all",
        "--format", "dot"
    ])

    dot_file = dot_dir / "export.dot"
    edges, node_list, labels = parse_dot_to_adj_matrix_with_node_info(dot_file, macro_lines_count)

    # Temporary files can be deleted
    # temp_file.unlink(missing_ok=True)
    # if patched_file != temp_file:
    #     patched_file.unlink(missing_ok=True)
    # bin_file.unlink(missing_ok=True)
    # shutil.rmtree(dot_dir, ignore_errors=True)

    return collect_sample_json(idx, id_str, edges, node_list, labels, data["target"])


def main():
    start = time.time()
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    Path(args.save_dir).mkdir(parents=True, exist_ok=True)
    with open(args.data_file, 'r', encoding='utf-8') as f:
        data_list = json.load(f)

    results = [None] * len(data_list)
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        future_to_idx = {
            executor.submit(process_single_function, idx, data): idx
            for idx, data in enumerate(data_list)
        }

        for future in tqdm(concurrent.futures.as_completed(future_to_idx), total=len(data_list)):
            idx = future_to_idx[future]
            try:
                result = future.result()
                if result:
                    results[idx] = result
            except Exception as e:
                with open(log_file_path, "a") as log_file:
                    log_file.write(f"[ERROR] Exception in idx {idx}: {e}\n")

    with open(args.output_file, "w") as f:
        for r in results:
            if r:
                f.write(json.dumps(r) + "\n")

    print(f"✅ complete, elapsed time {time.time() - start:.2f} seconds")


if __name__ == "__main__":
    main()
