from transformers import LlamaModel, AutoTokenizer, AutoModel, Trainer, default_data_collator, TrainerCallback, TrainingArguments
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils.args import Arguments
from utils.dist import is_dist, set_dist_env
from utils.metrics import accuracy
from utils.peft import create_lora_config
from utils.utils import model_id
from models.LMs import BertClassifier, LlamaClassifier
from data.load import load_data
from data.dataset import NCDataset
from data.sampling import collect_subgraphs
import os
from copy import deepcopy
from models.EdgeEncoder.attention import EdgeEncoder
from models.GNNs.GTN import GTN
from torch.utils.data  import Dataset, DataLoader
import pickle
from tqdm import tqdm
 
class EdgeDataset(Dataset):
    def __init__(self, node_features, neighbor_dict, num_neg_samples=5):
        """
        node_features: 预加载的节点特征 [num_nodes, node_dim]
        neighbor_dict: 邻接字典 {node_id: [neighbor_ids]}
        """
        self.node_features  = node_features
        self.edges  = self._generate_edges(neighbor_dict)
        self.num_neg_samples  = num_neg_samples 
        self.num_nodes  = len(node_features)
 
    def _generate_edges(self, neighbor_dict):
        """构建正样本边列表"""
        edges = []
        for src in neighbor_dict:
            for dst in neighbor_dict[src]:
                edges.append((src,  dst))
        return edges
 
    def __len__(self):
        return len(self.edges) 
 
    def __getitem__(self, idx):
        src, dst = self.edges[idx] 
        
        # 正样本特征 
        h_i = self.node_features[src] 
        h_j = self.node_features[dst] 
        neighbors_i = self._sample_neighbors(src)
        neighbors_j = self._sample_neighbors(dst)
        
        # 负样本特征（随机替换头或尾节点）
        neg_samples = []
        for _ in range(self.num_neg_samples): 
            if np.random.rand()  > 0.5:
                neg_dst = np.random.randint(0,  self.num_nodes) 
                neg_samples.append((src,  neg_dst))
            else:
                neg_src = np.random.randint(0,  self.num_nodes) 
                neg_samples.append((neg_src,  dst))
        
        return {
            'pos_pair': (h_i, h_j, neighbors_i, neighbors_j),
            'neg_pairs': [(h_i, self.node_features[nd],  neighbors_i, self._sample_neighbors(nd)) 
                          for _, nd in neg_samples]
        }
 
    def _sample_neighbors(self, node_id, k=5):
        """采样固定数量邻居"""
        neighbors = neighbor_dict.get(node_id,  [])
        if len(neighbors) > k:
            neighbors = np.random.choice(neighbors,  k, replace=False)
        return self.node_features[neighbors] 

class EdgeEncoderTrainer:
    def __init__(self, node_dim=768, edge_dim=256):
        self.model  = EdgeEncoder(node_dim, edge_dim).cuda()
        self.optimizer  = torch.optim.Adam(self.model.parameters(),  lr=1e-4)
        self.criterion  = nn.TripletMarginLoss(margin=1.0)  # 三元组损失 
 
    def train(self, dataloader, epochs=50):
        self.model.train() 
        for epoch in range(epochs):
            total_loss = 0 
            for batch in dataloader:
                # 正样本编码 
                h_i, h_j, ni, nj = batch['pos_pair']
                h_pos = self.model(  
                    h_i.unsqueeze(0).cuda(), 
                    h_j.unsqueeze(0).cuda(), 
                    ni.cuda(), 
                    nj.cuda() 
                )
                
                # 负样本编码
                losses = []
                for neg_pair in batch['neg_pairs']:
                    h_i_neg, h_j_neg, ni_neg, nj_neg = neg_pair 
                    h_neg = self.model( 
                        h_i_neg.unsqueeze(0).cuda(), 
                        h_j_neg.unsqueeze(0).cuda(), 
                        ni_neg.cuda(), 
                        nj_neg.cuda() 
                    )
                    # 计算三元组损失 
                    losses.append(self.criterion(h_pos,  h_pos, h_neg))  # anchor=pos, positive=pos, negative=neg 
                
                # 反向传播
                loss = torch.mean(torch.stack(losses)) 
                self.optimizer.zero_grad() 
                loss.backward() 
                self.optimizer.step() 
                
                total_loss += loss.item() 
            
            print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")
 
    def save_model(self, path):
        torch.save(self.model.state_dict(),  path)

def generate_and_save_edge_encodings(model, node_features, neighbor_dict, save_path):
    """
    生成全图的边编码并保存
    Args:
        model: 训练好的EdgeEncoder实例 
        node_features: 节点特征矩阵 [num_nodes, node_dim]
        neighbor_dict: 邻接字典 {node_id: [neighbor_ids]}
        save_path: 边编码保存路径 (.pkl 或 .pt)
    """
    model.eval() 
    edge_encodings = {}
    
    # 获取所有边 
    edges = []
    for src in neighbor_dict:
        for dst in neighbor_dict[src]:
            edges.append((src,  dst))
    
    # 批量生成编码 
    with torch.no_grad(): 
        for src, dst in tqdm(edges, desc="Generating edge encodings"):
            h_i = node_features[src].unsqueeze(0).cuda()
            h_j = node_features[dst].unsqueeze(0).cuda()
            neighbors_i = node_features[neighbor_dict[src]].cuda()
            neighbors_j = node_features[neighbor_dict[dst]].cuda()
            
            # 生成边编码 [1, edge_dim] -> [edge_dim]
            h_eij = model(h_i, h_j, neighbors_i, neighbors_j).squeeze(0)
            edge_encodings[(src, dst)] = h_eij.cpu() 
    
    # 保存为pickle文件（保留字典结构）
    with open(save_path, 'wb') as f:
        pickle.dump(edge_encodings,  f)
    
    print(f"Edge encodings saved to {save_path}")

if __name__ == '__main__':
    bert_x_path = os.path.join('out', 'lm_emb', 'roberta-base', f'cora.pt')
    node_features = torch.load(bert_x_path)
    neighbor_path = os.path.join('out', 'neighbors', 'cora', f'one_hop_neighbors.pkl')
    with open(neighbor_path, 'rb') as f:
        neighbor_dict = pickle.load(f) 
    
    # 创建数据集
    dataset = EdgeDataset(node_features, neighbor_dict)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # 训练
    trainer = EdgeEncoderTrainer()
    trainer.train(dataloader,  epochs=50)
    
    generate_and_save_edge_encodings(
        model=trainer.model, 
        node_features=node_features,
        neighbor_dict=neighbor_dict,
        save_path=os.path.join('out',  'edge_encodings', 'cora.pkl') 
    )
