import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from core.data_utils.data_loader import ListDataset
from torch.nn import MultiheadAttention
from core.llm.lm import TextEncoder
from torch_geometric.utils import negative_sampling
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score, precision_recall_curve
import numpy as np
import networkx as nx
import torch
import json
from torch_sparse import SparseTensor
from torch.cuda.amp import autocast, GradScaler


class EdgeDataset(Dataset):
    def __init__(self, edges, labels, device='cpu'):
        self.edges = edges
        self.labels = labels
        self.device = device

    def __len__(self):
        return len(self.edges)

    def __getitem__(self, idx):
        edge = self.edges[idx].to(self.device)
        label = self.labels[idx].to(self.device)
        return edge, label


class NodeGCN(nn.Module):
    def __init__(self, cfg):
        super(NodeGCN, self).__init__()
        self.cfg = cfg
        self.embedding = nn.Embedding(cfg.gnn.vocab_size, cfg.gcn.source_in_channels)
        self.convs = nn.ModuleList()
        in_channels = cfg.gcn.source_in_channels
        hidden_channels = cfg.gnn.hidden_channels
        num_layers = cfg.gnn.num_layers
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(in_channels, hidden_channels))
            in_channels = hidden_channels
        self.convs.append(GCNConv(hidden_channels, cfg.gnn.out_channels))

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.embedding.weight
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
        return x


class NodeGAT(nn.Module):
    def __init__(self, cfg):
        super(NodeGAT, self).__init__()
        self.cfg = cfg
        self.embedding = nn.Embedding(cfg.gnn.vocab_size, cfg.gnn.source_in_channels)
        self.convs = nn.ModuleList()
        in_channels = cfg.gnn.source_in_channels
        hidden_channels = cfg.gnn.hidden_channels
        heads = cfg.gnn.heads
        num_layers = cfg.gnn.num_layers
        for i in range(num_layers - 1):
            conv = GATConv(in_channels, hidden_channels, heads=heads, concat=True)
            self.convs.append(conv)
            in_channels = hidden_channels * heads
        self.convs.append(GATConv(in_channels, cfg.gnn.out_channels, heads=1, concat=False))

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.embedding.weight
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.elu(x)
        return x


class GAT(torch.nn.Module):
    def __init__(self, cfg, label_num, in_channels, vocab_size, init_embedding=False, is_classification=False):
        super(GAT, self).__init__()
        if init_embedding:
            self.embedding = nn.Embedding(vocab_size, in_channels)
        self.conv1 = GATConv(in_channels, cfg.gnn.hidden_channels, heads=cfg.gnn.heads, concat=True)
        self.conv2 = GATConv(cfg.gnn.hidden_channels * cfg.gnn.heads, cfg.gnn.hidden_channels, heads=cfg.gnn.heads,
                             concat=True)
        self.conv3 = GATConv(cfg.gnn.hidden_channels * cfg.gnn.heads, cfg.gnn.hidden_channels, heads=cfg.gnn.heads,
                             concat=True)
        if is_classification:
            self.conv4 = GATConv(cfg.gnn.hidden_channels * cfg.gnn.heads, label_num, heads=1)
        else:
            self.conv4 = GATConv(cfg.gnn.hidden_channels * cfg.gnn.heads, cfg.gnn.out_channels, heads=1)

    def forward(self, edge_index, return_attention_weights=False, x=None):
        if x is None:
            x = self.embedding.weight
        if return_attention_weights:
            x, attn_weights1 = self.conv1(x, edge_index, return_attention_weights=True)
            x = F.elu(x)
            x, attn_weights2 = self.conv2(x, edge_index, return_attention_weights=True)
            x = F.elu(x)
            x, attn_weights3 = self.conv4(x, edge_index, return_attention_weights=True)
            return x, (attn_weights1, attn_weights2, attn_weights3)
        else:
            x1 = self.conv1(x, edge_index)
            x1 = F.elu(x1)
            x2 = self.conv2(x1, edge_index)
            x2 = F.elu(x2)
            x3 = self.conv3(x2, edge_index)
            x3 = F.elu(x3)
            x4 = self.conv4(x3, edge_index)
            return x4


class EdgeGCN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embedding = nn.Embedding(cfg.gnn.vocab_size, cfg.gcn.in_channels)
        self.convs = nn.ModuleList()
        in_channels = cfg.gcn.in_channels
        hidden_channels = cfg.gcn.hidden_channels
        num_layers = cfg.gcn.num_layers
        for i in range(num_layers - 1):
            self.convs.append(GCNConv(in_channels, hidden_channels))
            in_channels = hidden_channels
        self.convs.append(GCNConv(hidden_channels, cfg.gcn.out_channels))

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.embedding.weight
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x


