# data_loader.py
import os
import logging
from collections import defaultdict
from typing import Dict, List, Tuple

import numpy as np
import networkx as nx

logger = logging.getLogger(__name__)


def _read_int_lines(path: str) -> List[int]:
    vals: List[int] = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln:
                vals.append(int(ln))
    return vals


def _read_edge_list(path: str) -> List[Tuple[int, int]]:
    edges: List[Tuple[int, int]] = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split(",")
            if len(parts) < 2:
                continue
            u = int(parts[0].strip())
            v = int(parts[1].strip())
            edges.append((u, v))
    return edges


def load_node_attributes(path: str) -> Dict[int, List[float]]:
    """
    Loads node feature vectors from comma-separated text file.
    Node i corresponds to line i.
    """
    attrs: Dict[int, List[float]] = {}
    if not os.path.exists(path):
        logger.warning("node_attributes.txt not found: %s", path)
        return attrs

    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            try:
                attrs[i] = list(map(float, line.split(",")))
            except ValueError:
                attrs[i] = []
    return attrs


def load_graph_structure(data_dir: str) -> List[nx.Graph]:
    """
    Reconstructs individual NetworkX graphs from global intermediate HRG files.

    Required files:
      - A.txt                : global edge list "u, v" (node ids are global)
      - edge_labels.txt      : edge types aligned with A.txt lines
      - graph_indicator.txt  : graph id for each global node id
      - graph_labels.txt     : label per graph id
      - node_attributes.txt  : feature vector per global node id

    Output:
      List of per-sample NetworkX graphs:
        - nodes: local ids [0..n-1], with node['attr'] = feature vector
        - edges: undirected, with edge['type'] = edge label
        - g.graph['label'] = graph label (0 benign / 1 vulnerable)
        - g.graph['gid']   = graph id
    """
    req = [
        "A.txt",
        "edge_labels.txt",
        "graph_indicator.txt",
        "graph_labels.txt",
        "node_attributes.txt",
    ]
    for fn in req:
        p = os.path.join(data_dir, fn)
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing required file: {p}")

    # graph_indicator may be 1-based (older code) or 0-based (new ast_processor)
    indic_raw = _read_int_lines(os.path.join(data_dir, "graph_indicator.txt"))
    if len(indic_raw) == 0:
        return []

    indic_min = min(indic_raw)
    indic = indic_raw[:] if indic_min == 0 else [x - 1 for x in indic_raw]

    graph_labels = _read_int_lines(os.path.join(data_dir, "graph_labels.txt"))
    edges = _read_edge_list(os.path.join(data_dir, "A.txt"))
    edge_types = _read_int_lines(os.path.join(data_dir, "edge_labels.txt"))
    node_attrs = load_node_attributes(os.path.join(data_dir, "node_attributes.txt"))

    num_graphs = max(indic) + 1

    # Global -> graph id
    g_sizes = defaultdict(int)
    for gid in indic:
        g_sizes[gid] += 1

    # Offsets for global node id -> local node id
    node_offsets = [0] * (num_graphs + 1)
    cur = 0
    for gid in range(num_graphs):
        node_offsets[gid] = cur
        cur += g_sizes[gid]
    node_offsets[num_graphs] = cur

    graphs: List[nx.Graph] = []
    for gid in range(num_graphs):
        g = nx.Graph()
        g.graph["gid"] = gid
        g.graph["label"] = int(graph_labels[gid]) if gid < len(graph_labels) else 0
        graphs.append(g)

    # Add nodes
    for global_nid, gid in enumerate(indic):
        local_id = global_nid - node_offsets[gid]
        graphs[gid].add_node(local_id, attr=node_attrs.get(global_nid, []))

    # Add edges (skip cross-graph edges)
    for ei, (u, v) in enumerate(edges):
        if u < 0 or v < 0 or u >= len(indic) or v >= len(indic):
            continue
        g1, g2 = indic[u], indic[v]
        if g1 != g2:
            continue

        off = node_offsets[g1]
        lu, lv = u - off, v - off
        et = int(edge_types[ei]) if ei < len(edge_types) else 1
        graphs[g1].add_edge(lu, lv, type=et)

    logger.info("Loaded %d graphs from %s", len(graphs), data_dir)
    return graphs
