



import os
import os.path as osp
import random
from tqdm import tqdm

import numpy as np
import networkx as nx
import torch


from molecular import generate_real_dataset
from BA3_loc import find_gd


_MEAN = [1.5, 2.0, 1.2, 1.3, 1.8]
_STD  = [1.5, 2.0, 1.2, 1.3, 1.8]


def generate_motif_easy_with_noise(
    num_samples: int = 1000,
    output_filename: str = 'motif_test.npy',
    feature_noise: float = 0.05,
    edge_del_prob: float = 0.02,
    edge_add_prob: float = 0.02
):
    data_dir = './data/motif/'
    os.makedirs(data_dir, exist_ok=True)

    node_features_list = []
    edge_index_list    = []
    label_list         = []
    ground_truth_list  = []
    role_id_list       = []
    pos_list           = []

    for _ in tqdm(range(num_samples), desc='Generating Noisy Motif Dataset'):

        G, role_id, label, edge_index, *_ = generate_real_dataset()


        for n, attrs in G.nodes(data=True):
            if not any(k in attrs for k in ('feature','features','feat')):
                G.nodes[n]['feature'] = np.random.normal(
                    loc=_MEAN, scale=_STD
                ).astype(np.float32)


        for n, attrs in G.nodes(data=True):
            if 'feature' in attrs:
                orig = attrs['feature']
            elif 'features' in attrs:
                orig = attrs['features']
            else:
                orig = attrs['feat']
            noise = np.random.normal(0.0, feature_noise, size=orig.shape).astype(np.float32)
            G.nodes[n]['feature'] = (orig + noise)


        edges = list(G.edges())

        num_del = int(edge_del_prob * len(edges))
        if num_del > 0:
            del_edges = random.sample(edges, num_del)
            G.remove_edges_from(del_edges)

        nodes = list(G.nodes())
        num_add = int(edge_add_prob * G.number_of_edges())
        for _i in range(num_add):
            u, v = random.sample(nodes, 2)
            if not G.has_edge(u, v):
                G.add_edge(u, v)


        ei = torch.tensor(list(G.edges())).t().contiguous().cpu().numpy()
        if ei.shape[0] != 2 and ei.shape[1] == 2:
            ei = ei.T
        edge_idx = ei.astype(np.int64)


        ids_arr = np.array(role_id, dtype=np.int64)
        gt = find_gd(edge_idx, ids_arr)


        features = []
        for i in sorted(G.nodes()):
            attrs = G.nodes[i]
            if 'feature' in attrs:
                feat = attrs['feature']
            elif 'features' in attrs:
                feat = attrs['features']
            else:
                feat = attrs['feat']
            features.append(feat)
        node_feats = np.vstack(features).astype(np.float32)
        pos_arr = np.array(list(nx.spring_layout(G).values()), dtype=np.float32)


        node_features_list.append(node_feats)
        edge_index_list.append(edge_idx)
        label_list.append(int(label))
        ground_truth_list.append(gt)
        role_id_list.append(ids_arr)
        pos_list.append(pos_arr)


    print(f"#Graphs: {num_samples}    Avg nodes: {np.mean([nf.shape[0] for nf in node_features_list]):.2f}")
    save_path = osp.join(data_dir, output_filename)
    np.save(save_path, {
        'node_features': node_features_list,
        'edge_index'   : edge_index_list,
        'label'        : label_list,
        'ground_truth' : ground_truth_list,
        'role_id'      : role_id_list,
        'pos'          : pos_list
    })
    print(f"Saved noisy motif dataset to: {save_path}")


if __name__ == '__main__':
    generate_motif_easy_with_noise()
