import argparse
import json
import time
import re
import shutil
import pandas as pd
import subprocess
from pathlib import Path
import concurrent.futures
import threading
from tqdm import tqdm

joern_parse_cmd = "joern-parse"
joern_export_cmd = "joern-export"
# Define log file path
log_file_path = "process.log"  # You can modify the filename and path as needed

# Keep original macros
KNOWN_KERNEL_MACROS = [
    "__user", "__force", "__iomem", "__must_check", "__kernel", "__init", "noinline",
    "__exit", "__ref", "__maybe_unused", "__always_inline", "__aligned", "av_cold",
    "__attribute__", "__inline__", "likely", "unlikely", "__must_hold", "u8",
    "noinline_for_stack", "__releases", "__acquires", "__u64", "__le32", "__be32", "__CHAR_UNSIGNED__"
]

TYPENAME_MACROS = {
    "__releases", "__acquires", "__user", "__force", "__iomem",
    "__must_check", "__init", "__exit", "__ref", "__aligned", "__u64",
    "__le32", "__be32", "__CHAR_UNSIGNED__"
}

FUNC_MACROS = {
    "likely", "unlikely"
}


def collect_sample_json(idx, filename, edges, node_info_list, edge_labels, target):
    # 'edges' is already integer indices, no need to convert
    return {
        "id": idx,
        "filename": filename,
        "nodes": list(range(len(node_info_list))),
        "edges": edges,
        "nodes_label": [info["Node_Type"] for info in node_info_list],
        "nodes_codes": [info["code"] for info in node_info_list],
        "edges_label": edge_labels,
        "code_lines": [info["line_number"] for info in node_info_list],
        "target": target
    }


def extract_attributes(attr_str):
    # Match key="value" patterns, allowing special characters and newlines in value
    pattern = re.findall(r'(\w+)\s*=\s*"((?:[^"\\]|\\.)*)"', attr_str, re.DOTALL)

    return {k: v.replace('\\"', '"') for k, v in pattern}


def extract_all_missing_macros_from_code(code):
    found_macros = set()

    # Normal macros starting with __ without parentheses
    pattern_plain_macro = r"\b(__?[\w\d]{2,30})\b(?!\s*\()"
    found_macros |= set(re.findall(pattern_plain_macro, code))

    # Function-like macros (with parentheses)
    pattern_func_macro = r"\b([A-Z_][A-Z0-9_]{2,30})\s*\("
    found_macros |= set(re.findall(pattern_func_macro, code))

    # Uppercase constant macros (without parentheses)
    pattern_upper_macro = r"\b([A-Z_][A-Z0-9_]{2,30})\b(?!\s*\()"
    found_macros |= set(re.findall(pattern_upper_macro, code))

    # Add known kernel macros
    for macro in KNOWN_KERNEL_MACROS:
        if re.search(r'\b' + re.escape(macro) + r'\b', code):
            found_macros.add(macro)

    # Remove C language keywords
    c_keywords = {
        "if", "else", "for", "while", "return", "switch", "case", "break",
        "continue", "goto", "sizeof", "typedef", "struct", "int", "void",
    }
    found_macros = {m for m in found_macros if m not in c_keywords}

    return found_macros


HARDCODE_MACRO_VALUES = {
    "AVERROR": "(x) (-x)",
    "ENOMEM": "12",
    "PIX_FMT_BGR24": "3",
    "AV_PIX_FMT_RGB555": "5",
    "AV_PIX_FMT_YUV420P": "0",
    "RU": "-169", "GU": "-331", "BU": "500",
    "RV": "500", "GV": "-419", "BV": "-81",
    "RGB2YUV_SHIFT": "15"
}


def generate_macro_definitions(macros):
    macro_lines = []
    for macro in sorted(macros):
        if macro in HARDCODE_MACRO_VALUES:
            macro_lines.append(f"#define {macro} {HARDCODE_MACRO_VALUES[macro]}  // hardcoded")
        elif macro in TYPENAME_MACROS:
            macro_lines.append(f"#define {macro}  // auto-typedef")
        elif macro in FUNC_MACROS:
            macro_lines.append(f"#define {macro}(x) (x)  // auto-func")
        else:
            macro_lines.append(f"#define {macro}(...)  // auto-func-like or empty")
    return "\n".join(macro_lines)


