import os
import pickle
import numpy as np
import torch
import networkx as nx
import math
import random
from tqdm import tqdm
import multiprocessing

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
os.environ.setdefault("KMP_INIT_AT_FORK", "FALSE")
os.environ.setdefault("MKL_THREADING_LAYER", "GNU")
torch.set_num_threads(1)

# ================= 依赖检查与回退 =================
try:
    from BAPG import connected_subgraph
except ImportError:
    print("【提示】未检测到 BAPG.py，使用内置的子图采样函数作为替代。")
    def connected_subgraph(G, ratio=1.0):
        if ratio >= 1.0:
            return G, list(G.nodes())
        n_nodes = G.number_of_nodes()
        n_sub = int(n_nodes * ratio)
        if n_sub == 0: return G, list(G.nodes())
        start_node = random.choice(list(G.nodes()))
        # 简单的 BFS 采样
        sub_nodes = list(nx.bfs_tree(G, start_node))[:n_sub]
        subG = G.subgraph(sub_nodes).copy()
        return subG, sub_nodes

# ================= 配置参数 =================
dataset_name = 'DHFR'
input_pkl_path = f'data/{dataset_name}_weighted.pkl'

# 全局随机种子
seed = 321
# 子图保留比例 (1.0 代表全图作为 Source)
subgraph_ratio = 1.0 
# 噪音比例 (例如 0.5 表示增加 50% 的点作为 Outliers)
outlier_ratio = 1.0

# 随机噪声参数
noise_scale = float(os.environ.get("ATTACK_NOISE_SCALE", "1.0"))

output_pkl_path = f'data/{dataset_name}_noise_weighted_{outlier_ratio}_randomP_attacked.pkl'

# 并行核心数
_workers_env = os.environ.get("ATTACK_NUM_WORKERS", "").strip()
if _workers_env.isdigit():
    NUM_WORKERS = max(1, int(_workers_env))
else:
    NUM_WORKERS = max(1, min(8, multiprocessing.cpu_count() - 2))


# ================= 核心攻击函数 (Random Noise) =================
def random_noise_attack(G_full, outlier_ratio=0.5, noise_scale=1.0):
    """
    直接随机生成 Outlier-to-Origin/Outlier-to-Outlier 噪声，不做梯度优化。
    保持原图 (Origin-Origin) 不变，并生成对称矩阵。
    """
    D_orig = nx.to_numpy_array(G_full)
    N_orig = D_orig.shape[0]
    n_outliers = math.ceil(N_orig * outlier_ratio)
    if n_outliers <= 0:
        return nx.from_numpy_array(D_orig, create_using=nx.Graph)

    N_new = N_orig + n_outliers
    D_new = np.zeros((N_new, N_new), dtype=np.float32)
    # 保持原图区域完全不变
    D_new[:N_orig, :N_orig] = D_orig

    base_scale = float(np.max(D_orig)) if np.max(D_orig) > 0 else 1.0
    noise_max = base_scale * noise_scale
    # 只对涉及 Outlier 的区域加噪声
    out_slice = slice(N_orig, N_new)
    noise_oo = np.random.uniform(0.0, noise_max, size=(n_outliers, n_outliers)).astype(np.float32)
    noise_ox = np.random.uniform(0.0, noise_max, size=(n_outliers, N_orig)).astype(np.float32)
    D_new[out_slice, out_slice] = noise_oo
    D_new[out_slice, :N_orig] = noise_ox
    D_new[:N_orig, out_slice] = noise_ox.T

    D_new = 0.5 * (D_new + D_new.T)
    np.fill_diagonal(D_new, 0.0)
    return nx.from_numpy_array(D_new, create_using=nx.Graph)


# ================= Worker 函数 =================
def process_single_graph(args):
    idx_in_list, G_full, config = args
    
    local_seed = config['seed'] + idx_in_list
    random.seed(local_seed)
    np.random.seed(local_seed)
    torch.manual_seed(local_seed)
    
    if G_full.number_of_nodes() < 5:
        return None
    
    try:
        subG, gt_idx = connected_subgraph(G_full, config['subgraph_ratio'])
        
        G_target_attacked = random_noise_attack(
            G_full=G_full,
            outlier_ratio=config['outlier_ratio'],
            noise_scale=config['noise_scale']
        )
        
        return (subG, G_target_attacked, gt_idx)
        
    except Exception as e:
        # 可以在这里 print(e) 调试，但 tqdm 下建议少 print
        return None


# ================= 主流程 =================
def generate_data_parallel():
    print(f"Dataset: {dataset_name}")
    print(f"Workers: {NUM_WORKERS}")
    print(f"Loading raw graphs from {input_pkl_path}...")
    
    if not os.path.exists(input_pkl_path):
        print(f"Error: 文件不存在 {input_pkl_path}")
        return

    with open(input_pkl_path, 'rb') as f:
        full_graphs = pickle.load(f)
        
    print(f"Total graphs: {len(full_graphs)}")
    print("Strategy: Random Noise Injection (no gradient optimization)")
    print(f"Params: Outlier={outlier_ratio}, NoiseScale={noise_scale}")
    
    config = {
        'seed': seed,
        'subgraph_ratio': subgraph_ratio,
        'outlier_ratio': outlier_ratio,
        'noise_scale': noise_scale
    }
    
    tasks = [(i, G, config) for i, G in enumerate(full_graphs)]
    attacked_dataset = []
    
    with multiprocessing.Pool(processes=NUM_WORKERS) as pool:
        # chunksize 设置稍微大一点有助于 CPU 密集型任务
        results_iter = pool.imap_unordered(process_single_graph, tasks, chunksize=10)
        
        for result in tqdm(results_iter, total=len(tasks), desc="Attacking"):
            if result is not None:
                attacked_dataset.append(result)
    
    print(f"Saving {len(attacked_dataset)} pairs to {output_pkl_path}...")
    with open(output_pkl_path, 'wb') as f:
        pickle.dump(attacked_dataset, f)
        
    print("Done! 数据生成完毕。")

if __name__ == '__main__':
    # Linux 下推荐使用 fork 启动，速度更快且兼容性更好
    # Windows 会自动退化为 spawn
    try:
        multiprocessing.set_start_method('fork')
    except RuntimeError:
        pass 
        
    generate_data_parallel()
