from transformers import LlamaModel, AutoTokenizer, AutoModel, Trainer, default_data_collator, TrainerCallback, TrainingArguments
from contextlib import nullcontext
import torch
import torch.distributed as dist
import numpy as np
from utils.args import Arguments
from utils.dist import is_dist, set_dist_env
from utils.metrics import accuracy
from utils.peft import create_lora_config
from utils.utils import model_id
from models.LMs import BertClassifier, LlamaClassifier
from data.load import load_data
from data.dataset import NCDataset
from data.sampling import collect_subgraphs
import os
from copy import deepcopy

if __name__ == '__main__':
    if is_dist():
        rank = set_dist_env()
    else:
        rank = 0
     
    config = Arguments().parse_args()
    print(config)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    
    data, text, num_classes = load_data(config.dataset, use_text=True, seed=0)

    if not is_dist() or rank == 0:
        # 保存节点的一跳邻居信息
        neighbor_dir = os.path.join('out',  'neighbors', f"{config.dataset}") 
        os.makedirs(neighbor_dir,  exist_ok=True)
    
        # 获取所有节点的一跳邻居
        all_neighbors = {}
        for node in range(data.num_nodes): 
            # 获取一跳邻居 
            neighbors = torch.nonzero(data.edge_index[1]  == node).squeeze().tolist()
            if isinstance(neighbors, int):  # 处理只有一个邻居的情况
                neighbors = [neighbors]
            source_nodes = data.edge_index[0,  neighbors].unique().tolist()
            all_neighbors[node] = source_nodes
    
        # 保存到文件
        neighbor_file = os.path.join(neighbor_dir,  'one_hop_neighbors.pkl') 
        with open(neighbor_file, 'wb') as f:
            import pickle
            pickle.dump(all_neighbors,  f)
        print(f"Saved one-hop neighbors to {neighbor_file}")
    
        # 可选：保存邻接关系为文本文件便于查看
        txt_file = os.path.join(neighbor_dir,  'neighbors.txt') 
        with open(txt_file, 'w') as f:
            for node, neighbors in all_neighbors.items(): 
                f.write(f"Node  {node} neighbors: {neighbors}\n")
        print(f"Saved neighbor relationships to {txt_file}")