import os
import re
import json
import logging
import argparse
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Callable

try:
    import numpy as np
except Exception:
    np = None

from config import NODE_TYPE_MAP, NUM_NODE_TYPES

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("ASTProcessor")


@dataclass
class GraphData:
    edges: List[Tuple[int, int]]
    edge_labels: List[int]
    node_labels: List[int]
    node_attributes: List[List[float]]
    graph_indicator: List[int]
    graph_label: int
    num_nodes: int


class ASTProcessor:
    """
    Parse raw AST text files into HRG intermediate format:
      - A.txt (edge list)
      - edge_labels.txt
      - node_labels.txt
      - node_attributes.txt
      - graph_indicator.txt
      - graph_labels.txt

    Enhancements over original:
      1) Parse per-file into a GraphData, then merge with global offset
      2) Robust indent parsing (count '|' as hierarchy level)
      3) Validation (edge/label alignment, node feature length)
      4) CLI support & optional NPZ export + stats JSON
    """

    def __init__(
        self,
        raw_dir: str,
        out_dir: str,
        label_fn: Optional[Callable[[str], int]] = None,
        file_suffix: str = ".txt",
        export_npz: bool = False,
    ):
        self.raw_dir = raw_dir
        self.out_dir = out_dir
        self.file_suffix = file_suffix
        self.export_npz = export_npz

        # label_fn decides graph label from filename; default uses "_1" heuristic like your code
        self.label_fn = label_fn if label_fn is not None else (lambda fname: 1 if "_1" in fname else 0)

        # Global merged storage
        self.edges: List[Tuple[int, int]] = []
        self.edge_labels: List[int] = []
        self.node_labels: List[int] = []
        self.graph_indicator: List[int] = []
        self.node_attributes: List[List[float]] = []
        self.graph_labels: List[int] = []

        self._global_node_offset = 0
        self._graph_id = 0

    # --------------------------
    # Feature extraction helpers
    # --------------------------
    @staticmethod
    def _normalize_charclass(low: int, high: int) -> List[float]:
        """Normalize character class ranges to [0,1]."""
        center = (low + high) / 2.0 / 255.0
        width = (high - low) / 255.0
        return [center, width]

    @staticmethod
    def _count_level(line: str) -> int:
        """
        Robust hierarchy level:
          - many AST dumps use leading pipes like: '| |  CONCAT'
          - We count the number of '|' characters (not spaces) as level.
        """
        return line.count("|")

    def _parse_node_features(
        self,
        content: str,
        capture_map: Dict[int, int],
        current_node: int,
        edges: List[Tuple[int, int]],
        edge_labels: List[int],
    ) -> Tuple[List[float], int, str]:
        """
        Return:
          - node feature vector (one-hot + [center,width,index])
          - node_label int
          - matched_type string
        """
        type_vec = [0] * NUM_NODE_TYPES
        extra_feat = [0.0, 0.0, -1.0]  # [Center, Width, Index]
        matched_type = "ROOT"

        for t_name, t_idx in NODE_TYPE_MAP.items():
            if content.startswith(t_name):
                type_vec[t_idx] = 1
                matched_type = t_name

                # CHARCLASS [low, high]
                if t_name == "CHARCLASS":
                    m = re.search(r"\[(\d+)\s*,\s*(\d+)\]", content)
                    if m:
                        l, h = int(m.group(1)), int(m.group(2))
                        extra_feat[:2] = self._normalize_charclass(l, h)

                # CAPTURELEFT / CAPTURERIGHT Index: k
                elif t_name in ("CAPTURELEFT", "CAPTURERIGHT"):
                    m = re.search(r"Index:\s*(\d+)", content)
                    if m:
                        idx = int(m.group(1))
                        extra_feat[2] = float(idx)
                        if t_name == "CAPTURELEFT":
                            capture_map[idx] = current_node

                # BACKREFERENCE Refers to: k  (semantic edge type=3)
                elif t_name == "BACKREFERENCE":
                    m = re.search(r"Refers to:\s*(\d+)", content)
                    if m:
                        ref_idx = int(m.group(1))
                        extra_feat[2] = float(ref_idx)
                        if ref_idx in capture_map:
                            target = capture_map[ref_idx]
                            edges.extend([(current_node, target), (target, current_node)])
                            edge_labels.extend([3, 3])

                break

        node_label = NODE_TYPE_MAP.get(matched_type, NODE_TYPE_MAP["ROOT"])
        feat = list(map(float, type_vec)) + list(map(float, extra_feat))
        return feat, node_label, matched_type

    # --------------------------
    # Core parsing (per file)
    # --------------------------
    def _parse_file(self, path: str, graph_id: int, graph_label: int) -> GraphData:
        """
        Parse one AST text file into a standalone GraphData with local node ids [0..n-1].
        Then caller merges with global offset.
        """
        with open(path, "r", encoding="utf-8") as f:
            raw_lines = [ln.rstrip("\n") for ln in f.readlines()]

        # Remove empty lines early
        lines = [ln for ln in raw_lines if ln.strip()]
        if not lines:
            # still create a trivial graph with a single ROOT
            root_feat = [0.0] * NUM_NODE_TYPES + [0.0, 0.0, -1.0]
            root_feat[NODE_TYPE_MAP["ROOT"]] = 1.0
            return GraphData(
                edges=[],
                edge_labels=[],
                node_labels=[NODE_TYPE_MAP["ROOT"]],
                node_attributes=[root_feat],
                graph_indicator=[graph_id],
                graph_label=graph_label,
                num_nodes=1,
            )

        edges: List[Tuple[int, int]] = []
        edge_labels: List[int] = []
        node_labels: List[int] = []
        node_attributes: List[List[float]] = []
        graph_indicator: List[int] = []
        capture_map: Dict[int, int] = {}

        # local node id allocator
        next_id = 0

        # Create explicit ROOT node
        root = next_id
        next_id += 1
        root_feat = [0.0] * NUM_NODE_TYPES + [0.0, 0.0, -1.0]
        root_feat[NODE_TYPE_MAP["ROOT"]] = 1.0
        node_attributes.append(root_feat)
        node_labels.append(NODE_TYPE_MAP["ROOT"])
        graph_indicator.append(graph_id)

        # Stack: (node_id, level)
        stack: List[Tuple[int, int, str]] = [(root, -1, "ROOT")]  # root at level -1

        for ln in lines:
            level = self._count_level(ln)
            content = ln.lstrip("| ").strip()

            cur = next_id
            next_id += 1

            feat, nlabel, matched_type = self._parse_node_features(
                content=content,
                capture_map=capture_map,
                current_node=cur,
                edges=edges,
                edge_labels=edge_labels,
            )

            node_attributes.append(feat)
            node_labels.append(nlabel)
            graph_indicator.append(graph_id)

            # pop until parent level < current level
            while stack and stack[-1][1] >= level:
                stack.pop()

            parent, parent_level, parent_type = stack[-1] if stack else (root, -1, "ROOT")

            # syntactic edges (undirected via two directed entries), label=1
            edges.extend([(parent, cur), (cur, parent)])
            edge_labels.extend([1, 1])

            # lookaround scope edge (directed), label=2 (keep your convention)
            if "LOOK" in parent_type:
                edges.append((parent, cur))
                edge_labels.append(2)

            stack.append((cur, level, matched_type))

        return GraphData(
            edges=edges,
            edge_labels=edge_labels,
            node_labels=node_labels,
            node_attributes=node_attributes,
            graph_indicator=graph_indicator,
            graph_label=graph_label,
            num_nodes=next_id,
        )

    # --------------------------
    # Merge + validation + save
    # --------------------------
    def _merge_graph(self, g: GraphData):
        off = self._global_node_offset

        # merge nodes
        self.node_attributes.extend(g.node_attributes)
        self.node_labels.extend(g.node_labels)
        self.graph_indicator.extend(g.graph_indicator)
        self.graph_labels.append(g.graph_label)

        # merge edges with offset
        self.edges.extend([(u + off, v + off) for (u, v) in g.edges])
        self.edge_labels.extend(g.edge_labels)

        self._global_node_offset += g.num_nodes
        self._graph_id += 1

    def _validate(self):
        # edge-label alignment
        if len(self.edges) != len(self.edge_labels):
            raise ValueError(f"edges ({len(self.edges)}) != edge_labels ({len(self.edge_labels)})")

        # feature length alignment
        feat_len = NUM_NODE_TYPES + 3
        bad = [i for i, x in enumerate(self.node_attributes) if len(x) != feat_len]
        if bad:
            raise ValueError(f"Found {len(bad)} nodes with invalid feature length != {feat_len}, e.g. {bad[:5]}")

        # graph_indicator length equals num_nodes
        if len(self.graph_indicator) != len(self.node_labels):
            raise ValueError("graph_indicator length != node_labels length")

        # graph_labels length equals number of graphs
        # (graph_id increments as we merge)
        if len(self.graph_labels) != self._graph_id:
            raise ValueError("graph_labels length != number of merged graphs")

    def _save_output(self):
        os.makedirs(self.out_dir, exist_ok=True)

        def save_txt(name: str, data: List):
            with open(os.path.join(self.out_dir, name), "w", encoding="utf-8") as f:
                for x in data:
                    f.write(f"{x}\n")

        # A.txt
        with open(os.path.join(self.out_dir, "A.txt"), "w", encoding="utf-8") as f:
            for u, v in self.edges:
                f.write(f"{u}, {v}\n")

        save_txt("edge_labels.txt", self.edge_labels)
        save_txt("node_labels.txt", self.node_labels)
        save_txt("graph_indicator.txt", self.graph_indicator)
        save_txt("graph_labels.txt", self.graph_labels)

        with open(os.path.join(self.out_dir, "node_attributes.txt"), "w", encoding="utf-8") as f:
            for feat in self.node_attributes:
                f.write(",".join(map(lambda z: f"{z:.6g}", feat)) + "\n")

        # stats.json
        stats = self._compute_stats()
        with open(os.path.join(self.out_dir, "stats.json"), "w", encoding="utf-8") as f:
            json.dump(stats, f, ensure_ascii=False, indent=2)

        # optional NPZ (方便你后续直接 np.load 训练/分析)
        if self.export_npz:
            if np is None:
                logger.warning("numpy not available; skip NPZ export.")
            else:
                np.savez_compressed(
                    os.path.join(self.out_dir, "hrg.npz"),
                    edges=np.array(self.edges, dtype=np.int64),
                    edge_labels=np.array(self.edge_labels, dtype=np.int64),
                    node_labels=np.array(self.node_labels, dtype=np.int64),
                    graph_indicator=np.array(self.graph_indicator, dtype=np.int64),
                    graph_labels=np.array(self.graph_labels, dtype=np.int64),
                    node_attributes=np.array(self.node_attributes, dtype=np.float32),
                )

        logger.info(f"Saved HRG files to: {self.out_dir}")

    def _compute_stats(self) -> Dict:
        # basic stats
        num_nodes = len(self.node_labels)
        num_edges = len(self.edges)
        num_graphs = len(self.graph_labels)

        # label distribution
        pos = sum(self.graph_labels)
        neg = num_graphs - pos

        # node type distribution
        type_count = {k: 0 for k in NODE_TYPE_MAP.keys()}
        inv_map = {v: k for k, v in NODE_TYPE_MAP.items()}
        for nl in self.node_labels:
            type_count[inv_map.get(nl, "UNK")] = type_count.get(inv_map.get(nl, "UNK"), 0) + 1

        # edge label distribution
        ecount = {1: 0, 2: 0, 3: 0}
        for el in self.edge_labels:
            ecount[el] = ecount.get(el, 0) + 1

        return {
            "num_graphs": num_graphs,
            "num_nodes": num_nodes,
            "num_edges": num_edges,
            "graph_label_dist": {"benign(0)": neg, "vulnerable(1)": pos},
            "edge_label_dist": ecount,
            "node_type_dist": type_count,
            "feature_dim": NUM_NODE_TYPES + 3,
        }

    # --------------------------
    # Public API
    # --------------------------
    def process_all(self, limit: Optional[int] = None):
        if not os.path.exists(self.raw_dir):
            raise FileNotFoundError(f"raw_dir not found: {self.raw_dir}")

        files = sorted([f for f in os.listdir(self.raw_dir) if f.endswith(self.file_suffix)])
        if limit is not None:
            files = files[: max(0, int(limit))]

        logger.info(f"Parsing {len(files)} files from {self.raw_dir} ...")

        for i, fname in enumerate(files, start=1):
            g_label = int(self.label_fn(fname))
            fpath = os.path.join(self.raw_dir, fname)

            g = self._parse_file(fpath, graph_id=self._graph_id, graph_label=g_label)
            self._merge_graph(g)

            if i % 200 == 0 or i == len(files):
                logger.info(f"Progress: {i}/{len(files)} graphs, nodes={len(self.node_labels)}, edges={len(self.edges)}")

        self._validate()
        self._save_output()


def build_argparser():
    p = argparse.ArgumentParser(description="Parse raw AST txt dumps into HRG intermediate files.")
    p.add_argument("--raw_dir", type=str, required=True, help="Directory containing raw *.txt AST files")
    p.add_argument("--out_dir", type=str, required=True, help="Output directory to save HRG files")
    p.add_argument("--suffix", type=str, default=".txt", help="File suffix to parse (default: .txt)")
    p.add_argument("--limit", type=int, default=None, help="Optional limit on number of files")
    p.add_argument("--export_npz", action="store_true", help="Also export a compressed hrg.npz")
    p.add_argument("--log_level", type=str, default="INFO", help="DEBUG/INFO/WARNING/ERROR")
    return p


def main():
    args = build_argparser().parse_args()
    logger.setLevel(getattr(logging, args.log_level.upper(), logging.INFO))

    proc = ASTProcessor(
        raw_dir=args.raw_dir,
        out_dir=args.out_dir,
        file_suffix=args.suffix,
        export_npz=args.export_npz,
    )
    proc.process_all(limit=args.limit)


if __name__ == "__main__":
    main()