def parse_node_line(line):
    match = re.match(r'"(\d+)"\s+\[(.*)\]', line, re.DOTALL)
    if not match:
        return None, None

    node_id, attr_str = match.groups()

    attr = extract_attributes(attr_str)

    node_type = attr.get("label", "")

    code = attr.get("CODE", "")
    if node_type == "CONTROL_STRUCTURE" or "METHOD":
        if code.strip():
            split_marker = "    "
            if split_marker in code:
                code = code.split(split_marker, 1)[0].strip()

    #  Filter all nodes whose code or full_name contains TempClass
    if "TempClass" in code:
        return None, None
    if code.startswith("public class TempClass"):
        return None, None

    if node_type == "UNKNOWN":
        node_type = "OTHERS"

    line_number = attr.get("LINE_NUMBER")
    try:
        if line_number is not None:
            line_number = int(line_number)
            if line_number > 1:
                line_number -= 1
        else:
            line_number = -1
    except ValueError:
        line_number = -1

    arg_index = attr.get("ARGUMENT_INDEX")
    if arg_index is not None:
        try:
            arg_index = int(arg_index)
        except ValueError:
            arg_index = None

    is_external = attr.get("IS_EXTERNAL", "false")

    return node_id, {
        "code": code,
        "line_number": line_number,
        "Node_Type": node_type,
        "arg_index": arg_index,
        "is_external": is_external
    }


