import time
import random
import numpy as np
import os
import torch
from collections import defaultdict
import dgl
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data, Batch
from config import *
from torch_geometric.datasets import DBLP, IMDB, HGBDataset, OGB_MAG
from torch_geometric.loader import NeighborLoader
import math
import glob
import pandas as pd
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.nn import TransE
import torch.optim as optim
import json, csv
from typing import List
from transformers import AutoTokenizer, AutoModel

def set_seed(seed):
    if seed == 0:
        seed = int(time.time())
    random.seed(seed)
    np.random.seed(seed)
    np.random.RandomState(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    return seed

'''
def encode_texts(texts, batch_size, model, tokenizer) -> torch.Tensor:
    """
    Tokenize & encode a list of texts into mean-pooled embeddings.
    Returns Tensor [len(texts), hidden_size].
    """
    all_embs = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            enc = tokenizer(batch,
                            padding=True,
                            truncation=True,
                            max_length=512,
                            return_tensors='pt').to(device)
            out = model(**enc).last_hidden_state  # [B, L, H]
            mask = enc.attention_mask.unsqueeze(-1)  # [B, L, 1]
            summed = (out * mask).sum(dim=1)         # [B, H]
            counts = mask.sum(dim=1).clamp(min=1)    # [B, 1]
            embs = summed / counts                   # [B, H]
            all_embs.append(embs.cpu())
    return torch.cat(all_embs, dim=0)

def build_full_graph():
    # --- Load raw users & collect input texts ---
    users = json.load(open(os.path.join('TWIBOT', 'node.json')))
    print(users[0])
    print(users[1])
    for u in users:
        try:
            print(u['ID'])
        except:
            print(u)
            exit()
    user_ids = [str(u['ID']) for u in users]
    id2idx = {uid: i for i, uid in enumerate(user_ids)}
    N_u = len(user_ids)

    # Prepare texts
    user_texts = []
    tweet_texts = []
    tweet_owner = []

    for u in users:
        uid = str(u['ID'])
        # stringify profile+domains
        profile_str = json.dumps(u['profile'], ensure_ascii=False)
        domains_str = " ".join(u.get('domain', []))
        user_texts.append(f"PROFILE: {profile_str} DOMAINS: {domains_str}")

        for tw in u.get('tweet', []):
            tweet_texts.append(tw)
            tweet_owner.append(id2idx[uid])

    MODEL_NAME = 'bert-base-uncased'
    device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model     = AutoModel.from_pretrained(MODEL_NAME).to(device)
    model.eval()

    # --- Batch‐encode user features & tweet features ---
    user_x  = encode_texts(user_texts, 16, model, tokenizer)   # [N_u, H]
    tweet_x = encode_texts(tweet_texts, 32, model, tokenizer) if tweet_texts else torch.zeros((0, model.config.hidden_size))
    N_t     = tweet_x.size(0)

    # --- Load user–user edges (friend/follow) ---
    rel2type = {'friend': 1, 'follow': 2}
    srcs, dsts, etypes = [], [], []
    with open(edges_csv) as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            r = row['relation']
            if r not in rel2type: continue
            u = id2idx[row['source_id'][1:]]
            v = id2idx[row['target_id'][1:]]
            t = rel2type[r]
            srcs += [u, v]; dsts += [v, u]; etypes += [t, t]

    # --- Load labels (human=0, bot=1), unlabeled = -1 ---
    y = torch.full((N_u + N_t,), -1, dtype=torch.long)
    lbl_map = {'human':0, 'bot':1}
    with open(labels_csv) as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            u = id2idx[row['id'][1:]]
            y[u] = lbl_map[row['label']]

    # --- Add post edges (type=0) between user<->their tweets ---
    for ti, ui in enumerate(tweet_owner):
        t = N_u + ti
        srcs += [ui, t]; dsts += [t, ui]; etypes += [0, 0]

    # --- Assemble final Data ---
    x = torch.cat([user_x, tweet_x], dim=0)             # [N_u+N_t, EMB_DIM]
    edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
    edge_type  = torch.tensor(etypes, dtype=torch.long)
    node_type  = torch.cat([
        torch.zeros(N_u, dtype=torch.long),
        torch.ones(N_t, dtype=torch.long)
    ], dim=0)  # 0=user, 1=tweet

    return Data(
        x=x,
        edge_index=edge_index,
        edge_type=edge_type,
        node_type=node_type,
        y=y
    ), N_u

def preprocess_Twibot(data):
    data_full, Nu = build_full_graph()
    
    k = 50000
    idx = list(range(Nu))
    labels = data_full.y[:Nu].tolist()
    # train_test_split with stratify to pick exactly k
    centers, _ = train_test_split(
        idx,
        train_size = k,
        stratify   = labels
    )

    loader = NeighborLoader(
        data_full,
        num_neighbors = [-1, -1],  # take *all* neighbors each hop
        input_nodes   = centers,        # only these user nodes
        batch_size    = 1,              # one ego‐graph per batch
        shuffle       = False
    )

    data_list = []
    for batch in loader:
        g = Data(
            x          = batch.x,
            edge_index = batch.edge_index,
            edge_type  = batch.edge_type,
            node_type  = batch.node_type,
            y          = batch.y
        )
        if hasattr(g, 'batch'):
            del g.batch
        data_list.append(g)

    graphs = Batch.from_data_list(data_list)
    #graphs = data_list
    # 1.4) Save single Batch
    torch.save(graphs, os.path.join(data, f'{data}.pt'))
    return graphs
'''

def orthogonal_projection(k, d):
    random_matrix = np.random.randn(k, k)
    Q, _ = np.linalg.qr(random_matrix)
    type_vectors = Q[:, :d]
    return type_vectors


def preprocess_kaggleRCDD(data): 
    k = 50000
    num_hops = 1

    nodes_df = pd.read_csv(
        os.path.join(data, 'AliRCD_ICDM_nodes.csv'),
        header=None, names=['nid', 'ntype', 'feat_str']
    )
    # parse features
    feats = np.stack([
        np.fromstring(s, dtype=np.float32, sep=':')
        for s in nodes_df['feat_str'].values
    ])
    X = torch.from_numpy(feats)
    N, F = X.size()

    # map node_id -> idx 0..N-1
    node_ids = nodes_df['nid'].tolist()
    id2idx = {nid: i for i, nid in enumerate(node_ids)}

    # encode node types
    unique_nt = sorted(nodes_df['ntype'].unique())
    ntype2idx = {nt: i for i, nt in enumerate(unique_nt)}
    node_type = torch.tensor(
        nodes_df['ntype'].map(ntype2idx).tolist(),
        dtype=torch.long
    )  # [N]

    # 2) Load edges
    edges_df = pd.read_csv(
        os.path.join(data, 'AliRCD_ICDM_edges.csv'),
        header=None, names=['src', 'dst', '_', '__', '___']
    )
    # map to idx
    src = edges_df['src'].map(id2idx).values
    dst = edges_df['dst'].map(id2idx).values
    # build undirected edges
    edge_index = torch.tensor(
        np.vstack([np.concatenate([src, dst]), np.concatenate([dst, src])]),
        dtype=torch.long
    )

    # Compute edge types by (src_ntype, dst_ntype) pairs:
    src_nt = node_type[edge_index[0]]
    dst_nt = node_type[edge_index[1]]
    pair2idx = {}
    edge_types = []
    cnt = 0
    for s, d in zip(src_nt.tolist(), dst_nt.tolist()):
        key = (int(s), int(d))
        if key not in pair2idx:
            pair2idx[key] = cnt
            cnt += 1
        edge_types.append(pair2idx[key])
    edge_type = torch.tensor(edge_types, dtype=torch.long)

    # 3) Load labels
    def load_lbl(fn):
        df = pd.read_csv(os.path.join(data, fn),
                         header=None, names=['nid','lbl'])
        return dict(zip(df['nid'], df['lbl']))
    train_lbl = load_lbl('AliRCD_ICDM_train_labels.csv')
    test_lbl  = load_lbl('AliRCD_ICDM_test_labels.csv')
    labels = {**train_lbl, **test_lbl}
    # Only consider nodes that have labels
    labeled_idx = [id2idx[nid] for nid in labels.keys() if nid in id2idx]
    y_full = torch.full((N,), -1, dtype=torch.long)
    for nid, lbl in labels.items():
        if nid in id2idx:
            y_full[id2idx[nid]] = lbl
    # filter out unlabeled
    labeled_idx = torch.tensor(labeled_idx, dtype=torch.long)
    y_labeled = y_full[labeled_idx]

    # 4) Sample K nodes per class
    classes = y_labeled.unique().tolist()
    class2idxs = {c: (labeled_idx[y_labeled == c]).tolist() for c in classes}
    sampled_centers = []
    for c, idxs in class2idxs.items():
        k = min(k, len(idxs))
        sampled = np.random.choice(idxs, size=k, replace=False).tolist()
        sampled_centers.extend(sampled)

    full_data = Data(
        x          = X.float(),
        edge_index = edge_index,
        node_type  = node_type,
        edge_type  = edge_type
    )

    center_global_idx = torch.tensor(
        sampled_centers,
        dtype=torch.long
    )


    loader = NeighborLoader(
        data           = full_data,
        input_nodes    = center_global_idx,
        num_neighbors  = [-1, 10],      # -1=all 1-hop, 5 2-hop
        batch_size     = 1,
        shuffle        = False
    )

    # 5) — Assign labels to each ego-graph
    data_list = []
    for idx_global, sub_data in zip(center_global_idx.tolist(), loader):
        # sub_data is already the ego-net around one center
        y_label = int(y_full[idx_global])
        sub_data.y = torch.tensor([y_label], dtype=torch.long)
        # cleanup unwanted attributes
        for attr in ['batch', 'n_id', 'e_id']:
            if hasattr(sub_data, attr):
                delattr(sub_data, attr)
        data_list.append(sub_data)


    graphs = Batch.from_data_list(data_list)
    #graphs = data_list
    # 1.4) Save single Batch
    torch.save(graphs, os.path.join(data, f'{data}.pt'))
    return graphs

def preprocess_kagglePDNS(data):
    k = 50000
    #num_hops = 1
    embed_dim = 32

    domain_features_csv = os.path.join(data, 'domain_features.csv')
    ip_features_csv = os.path.join(data, 'ip_features.csv')

    # 1) Load and prepare node features
    dom_df = pd.read_csv(domain_features_csv)
    ip_df  = pd.read_csv(ip_features_csv)

    # Extract domain node IDs and labels
    domain_ids = dom_df['domain_node'].to_numpy()
    labels_raw = dom_df['suspicious_tld'].to_numpy()
    # encode labels to ints 0,1,2,...
    label2int = {lab: i for i, lab in enumerate(np.unique(labels_raw))}
    domain_labels = np.array([label2int[lab] for lab in labels_raw],
                                      dtype=np.int64)

    # Extract feature arrays
    dom_feats = dom_df.drop(['domain_node','suspicious_tld'], axis=1).to_numpy()
    ip_feats  = ip_df.drop(['ip_node'], axis=1).to_numpy()

    # pad to same dim
    D_dom, D_ip = dom_feats.shape[1], ip_feats.shape[1]
    D_pad = max(D_dom, D_ip)
    def pad(feats, D_out):
        P = D_out - feats.shape[1]
        if P>0:
            return np.hstack([feats, np.zeros((feats.shape[0], P), dtype=feats.dtype)])
        return feats
    dom_p = pad(dom_feats, D_pad)
    ip_p  = pad(ip_feats,  D_pad)

    # Gaussian projection to embed_dim
    proj = torch.randn(D_pad, embed_dim)
    dom_emb = torch.from_numpy(dom_p).float() @ proj
    ip_emb  = torch.from_numpy(ip_p).float()  @ proj

    # Build global node index mapping
    # domain nodes: [0 .. N_dom-1],  ip nodes: [N_dom .. N_dom+N_ip-1]
    N_dom = dom_emb.size(0)
    N_ip  = ip_emb.size(0)
    node_features = torch.cat([dom_emb, ip_emb], dim=0)
    node_type = torch.cat([
        torch.zeros(N_dom, dtype=torch.long),    # 0=domain
        torch.ones (N_ip,  dtype=torch.long)     # 1=IP
    ], dim=0)

    # 2) Load and build edges
    edge_files = glob.glob(os.path.join(data, 'timestamp_*_edges.csv'))
    etype_map = {'resolves':0, 'similar':1, 'apex':2}
    src_list, tgt_list, et_list = [], [], []

    # fast lookup maps from original ID to new index
    dom_map = {nid:i for i,nid in enumerate(domain_ids)}
    ip_map  = {nid:i+N_dom for i,nid in enumerate(ip_df['ip_node'].to_numpy())}

    for f in edge_files:
        df = pd.read_csv(f)
        for src, tgt, et in zip(df['source'], df['target'], df['type']):
            t = etype_map[et]
            # resolves: domain->IP
            if et=='resolves':
                i = dom_map[src]; j = ip_map[tgt]
            else:
                # similar or apex: domain->domain
                i = dom_map[src]; j = dom_map[tgt]
            src_list.append(i)
            tgt_list.append(j)
            et_list.append(t)

            # add reverse (undirected)
            src_list.append(j)
            tgt_list.append(i)
            et_list.append(t)

    edge_index = torch.tensor([src_list, tgt_list], dtype=torch.long)
    edge_type  = torch.tensor(et_list, dtype=torch.long)

    def stratified_sample(k, domain_ids, labels):
        """
        Sample exactly k elements from domain_ids (len N) in proportion to
        the class counts in labels.
        """
        N = len(labels)
        classes, counts = np.unique(labels, return_counts=True)
        target = []
        # exact quota
        quotas = counts / N * k
        floors = np.floor(quotas).astype(int)
        residual = k - floors.sum()
        # distribute remainder by largest fractional parts
        fracs = quotas - floors
        order = np.argsort(-fracs)
        for i in order[:residual]:
            floors[i] += 1
        # now floors.sum() == k
        for cls, q in zip(classes, floors):
            candidates = domain_ids[labels==cls]
            if q > len(candidates):
                raise ValueError(f"Not enough nodes in class {cls} to sample {q}")
            choice = np.random.choice(candidates, size=q, replace=False)
            target.extend(choice.tolist())
        np.random.shuffle(target)
        return target

    # 3) Sample k domain centers preserving label distribution
    center_nodes = stratified_sample(k, domain_ids, domain_labels)
    

    full_data = Data(
        x          = node_features.float(),
        edge_index = edge_index,
        node_type  = node_type,
        edge_type  = edge_type
    )

    center_global_idx = torch.tensor(
        [dom_map[n] for n in center_nodes],
        dtype=torch.long
    )


    loader = NeighborLoader(
        data           = full_data,
        input_nodes    = center_global_idx,
        num_neighbors  = [-1, 20],      # -1=all 1-hop, 5 2-hop
        batch_size     = 1,
        shuffle        = False
    )
    
    # 5) — Assign labels to each ego-graph
    data_list = []
    for idx_global, sub_data in zip(center_global_idx.tolist(), loader):
        # sub_data is already the ego-net around one center
        y_label = int(domain_labels[idx_global])
        sub_data.y = torch.tensor([y_label], dtype=torch.long)
        # cleanup unwanted attributes
        for attr in ['batch', 'n_id', 'e_id']:
            if hasattr(sub_data, attr):
                delattr(sub_data, attr)
        data_list.append(sub_data)

    graphs = Batch.from_data_list(data_list)
    #graphs = data_list
    # 1.4) Save single Batch
    torch.save(graphs, os.path.join(data, f'{data}.pt'))

    return graphs

def unify_node_features_gaussian(hetero_data):
    dims = []
    has_x_types = []
    miss_x_types = []

    # 1.1) Inspect each node‐type
    for ntype in hetero_data.node_types:
        data_nt = hetero_data[ntype]
        if hasattr(data_nt, 'x'):
            X = data_nt.x
            dims.append(X.size(1))
            has_x_types.append(ntype)
        else:
            miss_x_types.append(ntype)

    # 1.2) Determine max/min dims
    D_max = max(dims)
    D_min = min(dims) if min(dims) < 64 else 64

    # 1.3) Build shared Gaussian projection matrix (D_max → D_min)
    W = torch.randn(D_max, D_min) / math.sqrt(D_min)

    def pad_or_trunc(X, target_dim):
        N, d = X.size()
        if d == target_dim:
            return X
        elif d > target_dim:
            return X[:, :target_dim]
        else:
            pad = X.new_zeros((N, target_dim - d))
            return torch.cat([X, pad], dim=1)

    for ntype in has_x_types:
        X = hetero_data[ntype].x      # [N, d_orig]
        X_pad = pad_or_trunc(X, D_max)            # [N, D_max]
        hetero_data[ntype].x = X_pad @ W          # [N, D_min]

    # 1.5) Process node‐types without features
    for ntype in miss_x_types:
        n_nodes = hetero_data[ntype].num_nodes
        # random native vectors
        X_rand = torch.randn(n_nodes, D_max)
        hetero_data[ntype].x = X_rand @ W        # [n_nodes, D_min]

    return hetero_data



def hetero_to_homo(sub_data, graph_label):
    node_types = sub_data.node_types
    edge_types = sub_data.edge_types

    # map node‐type → int
    ntype2id = {nt: i for i, nt in enumerate(node_types)}
    # gather sizes & compute offsets
    sizes   = [sub_data[nt].num_nodes for nt in node_types]
    offsets = torch.tensor([0] + sizes[:-1]).cumsum(0)

    # flatten features
    x = torch.cat([sub_data[nt].x for nt in node_types], dim=0)

    # node_type array
    node_type = torch.cat([
        torch.full((sz,), ntype2id[nt], dtype=torch.long)
        for nt, sz in zip(node_types, sizes)
    ], dim=0)

    # flatten edges & record edge_type
    etype2id = {et: i for i, et in enumerate(edge_types)}
    edge_index_list, edge_type_list = [], []
    for et in edge_types:
        src_t, _, dst_t = et
        ei = sub_data[et].edge_index
        if ei.numel() == 0:
            continue
        src_off = offsets[node_types.index(src_t)]
        dst_off = offsets[node_types.index(dst_t)]
        shifted = ei + torch.tensor([[src_off], [dst_off]])
        edge_index_list.append(shifted)
        edge_type_list.append(
            torch.full((shifted.size(1),), etype2id[et], dtype=torch.long)
        )
    
    if edge_index_list:
        edge_index = torch.cat(edge_index_list, dim=1)
        edge_type  = torch.cat(edge_type_list, dim=0)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_type  = torch.empty((0,),    dtype=torch.long)

    return Data(
        x=x.float(),
        edge_index=edge_index,
        node_type=node_type,
        edge_type=edge_type,
        y=torch.tensor([graph_label], dtype=torch.long)
    )


def generate_data_list(dataset, center_name, num_neighbors, y0, y1):
    hetero_data = dataset
    # 3.1) unify & project all node‐features
    hd = unify_node_features_gaussian(hetero_data)

    # 3.2) sample k authors from the two specified classes
    y = hd[center_name].y
    mask = (y == y0) | (y == y1)
    sampled = mask.nonzero(as_tuple=True)[0]

    # 3.3) NeighborLoader: full‐fanout up to h hops
    #num_neighbors = { et: [-1]*h for et in hd.edge_types }
    loader = NeighborLoader(
        hd,
        num_neighbors=num_neighbors,
        input_nodes=(center_name, sampled),
        batch_size=1,
        shuffle=False
    )

    # 3.4) iterate, convert, collect
    data_list = []
    for center, sub_data in zip(sampled, loader):
        # look up label directly
        center_label = hd[center_name].y[center].item()

        # drop loader artifacts
        for nt in sub_data.node_types:
            sub_data[nt].pop('batch', None)
            sub_data[nt].pop('n_id',   None)

        data_list.append(hetero_to_homo(sub_data, center_label))

    return data_list

def generate_node_features(dataset):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset = dataset.to(device)
    num_edge_types = len(torch.unique(dataset.edge_type))

    model = TransE(
        num_nodes=dataset.num_nodes,
        num_relations=num_edge_types,
        hidden_channels=16
    ).to(device)

    loader = model.loader(
        head_index=dataset.edge_index[0],
        rel_type=dataset.edge_type,
        tail_index=dataset.edge_index[1],
        batch_size=1000,
        shuffle=True,
    )

    optimizer = optim.Adam(model.parameters(), lr=0.01)

    def train():
        model.train()
        total_loss = total_examples = 0
        for head_index, rel_type, tail_index in loader:
            optimizer.zero_grad()
            loss = model.loss(head_index, rel_type, tail_index)
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * head_index.numel()
            total_examples += head_index.numel()
        return total_loss / total_examples

    for epoch in range(100):
        loss = train()
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    model = model.cpu()

    x = model.node_emb(torch.LongTensor(np.array(list(range(dataset.num_nodes)))))
    dataset = dataset.cpu()
    dataset.x = x.detach().cpu()
    dataset = dataset.to_heterogeneous(dataset.node_type, dataset.edge_type)
    return dataset, x

def preprocess_PyG(data):
    if data == PYGDBLP:
        dataset = DBLP(data).data
        num_neighbors = { et: [-1, -1, 5] for et in dataset.edge_types }
        data_list = generate_data_list(dataset, 'author', num_neighbors, 0, 1) 
         
    elif data == PYGIMDB: 
        dataset = IMDB(data).data
        num_neighbors = { et: [-1]*3 for et in dataset.edge_types }
        data_list = generate_data_list(dataset, 'movie', num_neighbors, 0, 1)

    elif data == PYGMAG:
        dataset = OGB_MAG(data, 'TransE').data
        num_neighbors = { et: [-1, 3] for et in dataset.edge_types }
        data_list = generate_data_list(dataset, 'paper', num_neighbors, 0, 1)
    
    elif data == PYGFREEBASE:
        dataset = HGBDataset('HGB', 'freebase').data
        dataset = dataset.to_homogeneous(add_node_type=True, add_edge_type=True)
        dataset.y += 1
        if os.path.exists(os.path.join(data, 'x.pt')):
            dataset.x = torch.load(os.path.join(data, 'x.pt'))
            dataset = dataset.to_heterogeneous(dataset.node_type, dataset.edge_type)
        else:
            dataset, x = generate_node_features(dataset)
            torch.save(x, os.path.join(data, 'x.pt'))
        num_neighbors = { et: [-1, 1] for et in dataset.edge_types }
        data_list = generate_data_list(dataset, 'book', num_neighbors, 0, 1)

    graphs = Batch.from_data_list(data_list)
    #grpahs = data_list
    # 1.4) Save single Batch
    torch.save(graphs, os.path.join(data, f'{data}.pt'))

    return graphs
    

def preprocess_TUDataset(data): 

    graph_ind = np.loadtxt(os.path.join(data, f'{data}_graph_indicator.txt'), dtype=int) - 1
    n_nodes = len(graph_ind)

    node_types = np.loadtxt(os.path.join(data, f'{data}_node_labels.txt'), dtype=int)

    feature_dim = int(max(node_types) + 1)

    type_vectors = orthogonal_projection(feature_dim, feature_dim)

    graph_labels = np.loadtxt(os.path.join(data, f'{data}_graph_labels.txt'), dtype=int)
    n_graphs = len(graph_labels)

    edges = np.loadtxt(os.path.join(data, f'{data}_A.txt'), delimiter=',', dtype=int) - 1

    src_all, dst_all = edges[:,0], edges[:,1]
    edge_to_graph    = graph_ind[src_all]

    data_list = []
    g2l = np.full(node_types.shape[0], -1, dtype=np.int64)

    uni_ntypes       = np.unique(node_types)

    edge_type_mapping = {}
    for src in uni_ntypes:
        for dst in uni_ntypes:
            edge_type_mapping[(src, dst)] = len(edge_type_mapping)

    for gid in range(n_graphs):
        # nodes in graph g
        nodes_g = np.nonzero(graph_ind == gid)[0]
        N = len(nodes_g)
        g2l[nodes_g] = np.arange(N)

        # edges with src in g
        e_mask = (edge_to_graph == gid)
        s, d   = src_all[e_mask], dst_all[e_mask]
        s_loc, d_loc = g2l[s], g2l[d]

        edge_index = torch.tensor(np.array([s_loc, d_loc]), dtype=torch.long)

        # node‐types & features
        nt   = node_types[nodes_g]
        nt_t = torch.from_numpy(nt)
        x    = torch.from_numpy(type_vectors[nt_t])
            
        if len(x.shape) == 1:
            x = x.unsqueeze(dim=0)

        # edge‐types
        #et_t = torch.from_numpy(nt[s_loc] * T + nt[d_loc])

        src_types = nt_t[s_loc]
        dst_types = nt_t[d_loc]
            
        # Create edge types using precomputed mapping
        edge_type = torch.zeros(len(src_types), dtype=torch.long)
        for idx, (s, d) in enumerate(zip(src_types.tolist(), dst_types.tolist())):
            edge_type[idx] = edge_type_mapping[(s, d)]

        # graph label
        y = torch.tensor([graph_labels[gid]], dtype=torch.long)

        data_list.append(Data(
            x         = x.float(),
            edge_index= edge_index,
            node_type = nt_t,
            edge_type = edge_type,
            y         = y,
        ))

        g2l[nodes_g] = -1  # reset

    '''
    for i in range(len(data_list)):
        for j in range(len(data_list)):
            try:
                Batch.from_data_list([data_list[i], data_list[j]])
            except:
                print(data_list[i])
                print(data_list[j])
    '''

    # 1.3) Batch them all
    graphs = Batch.from_data_list(data_list)
    #graphs = data_list
    # 1.4) Save single Batch
    torch.save(graphs, os.path.join(data, f'{data}.pt'))

    return graphs


def stratified_split(labels, trainsz, testsz, seed):

    NORMAL = 0
    ABNORMAL = 1

    normalinds = []
    abnormalinds = []
    errors = []

    for i, label in enumerate(labels):
        if label == NORMAL:
            normalinds.append(i)
        elif label == ABNORMAL:
            abnormalinds.append(i)
        else:
            errors.append(i)

    print("Normal graphs: {}, abnormal graphs: {}, abnormal rate: {:.4f}".format(len(normalinds), len(abnormalinds), len(abnormalinds) / len(labels)))

    assert len(errors) == 0, "invalid labels"

    train_normal = np.array(normalinds[: int(trainsz * len(normalinds))])
    val_normal = np.array(normalinds[int(trainsz * len(normalinds)): int((1 - testsz) * len(normalinds))])
    test_normal = np.array(normalinds[int((1 - testsz) * len(normalinds)): ])

    train_abnormal = np.array(abnormalinds[: int(trainsz * len(abnormalinds))])
    val_abnormal = np.array(abnormalinds[int(trainsz * len(abnormalinds)): int((1 - testsz) * len(abnormalinds))])
    test_abnormal = np.array(abnormalinds[int((1 - testsz) * len(abnormalinds)):])

    train_index = np.concatenate((train_normal, train_abnormal))
    val_index = np.concatenate((val_normal, val_abnormal))
    test_index = np.concatenate((test_normal, test_abnormal))

    random.shuffle(train_index)
    random.shuffle(val_index)
    random.shuffle(test_index)

    print("Train size: {}, normal size: {}, abnormal size: {}".format(len(train_index), len(train_normal), len(train_abnormal)))
    print("Val size: {}, normal size: {}, abnormal size: {}".format(len(val_index), len(val_normal), len(val_abnormal)))
    print("Test size: {}, normal size: {}, abnormal size: {}".format(len(test_index), len(test_normal), len(test_abnormal)))

    print("Total size: {}, generate size: {}".format(len(labels), len(train_index) + len(val_index) + len(test_index)))

    return train_index, val_index, test_index