class EdgeGAT(nn.Module):
    def __init__(self, cfg):
        super(EdgeGAT, self).__init__()
        self.cfg = cfg
        self.embedding = nn.Embedding(cfg.gnn.vocab_size, cfg.gnn.source_in_channels)
        self.dropout_rate = getattr(cfg.gnn, 'dropout', 0.5)
        self.convs = nn.ModuleList()
        in_channels = cfg.gnn.source_in_channels
        hidden_channels = cfg.gnn.hidden_channels
        heads = cfg.gnn.heads
        num_layers = cfg.gnn.num_layers
        for i in range(num_layers - 1):
            self.convs.append(
                GATConv(in_channels, hidden_channels, heads=heads, concat=True, dropout=self.dropout_rate))
            in_channels = hidden_channels * heads
        self.convs.append(GATConv(in_channels, cfg.gnn.out_channels, heads=1, concat=False, dropout=self.dropout_rate))

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.embedding.weight
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.elu(x)
                x = F.dropout(x, p=self.dropout_rate, training=self.training)
        return x


class FeatureQueue:
    def __init__(self, dim, size=4096):
        self.size = size
        self.dim = dim
        self.features = torch.zeros(size, dim).to("cuda")
        self.labels = torch.full((size,), -1, dtype=torch.long).to("cuda")
        self.ptr = 0

    @torch.no_grad()
    def enqueue(self, feats, labels):
        B = feats.shape[0]
        if self.ptr + B <= self.size:
            self.features[self.ptr:self.ptr + B] = feats.detach()
            self.labels[self.ptr:self.ptr + B] = labels.detach()
            self.ptr = (self.ptr + B) % self.size
        else:
            overflow = self.ptr + B - self.size
            self.features[self.ptr:] = feats[:B - overflow]
            self.features[:overflow] = feats[B - overflow:]
            self.labels[self.ptr:] = labels[:B - overflow]
            self.labels[:overflow] = labels[B - overflow:]
            self.ptr = overflow

    def get_negatives(self, label, K):
        mask = self.labels != label
        if mask.sum() == 0:
            return None
        neg_feats = self.features[mask]
        if neg_feats.size(0) < K:
            return neg_feats
        idx = torch.randperm(neg_feats.size(0))[:K]
        return neg_feats[idx]


class CrossModalAttention(nn.Module):
    def __init__(self, text_dim, graph_dim, num_heads=4):
        super().__init__()
        self.hidden_dim = max(text_dim, graph_dim)
        self.text_to_graph_proj = nn.Linear(text_dim, self.hidden_dim)
        self.graph_to_text_proj = nn.Linear(graph_dim, self.hidden_dim)
        self.text_to_graph_attn = MultiheadAttention(embed_dim=self.hidden_dim, num_heads=num_heads, batch_first=False)
        self.graph_to_text_attn = MultiheadAttention(embed_dim=self.hidden_dim, num_heads=num_heads, batch_first=True)
        self.norm_text = nn.LayerNorm(self.hidden_dim)
        self.norm_graph = nn.LayerNorm(self.hidden_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, text_feats, graph_feats):
        text_proj = self.text_to_graph_proj(text_feats)
        graph_proj = self.graph_to_text_proj(graph_feats)
        text_out, _ = self.text_to_graph_attn(query=graph_proj[:, -1, :].unsqueeze(0), key=text_proj.transpose(0, 1),
                                              value=text_proj.transpose(0, 1))
        graph_out, _ = self.graph_to_text_attn(query=text_proj, key=graph_proj, value=graph_proj)
        graph_out = graph_out.mean(dim=1)
        text_out = text_out.squeeze()
        graph_out = self.dropout(graph_out)
        text_out = self.dropout(text_out)
        return text_out, graph_out


import torch.nn as nn
import torch.nn.functional as F


class NeighborAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.q_proj = nn.Linear(input_dim, hidden_dim)
        self.k_proj = nn.Linear(input_dim, hidden_dim)
        self.v_proj = nn.Linear(input_dim, hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, input_dim)
        self.scale = hidden_dim ** 0.5

    def forward(self, neighbor_feats):
        Q = self.q_proj(neighbor_feats.mean(dim=1, keepdim=True))
        K = self.k_proj(neighbor_feats)
        V = self.v_proj(neighbor_feats)
        attn = (Q @ K.transpose(-1, -2)) / self.scale
        attn = torch.softmax(attn, dim=-1)
        agg = attn @ V
        agg = agg.squeeze(1)
        out = self.out_proj(agg)
        return out


class GATWithCrossAttention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.neigh_attn = NeighborAttention(input_dim=cfg.gnn.out_channels, hidden_dim=cfg.gnn.out_channels // 2)
        self.text_encoder = TextEncoder(cfg)
        if cfg.model.node_type == "gat":
            self.node_gnn = NodeGAT(cfg)
        else:
            self.node_gnn = NodeGCN(cfg)
        if cfg.model.edge_type == "gcn":
            self.edge_gnn = EdgeGAT(cfg)
        else:
            self.edge_gnn = EdgeGCN(cfg)
        if cfg.gcn.source_in_channels != cfg.gcn.target_in_channels:
            self.cross_fc = nn.Linear(cfg.gcn.target_in_channels, cfg.gcn.source_in_channels)
        else:
            self.cross_fc = None
        self.node_num = cfg.gnn.vocab_size
        self.cross_attention = CrossModalAttention(text_dim=384, graph_dim=cfg.gnn.out_channels, num_heads=4)
        self.edgefc = nn.Linear(772 + cfg.dataset.source_label_num, 384)
        self.node_predictor = nn.Sequential(
            nn.LayerNorm(384 * 2),
            nn.Linear(384 * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, cfg.dataset.source_label_num)
        )
        self.text_proj = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        self.graph_proj = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        self.edge_predictor = nn.Sequential(
            nn.Linear(384 * 2 + 1, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )

        self.projection_head = nn.Sequential(
            nn.Linear(768, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

        self.projection_head_g = nn.Sequential(
            nn.Linear(769, 512),
            nn.ReLU(),
            nn.Linear(512, 256))
        self.residual = nn.Sequential(
            nn.Linear(768 + 768, 256),
            nn.LayerNorm(256)
        )
        self.temperature = nn.Parameter(torch.tensor(0.07))
        self.contrastive_scale = 0.1
        self.device = torch.device(cfg.training.device)
        self.to(self.device)
        for param in self.text_encoder.parameters():
            param.requires_grad = True
        self.queue1 = FeatureQueue(dim=128)
        self.queue2 = FeatureQueue(dim=128)

    def jaccard_similarity(self, n_id1, n_id2, edge_index):
        neighbors1 = self._get_neighbors(n_id1, edge_index)
        neighbors2 = self._get_neighbors(n_id2, edge_index)
        batch_size = n_id1.size(0)
        jaccard_scores = torch.zeros(batch_size, dtype=torch.float)
        for i in range(batch_size):
            neighbors1_i = neighbors1[i]
            neighbors2_i = neighbors2[i]

            neighbors1_i = neighbors1_i[neighbors1_i != -1]
            neighbors2_i = neighbors2_i[neighbors2_i != -1]

            intersection = torch.sum(torch.isin(neighbors1_i, neighbors2_i)).item()

            union = len(torch.unique(torch.cat((neighbors1_i, neighbors2_i))))

            jaccard_scores[i] = intersection / union if union > 0 else 0

        return jaccard_scores.to(self.device)

    def _get_common_neighbors(self, n_id1, n_id2, edge_index):
        neighbors1 = self._get_neighbors(n_id1, edge_index)
        neighbors2 = self._get_neighbors(n_id2, edge_index)

        neighbors1_np = neighbors1.numpy()
        neighbors2_np = neighbors2.numpy()

        common_neighbors_np = np.intersect1d(neighbors1_np, neighbors2_np)

        common_neighbors = torch.tensor(common_neighbors_np, dtype=torch.long)

        return common_neighbors

    def adamic_adar_index(self, n_id1, n_id2, edge_index):
        batch_size = n_id1.size(0)
        adamic_adar_scores = torch.zeros(batch_size, dtype=torch.float)

        for i in range(batch_size):
            common_neighbors = self._get_common_neighbors(n_id1[i], n_id2[i], edge_index)
            score = 0

            for neighbor in common_neighbors:
                degree = len(self._get_neighbors(neighbor, edge_index))
                if degree > 1:
                    score += 1 / np.log(degree)

            adamic_adar_scores[i] = score

        return adamic_adar_scores.to(self.device)

    def node_degree(self, n_id, edge_index):
        batch_size = n_id.size(0)
        degrees = torch.zeros(batch_size, dtype=torch.long)

        for i in range(batch_size):
            neighbors = self._get_neighbors(n_id[i], edge_index)
            degrees[i] = len(neighbors)

        return degrees.to(self.device)

    def _get_neighbors(self, n_id1, edge_index):
        if isinstance(n_id1, torch.Tensor):
            if len(n_id1.shape) == 0:
                n_id1 = n_id1.unsqueeze(0)
            batch_size = n_id1.size(0)
        else:
            batch_size = 1

        neighbors_list = []

        for i in range(batch_size):
            n_id = n_id1[i]
            neighbors_from_n_id = edge_index[1][edge_index[0] == n_id]
            neighbors_to_n_id = edge_index[0][edge_index[1] == n_id]
            neighbors = torch.cat([neighbors_from_n_id, neighbors_to_n_id])
            neighbors = torch.unique(neighbors)
            neighbors_list.append(neighbors)

        max_neighbors = max(len(neighbors) for neighbors in neighbors_list)
        neighbors_padded = torch.full((batch_size, max_neighbors), fill_value=-1, dtype=torch.long)

        for i, neighbors in enumerate(neighbors_list):
            neighbors_padded[i, :len(neighbors)] = neighbors

        return neighbors_padded

    def one_hot_encoding(self, labels, num_classes):
        return torch.eye(num_classes).to(self.device)[labels]

    def edge_predict(self, x_embed, edge_index, n_id1, n_id2, x_texts1, x_texts2, warmup=True, seqlen=64,
                     fixed_length=10, y=None, num_classes=10, user_cross=True):
        src_text_emb, src_graph_emb = self.get_node_feature(x_embed, x_texts1, n_id1, edge_index, warmup=warmup,
                                                            seqlen=seqlen,
                                                            fixed_length=fixed_length, type="edge",
                                                            user_cross=user_cross)
        neighbor_counts = self._count_common_neighbors(n_id1, n_id2, edge_index)
        jac_sim = self.jaccard_similarity(n_id1, n_id2, edge_index)
        adamic_adar_score = self.adamic_adar_index(n_id1, n_id2, edge_index)
        degree1 = self.node_degree(n_id1, edge_index)
        degree2 = self.node_degree(n_id2, edge_index)
        label1 = self.one_hot_encoding(y[n_id1], num_classes)
        label2 = self.one_hot_encoding(y[n_id2], num_classes)
        src_features = torch.cat([src_text_emb, src_graph_emb,
                                  neighbor_counts.unsqueeze(1), jac_sim.unsqueeze(1),
                                  adamic_adar_score.unsqueeze(1),
                                  degree1.unsqueeze(1),
                                  label1], dim=-1)
        src_emb = self.edgefc(src_features)
        dst_text_emb, dst_graph_emb = self.get_node_feature(x_embed, x_texts2, n_id2, edge_index, warmup=warmup,
                                                            seqlen=seqlen,
                                                            fixed_length=fixed_length, type="edge",
                                                            user_cross=user_cross)
        dst_features = torch.cat([dst_text_emb, dst_graph_emb,
                                  neighbor_counts.unsqueeze(1), jac_sim.unsqueeze(1),
                                  adamic_adar_score.unsqueeze(1),
                                  degree2.unsqueeze(1),
                                  label2], dim=-1)
        dst_emb = self.edgefc(dst_features)
        pred = F.cosine_similarity(src_emb, dst_emb, dim=-1)
        return pred

    def contrastive_with_queue(self, x_text, x_img, labels, queue: FeatureQueue, temperature=0.07, neg_k=64):
        B, D = x_text.shape
        x_text = F.normalize(x_text, dim=1)
        x_img = F.normalize(x_img, dim=1)
        pos_sim = torch.sum(x_text * x_img, dim=1) / temperature

        neg_sims = []
        for i in range(B):
            neg_feats = queue.get_negatives(labels[i].item(), neg_k)
            if neg_feats is None:
                continue
            neg_feats = F.normalize(neg_feats, dim=1)
            sim = torch.mm(x_text[i:i + 1], neg_feats.T) / temperature
            neg_sims.append(sim)

        final_logits = []
        final_labels = []
        for i, neg in enumerate(neg_sims):
            logits_i = torch.cat([pos_sim[i:i + 1], neg.squeeze(0)], dim=0)
            final_logits.append(logits_i.unsqueeze(0))
            final_labels.append(torch.tensor(0).to(x_text.device))

        final_logits = torch.cat(final_logits, dim=0)
        final_labels = torch.stack(final_labels)

        loss = F.cross_entropy(final_logits, final_labels)
        return loss

    def get_node_feature(self, x_embed, x_texts, n_id, edge_index, warmup=True, seqlen=64, fixed_length=10,
                         type="node", user_cross=True):
        if warmup:
            for param in self.text_encoder.parameters():
                param.requires_grad = False
        else:
            for param in self.text_encoder.parameters():
                param.requires_grad = True
        x_text = self.text_encoder(x_texts, seqlen, False)
        edge_index = edge_index.long().to(self.device)
        if type == "node":
            x_graph = self.node_gnn(edge_index, x_embed)
        else:
            x_graph = self.edge_gnn(edge_index, x_embed)
        center_node = x_graph[n_id]
        neighbors_list = [edge_index[1, edge_index[0] == id] for id in n_id]
        processed_neighbors = []
        for neighbors in neighbors_list:
            if len(neighbors) > fixed_length:
                neighbors = neighbors[:fixed_length]
            elif len(neighbors) < fixed_length:
                padding_size = fixed_length - len(neighbors)
                neighbors = torch.cat([neighbors, torch.full((padding_size,), -1).to(self.device)])
            processed_neighbors.append(neighbors)
        dummy = torch.zeros(1, x_graph.size(1), device=x_graph.device)
        x_graph = torch.cat([x_graph, dummy], dim=0)
        neighbors = torch.stack(processed_neighbors, dim=0)
        neighbors = neighbors.clone()
        neighbors[neighbors == -1] = x_graph.size(0) - 1
        neighbour_node = x_graph[neighbors]
        if user_cross:
            text_out, graph_out = self.cross_attention(x_text,
                                                       torch.concat(
                                                           [neighbour_node, torch.unsqueeze(center_node, dim=1)],
                                                           dim=1))
        else:
            text_out, graph_out = x_text.mean(dim=1), torch.concat(
                [neighbour_node, torch.unsqueeze(center_node, dim=1)], dim=1).mean(dim=1)
        if text_out.dim() == 1:
            text_out = text_out.unsqueeze(0)
        return text_out, graph_out

    def node_predict(self, x_embed, x_texts, x_prompts, n_id, edge_index, warmup=True, labels=None, reg=0, seqlen=64,
                     temperature=0.07, neg_k=512, fixed_length=10, user_cross=True, text_label=None):
        text_emb, graph_emb = self.get_node_feature(x_embed, x_texts, n_id, edge_index, warmup=warmup, seqlen=seqlen,
                                                    fixed_length=fixed_length, type="node", user_cross=user_cross)
        if text_label is None:
            x = torch.cat([text_emb, graph_emb], dim=1)
            predict = self.node_predictor(x)
        else:
            x = (text_emb + graph_emb) * 0.5
            predict = torch.matmul(x, self.text_encoder(text_label, seqlen, True).T)
        if reg > 0:
            x_prompt_clean = list()
            for prompt in x_prompts:
                try:
                    data = json.loads(prompt.strip().strip("```json").strip("```"))
                    text_clean = f"""
                                      Classification Prediction: {data['Classification Prediction']}.
                                      Keywords: {', '.join(data['Feature Extraction']['keywords'])}.
                                      Methods: {', '.join(data['Feature Extraction']['methods'])}.
                                      Novelty: {data['Feature Extraction']['novelty']}
                                      Summary: {data['Summary']}
                                      """.strip()
                    x_prompt_clean.append(text_clean)
                except:
                    x_prompt_clean.append(prompt)
            x_prompt = self.text_encoder(x_prompt_clean, seqlen, True)
            x_text = self.text_encoder(x_texts, seqlen, True)
            text_pro, x_prompt = self.text_proj(x_text), self.graph_proj(x_prompt)
            self.queue1.enqueue(x_prompt.detach(), labels.detach())
            self.queue2.enqueue(text_pro.detach(), labels.detach())
            loss1 = self.contrastive_with_queue(text_pro, x_prompt, labels, self.queue1, temperature=temperature,
                                                neg_k=neg_k)
            loss2 = self.contrastive_with_queue(x_prompt, text_pro, labels, self.queue2, temperature=temperature,
                                                neg_k=neg_k)
            loss = reg * (loss1 + loss2) / 2
        else:
            loss = 0.0
        return predict, loss

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)

        if ptr + batch_size > self.queue_size:
            self.queue[ptr:] = keys[:self.queue_size - ptr]
            self.queue[:batch_size - (self.queue_size - ptr)] = keys[self.queue_size - ptr:]
        else:
            self.queue[ptr:ptr + batch_size] = keys

        ptr = (ptr + batch_size) % self.queue_size
        self.queue_ptr[0] = ptr

    def _count_common_neighbors(self, u, v, edge_index):
        def cal(u_, v_):
            try:
                return len(list(nx.common_neighbors(G, u_, v_)))
            except:
                return 0

        G = nx.Graph()
        G.add_edges_from(edge_index.cpu().t().tolist())
        u = u.cpu().tolist()
        v = v.cpu().tolist()
        return torch.tensor([cal(u[i], v[i]) for i in range(len(u))]).to(self.device)


class GATWithPrompt(torch.nn.Module):
    def __init__(self, text_encoder, cfg):
        super(GATWithPrompt, self).__init__()
        self.text_encoder = text_encoder
        self.gat = GAT(cfg)
        self.gat.to(text_encoder.device)

    def forward(self, x_text, edge_index):
        x = self.text_encoder(x_text)
        return self.gat(x, edge_index)


def train(model, data, cfg, warmup, vocab_size, cross_data=False):
    model.train()
    if cfg.training.task == "node":
        dataset = ListDataset(list(range(vocab_size)))
        dataloader = DataLoader(dataset, batch_size=cfg.training.batch_size, shuffle=True)
        data.edge_index = data.edge_index.to(torch.long)
        node_criterion = nn.CrossEntropyLoss()
        node_loss = 0
        all_preds = []
        all_targets = []
        nodeoptimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr, weight_decay=cfg.training.weight_decay)
        edge_index = data.edge_index.long().to("cuda")
        for n_id in dataloader:
            n_id = n_id.to("cuda")
            mask = data.train_mask[n_id]
            if mask.sum() == 0:
                continue
            filtered_n_id = n_id[mask]
            x_text = [data.x_text[i] for i in filtered_n_id]
            x_prompts = [data.x_prompts[i] for i in filtered_n_id]
            if cfg.dataset.source_name == cfg.dataset.target_name:
                text_label = None
            else:
                text_label = data.label
            if cfg.model.node_embed_type == "id":
                predict, cons_loss = model.node_predict(None,
                                                        x_text, x_prompts, filtered_n_id, edge_index, warmup,
                                                        data.y[filtered_n_id], cfg.training.reg,
                                                        cfg.training.seqlen, cfg.training.temperature,
                                                        cfg.training.neg_k,
                                                        cfg.training.fixed_length,
                                                        cfg.training.user_cross,
                                                        text_label
                                                        )
            else:
                if model.cross_fc is not None and cross_data:
                    x = model.cross_fc(data.x)
                else:
                    x = data.x
                predict, cons_loss = model.node_predict(x,
                                                        x_text, x_prompts, filtered_n_id, edge_index, warmup,
                                                        data.y[filtered_n_id], cfg.training.reg,
                                                        cfg.training.seqlen, cfg.training.temperature,
                                                        cfg.training.neg_k,
                                                        cfg.training.fixed_length,
                                                        cfg.training.user_cross,
                                                        text_label
                                                        )
            target_filtered = data.y[filtered_n_id]
            if cfg.dataset.source_name in ["ogbn-products", "ogbn-arxiv"]:
                target_filtered = target_filtered.squeeze()
            loss = cons_loss + node_criterion(predict, target_filtered)
            all_preds.append(predict.argmax(dim=1).detach())
            all_targets.append(target_filtered.detach())
            nodeoptimizer.zero_grad()
            loss.backward()
            nodeoptimizer.step()
            node_loss += loss.item()
        all_preds = torch.cat(all_preds).cpu().numpy()
        all_targets = torch.cat(all_targets).cpu().numpy()
        node_loss = node_loss / len(dataloader)
        node_acc = accuracy_score(all_targets, all_preds)
        node_f1 = f1_score(all_targets, all_preds, average='micro')
        edge_auc, edge_ap, edge_acc, edge_f1, edge_loss = 0, 0, 0, 0, 0
    else:
        dataset = EdgeDataset(data.train_edge_label_index.t(), data.train_edge_label, device=cfg.training.device)
        edgeoptimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=cfg.training.weight_decay)
        dataloader = DataLoader(dataset, batch_size=cfg.training.edgebatch_size, shuffle=True)
        edge_loss = 0
        all_preds = []
        all_labels = []
        for batch_edges, batch_labels in dataloader:
            n_id1 = batch_edges[:, 0].to(cfg.training.device)
            n_id2 = batch_edges[:, 1].to(cfg.training.device)
            x_texts1 = [data.x_text[i] for i in n_id1]
            x_texts2 = [data.x_text[i] for i in n_id2]
            batch_labels = batch_labels.to(cfg.training.device)

            neg_n_id1 = torch.randint(0, cfg.gnn.vocab_size, size=(len(n_id1),))
            x_texts3 = [data.x_text[i] for i in neg_n_id1]

            pred = model.edge_predict(data.x, data.train_edge_index, n_id1, n_id2, x_texts1, x_texts2,
                                      warmup,
                                      cfg.training.seqlen,
                                      cfg.training.fixed_length,
                                      data.y,
                                      cfg.dataset.label_num,
                                      cfg.training.user_cross)
            loss = F.binary_cross_entropy_with_logits(pred, batch_labels)
            negpred = model.edge_predict(data.x, data.train_edge_index, n_id1, neg_n_id1, x_texts1, x_texts3,
                                         warmup,
                                         cfg.training.seqlen,
                                         cfg.training.fixed_length,
                                         data.y,
                                         cfg.dataset.label_num,
                                         cfg.training.user_cross
                                         )
            logits = torch.cat([pred.unsqueeze(1), negpred.unsqueeze(1)], dim=1) / 0.07
            labels = torch.zeros(logits.size(0), dtype=torch.long).to(logits.device)
            consloss = (F.cross_entropy(logits, labels, reduction='none') * batch_labels).mean()
            loss += consloss * 0.5
            edgeoptimizer.zero_grad()
            loss.backward()
            edgeoptimizer.step()
            edge_loss += loss.item()
            pred = pred.sigmoid()
            all_preds.append(pred.detach().cpu())
            all_labels.append(batch_labels.detach().cpu())
        all_preds = torch.cat(all_preds).numpy()
        all_labels = torch.cat(all_labels).numpy()
        precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
        edge_loss = edge_loss / len(dataloader)
        edge_auc = roc_auc_score(all_labels, all_preds)
        edge_ap = average_precision_score(all_labels, all_preds)
        pred_bin = (all_preds > thresholds.mean()).astype(np.int32)
        edge_acc = accuracy_score(all_labels, pred_bin)
        edge_f1 = f1_score(all_labels, pred_bin, average='macro')
        node_acc, node_f1, node_loss = 0, 0, 0
    return node_acc, edge_auc, node_f1, edge_ap, edge_acc, edge_f1, node_loss, edge_loss


def test(model, data, cfg, warmup, mode):
    model.eval()
    wrong_node_id, wrong_label, wrong_pred = None, None, None
    if cfg.testing.task == "node":
        all_correct = 0
        total = 0
        node_loss = 0
        node_criterion = torch.nn.CrossEntropyLoss()
        dataset = ListDataset(list(range(cfg.gnn.target_vocab_size)))
        dataloader = DataLoader(dataset, batch_size=cfg.training.batch_size, shuffle=True)
        preds = []
        labels = []
        filtered_n_ids = []
        if cfg.dataset.source_name == cfg.dataset.target_name:
            text_label = None
        else:
            text_label = data.label
        with torch.no_grad():
            for n_id in dataloader:
                n_id = n_id.to("cuda")
                edge_index = data.edge_index.to(torch.long).to("cuda")
                if mode == 'val':
                    mask = data.val_mask[n_id]
                else:
                    mask = data.test_mask[n_id]
                if mask.sum() == 0:
                    continue
                filtered_n_id = n_id[mask]
                x_text = [data.x_text[i] for i in filtered_n_id]
                x_prompts = [data.x_prompts[i] for i in filtered_n_id]
                if cfg.model.node_embed_type == "id":
                    predict, cons_loss = model.node_predict(None, x_text, x_prompts, filtered_n_id, edge_index,
                                                            warmup,
                                                            data.y[filtered_n_id], cfg.training.reg,
                                                            cfg.training.seqlen, cfg.training.temperature,
                                                            cfg.training.neg_k,
                                                            cfg.training.fixed_length,
                                                            cfg.training.user_cross,
                                                            text_label
                                                            )
                else:
                    if model.cross_fc is not None:
                        x = model.cross_fc(data.x)
                    else:
                        x = data.x
                    predict, cons_loss = model.node_predict(x, x_text, x_prompts, filtered_n_id, edge_index,
                                                            warmup,
                                                            data.y[filtered_n_id], cfg.training.reg,
                                                            cfg.training.seqlen, cfg.training.temperature,
                                                            cfg.training.neg_k,
                                                            cfg.training.fixed_length,
                                                            cfg.training.user_cross,
                                                            text_label
                                                            )
                    filtered_n_ids.extend(filtered_n_id.tolist())
                target_filtered = data.y[filtered_n_id]
                if cfg.dataset.target_name in ["ogbn-products", "ogbn-arxiv"]:
                    target_filtered = target_filtered.squeeze()
                pred = predict.argmax(dim=1)
                preds.extend(pred.cpu().tolist())
                labels.extend(target_filtered.cpu().tolist())
                all_correct += (pred == target_filtered).sum().item()
                total += len(target_filtered)
                loss = node_criterion(predict, target_filtered)
                torch.cuda.empty_cache()
                node_loss += loss.item()
                node_loss += cons_loss
        node_loss = node_loss / len(dataloader)
        node_acc = all_correct / total if total > 0 else 0
        node_f1 = f1_score(labels, preds, average='macro')
        incorrect_indices = [i for i, (label, pred) in enumerate(zip(labels, preds)) if label != pred]
        edge_auc, edge_ap, edge_acc, edge_f1, edge_loss = 0, 0, 0, 0, 0
        wrong_node_id = [filtered_n_ids[i] for i in incorrect_indices]
        wrong_label = [labels[i] for i in incorrect_indices]
        wrong_pred = [preds[i] for i in incorrect_indices]
    else:
        with torch.no_grad():
            if mode == 'val':
                dataset = EdgeDataset(data.val_edge_label_index.t(), data.val_edge_label, device=cfg.training.device)
            else:
                dataset = EdgeDataset(data.test_edge_label_index.t(), data.test_edge_label, device=cfg.training.device)
            dataloader = DataLoader(dataset, batch_size=cfg.training.edgebatch_size, shuffle=True)
            all_preds = []
            all_labels = []
            edge_loss = 0
            for batch_edges, batch_labels in dataloader:
                n_id1 = batch_edges[:, 0].to(cfg.training.device)
                n_id2 = batch_edges[:, 1].to(cfg.training.device)
                x_texts1 = [data.x_text[i] for i in n_id1]
                x_texts2 = [data.x_text[i] for i in n_id2]
                batch_labels = batch_labels.to(cfg.training.device)
                if mode == 'test':
                    pred = model.edge_predict(data.x, data.test_edge_index, n_id1, n_id2, x_texts1, x_texts2,
                                              warmup,
                                              cfg.training.seqlen,
                                              cfg.training.fixed_length,
                                              data.y,
                                              cfg.dataset.target_label_num,
                                              cfg.training.user_cross
                                              )
                else:
                    pred = model.edge_predict(data.x, data.val_edge_index, n_id1, n_id2, x_texts1, x_texts2,
                                              warmup,
                                              cfg.training.seqlen,
                                              cfg.training.fixed_length,
                                              data.y,
                                              cfg.dataset.target_label_num,
                                              cfg.training.user_cross
                                              )

                loss = F.binary_cross_entropy_with_logits(pred, batch_labels)
                pred = pred.sigmoid()
                edge_loss += loss.item()
                all_preds.append(pred.detach().cpu())
                all_labels.append(batch_labels.detach().cpu())
            all_preds = torch.cat(all_preds).numpy()
            all_labels = torch.cat(all_labels).numpy()
            precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
            edge_auc = roc_auc_score(all_labels, all_preds)
            edge_ap = average_precision_score(all_labels, all_preds)
            pred_bin = (all_preds > thresholds.mean()).astype(np.int32)
            edge_acc = accuracy_score(all_labels, pred_bin)
            edge_f1 = f1_score(all_labels, pred_bin, average='macro')
            edge_loss /= len(dataloader)
            node_acc, node_f1, node_loss = 0, 0, 0
    return node_acc, edge_auc, node_f1, edge_ap, edge_acc, edge_f1, node_loss, edge_loss, wrong_node_id, wrong_label, wrong_pred