def parse_dot_to_adj_matrix_with_node_info(dot_file_path):
    node_info_map = {}
    edges = []

    edge_pattern = re.compile(r'"(\d+)"\s*->\s*"(\d+)"\s*\[(.*?)\]')
    label_re = re.compile(r'label="(.*?)"')

    with open(dot_file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    buffer = ""
    inside_node = False

    for line in lines:
        line = line.strip()
        if not line:
            continue

        match_edge = edge_pattern.search(line)
        if match_edge:
            src, dst, attrs = match_edge.groups()
            label_match = label_re.search(attrs)
            label = label_match.group(1) if label_match else ''
            edges.append((src, dst, label))
            continue

        # Multi-line node processing
        if "[" in line and "]" not in line:
            buffer = line
            inside_node = True
            continue
        elif inside_node:
            buffer += "    " + line  # Join with newline characters to preserve multi-line information
            if "]" in line:
                inside_node = False
                node_id, info = parse_node_line(buffer)
                if node_id is not None:
                    node_info_map[node_id] = info
                buffer = ""
        else:
            # Single-line node
            node_id, info = parse_node_line(line)
            if node_id is not None:
                node_info_map[node_id] = info

    def should_filter(node_id, info):
        label = info.get("Node_Type", "")
        arg_index = info.get("arg_index")
        is_external = info.get("is_external")
        line_number = info.get("line_number")
        code = info.get("code", "").strip()

        # keep
        if label in ["META_DATA", "BINDING"]:
            return False
        if label == "TYPE_DECL" and line_number != -1:
            return False

        # filter
        if code == "" or code == "<empty>":
            return True
        if label in ["BLOCK", "NAMESPACE", "NAMESPACE_BLOCK", "FIELD_IDENTIFIER",
                     "FILE", "IDENTIFIER", "LITERAL", "LOCAL", "METHOD_RETURN"]:
            return True
        if arg_index is not None and arg_index != -1:
            return True
        if is_external == "true":
            return True
        if line_number == -1:
            return True

        return False

    filtered_nodes = {
        nid for nid, info in node_info_map.items()
        if not should_filter(nid, info)
    }

    filtered_edges = [
        (src, dst, label) for src, dst, label in edges
        if src in filtered_nodes and dst in filtered_nodes
    ]

    nodes_sorted = sorted(filtered_nodes, key=int)
    node_to_idx = {node_id: idx for idx, node_id in enumerate(nodes_sorted)}

    edge_list = []
    for src, dst, orig_label in filtered_edges:
        src_idx = node_to_idx[src]
        dst_idx = node_to_idx[dst]
        line_src = node_info_map[src].get("line_number")
        line_dst = node_info_map[dst].get("line_number")

        aux_label = None
        if isinstance(line_src, int) and isinstance(line_dst, int) and line_src != -1 and line_dst != -1:
            if line_src == line_dst:
                aux_label = "SAME_ROW"
            elif abs(line_src - line_dst) == 1:
                aux_label = "NEAR"

        edge_list.append([src_idx, dst_idx, orig_label])
        if aux_label:
            edge_list.append([src_idx, dst_idx, aux_label])

    last_edges = []
    edges_label = []
    for edge in edge_list:
        last_edges.append([edge[0], edge[1]])
        edges_label.append(edge[2])  # string

    node_info_list = []
    for node_id in nodes_sorted:
        info = node_info_map[node_id]
        node_info_list.append({
            "node_id": node_id,
            "code": info["code"],
            "line_number": info["line_number"],
            "Node_Type": info["Node_Type"]
        })

    return last_edges, node_info_list, edges_label


# Thread-safe printing (to prevent output conflicts in multithreading)
print_lock = threading.Lock()


def process_single_function(task, output_dir, macro_lines_count=0):
    idx = task["idx"]
    filepath = task["filepath"]
    filename = task["filename"]
    target = task["target"]

    with print_lock:
        with open(log_file_path, "a") as log_file:
            log_file.write(
                f"[INFO] Processing task: idx={idx}, target={target}\n"
            )

    bin_file = output_dir / f"{idx}_cpg.bin"

    # Step 1: generate CPG
    parse_result = subprocess.run([joern_parse_cmd, str(filepath), "--output", str(bin_file)],
                                  stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
    if parse_result.returncode != 0:
        with print_lock:
            with open(log_file_path, "a") as log_file:
                log_file.write(
                    f"[ERROR] joern-parse failed on index {idx}"
                )
        return None

    # Step 3: Export AST in 'all' representation as dot format
    dot_ast_dir = output_dir / f"{idx}_dot_all"
    if dot_ast_dir.exists():
        shutil.rmtree(dot_ast_dir)
    export_result_dot = subprocess.run([
        joern_export_cmd,
        str(bin_file),
        "--out", str(dot_ast_dir),
        "--repr", "all",
        "--format", "dot"
    ], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
    if export_result_dot.returncode != 0:
        with print_lock:
            print(f"[ERROR] joern-export DOT failed on index {idx}")

        # shutil.rmtree(neo4jcsv_dir, ignore_errors=True)
        return None

    main_dot_file = dot_ast_dir / "export.dot"

    edges, node_list, edge_types = (
        parse_dot_to_adj_matrix_with_node_info(main_dot_file))
    row = collect_sample_json(idx, filename, edges, node_list, edge_types, target)

    # shutil.rmtree(neo4jcsv_dir, ignore_errors=True)
    # shutil.rmtree(dot_ast_dir, ignore_errors=True)

    return row


def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--language", default="cpp", type=str,
                        help="The language type.", choices=['cpp', 'java'])
    parser.add_argument("--data_file",
                        default="../temp/train.json",
                        type=str,
                        help="The input training data file (a jsonl file).")
    parser.add_argument("--output_dir", default="train_cpg", type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument("--save_dir", default="./Joern-Data/data_preprocess", type=str,
                        help="The output directory where the model predictions and checkpoints will be written.")
    parser.add_argument('--process_file', type=str, help='Path to the file being processed')
    parser.add_argument('--process_file2', type=str, help='Path to the original processed file after reprocessing')
    parser.add_argument('--output_file', type=str, help='Path to the final reprocessed output JSONL file (optional)')

    args = parser.parse_args()

    start_time = time.time()  # start time
    args.max_workers = 10
    args.node_num_threhold = 8

    # Changes made according to the inspection results
    need_reprocess_demo_idx = [0, 1]
    print(f'data_file: {args.data_file}')
    print(f'output_dir: {args.output_dir}')
    print(f'save_dir: {args.save_dir}')
    print(f'output_file: {args.output_file}')
    print(f'need reprocess demo length: {len(need_reprocess_demo_idx)}')
    print(f'node number threhold: {args.node_num_threhold}')

    output_dir = Path(args.output_dir)
    save_dir = Path(args.save_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_json_path = save_dir / args.output_file
    print(f'processing {args.data_file}')

    with open(args.data_file, 'r') as f:
        data = json.load(f, parse_int=str)  # ✅ list[dict]

    df = pd.DataFrame(data)

    start_df = 0  # 0
    end_df = len(df)  # len(df)
    df = df[start_df:end_df]
    print(f'start: {start_df}')
    print(f'end: {end_df}')

    # Construct the input data list
    tasks = []
    for row_no, row in df.iterrows():
        idx = row_no

        # Only process samples specified in need_reprocess_demo_idx
        if idx not in need_reprocess_demo_idx:
            continue
        method_code = row["func"]  # ⭐ old 'func' → 'code'
        target = row["target"]  # ⭐ old 'target' → 'vul'
        branch1 = f"{idx}"  # Keep consistent with previous logic

        type_c = ''
        temp_path = output_dir / f"{branch1}.c"
        with open(temp_path, "w") as f:
            f.write(method_code)

        lines = method_code.splitlines(keepends=True)
        
        if lines:
            first_line = lines[0]
            # Use regex to replace the keyword 'override' (as a whole word) and any surrounding spaces in the first line with a single space
            first_line_modified = re.sub(r'[ \t]*\boverride\b[ \t]*', ' ', first_line, 1)
            # If the first line was modified, replace it with the modified line
            lines[0] = first_line_modified
            method_code = ''.join(lines)

        macro_lines_count = 0
        if type_c == 'c':
            # For C language, check whether macro definitions are complete
            missing_macros = extract_all_missing_macros_from_code(method_code)
            macro_header = generate_macro_definitions(missing_macros)
            macro_lines_count = macro_header.count("\n") + 1 if macro_header else 0

        cpp_template = f"""{method_code}"""
        # If the single function contains 'override', Joern's parsing will fail
        if type_c == 'c':
            with open(temp_path, "w") as f:
                f.write(macro_header + "\n" + cpp_template)
        else:
            with open(temp_path, "w") as f:
                f.write(cpp_template)

        tasks.append({
            "idx": idx,
            "filename": branch1,
            "filepath": temp_path,
            "target": target,
        })

    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
        future_to_idx = {
            executor.submit(process_single_function, task, output_dir, macro_lines_count): task["idx"]
            for task in tasks
        }

        for future in tqdm(concurrent.futures.as_completed(future_to_idx), total=len(tasks)):
            
            idx = future_to_idx[future]
            try:
                result = future.result()
                if result and len(result['nodes']) > args.node_num_threhold:
                    results.append(result)
                    # results[idx - start_df] = result
            except Exception as e:
                with open(log_file_path, "a") as log_file:
                    log_file.write(f"[ERROR] Exception in idx {idx}: {e}\n")

    result_len = len(results)
    results = sorted(results, key=lambda result: result['id'])
    with open(output_json_path, "w") as f:
        for row in results:
            if row:
                f.write(json.dumps(row) + "\n")
    end_time = time.time()  # end time
    run_time = end_time - start_time  # run time
    print(
        f"✅ All samples processed successfully and written in order to {output_json_path}, total {result_len} records, runtime: {run_time:.2f} seconds")

    # Step 1: Read all data from the original file
    original_data = []
    with open(args.process_file, 'r') as f:
        for line in f:
            original_data.append(json.loads(line))

    # Step 2: Convert results into a dictionary mapping from ID to data
    results_dict = {result['id']: result for result in results}

    # Step 3: Build a set of IDs that need to be reprocessed
    need_reprocess_set = set(need_reprocess_demo_idx)

    # Step 4: Process and generate updated data
    updated_data = []
    for item in original_data:
        item_id = item['id']
        # Check if reprocessing is needed
        if item_id in need_reprocess_set:
            # Use the reprocessed data if available in results
            if item_id in results_dict:
                updated_data.append(results_dict[item_id])
            # Otherwise skip (equivalent to deletion)
        else:
            # Keep data that does not require reprocessing
            updated_data.append(item)

    # Step 5: Write the updated data to a new file
    with open(args.process_file2, 'w') as f:
        for item in updated_data:
            f.write(json.dumps(item) + '\n')

    print(f"✅ Data update completed, saved to {args.process_file2}, original data count: {len(original_data)}, updated data count: {len(updated_data)}")


if __name__ == "__main__":
    main()
