import os
import json
import numpy as np
import copy
import torch
import random
from torch_geometric.utils import negative_sampling
from tqdm import tqdm

from utils import UnionFindSet, get_bfs_sub_graph, get_dfs_sub_graph,negative_sampling_from_node_pairs
from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader

def tensor_to_serializable(obj):
    if isinstance(obj, torch.Tensor):
        return obj.tolist()
    if isinstance(obj, dict):
        return {k: tensor_to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [tensor_to_serializable(v) for v in obj]
    return obj


class GNN_DATA:
    def __init__(self, ppi_path, exclude_protein_path=None, max_len=2000, skip_head=True, p1_index=0, p2_index=1, label_index=2, graph_undirection=True, bigger_ppi_path=None):
        self.ppi_list = []
        self.ppi_dict = {}
        self.ppi_label_list = []
        self.protein_dict = {}
        self.protein_name = {}
        self.ppi_path = ppi_path
        self.bigger_ppi_path = bigger_ppi_path
        self.max_len = max_len

        name = 0
        ppi_name = 0
        # maxlen = 0
        self.node_num = 0
        self.edge_num = 0
        if exclude_protein_path != None:
            with open(exclude_protein_path, 'r') as f:
                ex_protein = json.load(f)
                f.close()
            ex_protein = {p:i for i, p in enumerate(ex_protein)}
        else:
            ex_protein = {}

        class_map = {'reaction':0, 'binding':1, 'ptmod':2, 'activation':3, 'inhibition':4, 'catalysis':5, 'expression':6}

        for line in tqdm(open(ppi_path)):
            if skip_head:
                skip_head = False
                continue
            line = line.strip().split('\t')

            if line[p1_index] in ex_protein.keys() or line[p2_index] in ex_protein.keys():
                continue

            # get node and node name
            if line[p1_index] not in self.protein_name.keys():
                self.protein_name[line[p1_index]] = name
                name += 1
            
            if line[p2_index] not in self.protein_name.keys():
                self.protein_name[line[p2_index]] = name
                name += 1

            # get edge and its label
            temp_data = ""
            if line[p1_index] < line[p2_index]:
                temp_data = line[p1_index] + "__" + line[p2_index]
            else:
                temp_data = line[p2_index] + "__" + line[p1_index]

            if temp_data not in self.ppi_dict.keys():
                self.ppi_dict[temp_data] = ppi_name
                temp_label = [0, 0, 0, 0, 0, 0, 0]
                temp_label[class_map[line[label_index]]] = 1
                self.ppi_label_list.append(temp_label)
                ppi_name += 1
            else:
                index = self.ppi_dict[temp_data]
                temp_label = self.ppi_label_list[index]
                temp_label[class_map[line[label_index]]] = 1
                self.ppi_label_list[index] = temp_label
        
        if bigger_ppi_path != None:
            skip_head = True
            for line in tqdm(open(bigger_ppi_path)):
                if skip_head:
                    skip_head = False
                    continue
                line = line.strip().split('\t')

                if line[p1_index] not in self.protein_name.keys():
                    self.protein_name[line[p1_index]] = name
                    name += 1
                
                if line[p2_index] not in self.protein_name.keys():
                    self.protein_name[line[p2_index]] = name
                    name += 1
                
                temp_data = ""
                if line[p1_index] < line[p2_index]:
                    temp_data = line[p1_index] + "__" + line[p2_index]
                else:
                    temp_data = line[p2_index] + "__" + line[p1_index]
                
                if temp_data not in self.ppi_dict.keys():
                    self.ppi_dict[temp_data] = ppi_name
                    temp_label = [0, 0, 0, 0, 0, 0, 0]
                    temp_label[class_map[line[label_index]]] = 1
                    self.ppi_label_list.append(temp_label)
                    ppi_name += 1
                else:
                    index = self.ppi_dict[temp_data]
                    temp_label = self.ppi_label_list[index]
                    temp_label[class_map[line[label_index]]] = 1
                    self.ppi_label_list[index] = temp_label

        i = 0
        for ppi in tqdm(self.ppi_dict.keys()):
            name = self.ppi_dict[ppi]
            assert name == i
            i += 1
            temp = ppi.strip().split('__')
            self.ppi_list.append(temp)


        ppi_num = len(self.ppi_list)
        self.origin_ppi_list = copy.deepcopy(self.ppi_list)
        assert len(self.ppi_list) == len(self.ppi_label_list)
        for i in tqdm(range(ppi_num)):
            seq1_name = self.ppi_list[i][0]
            seq2_name = self.ppi_list[i][1]
            # print(len(self.protein_name))
            self.ppi_list[i][0] = self.protein_name[seq1_name]
            self.ppi_list[i][1] = self.protein_name[seq2_name]
        
        if graph_undirection:
            for i in tqdm(range(ppi_num)):
                temp_ppi = self.ppi_list[i][::-1]
                temp_ppi_label = self.ppi_label_list[i]
                # if temp_ppi not in self.ppi_list:
                self.ppi_list.append(temp_ppi)
                self.ppi_label_list.append(temp_ppi_label)

        self.node_num = len(self.protein_name)
        self.edge_num = len(self.ppi_list)
    
    def get_protein_aac(self, pseq_path):
        # aac: amino acid sequences

        self.pseq_path = pseq_path
        self.pseq_dict = {}
        self.protein_len = []

        for line in tqdm(open(self.pseq_path)):
            line = line.strip().split('\t')
            if line[0] not in self.pseq_dict.keys():
                self.pseq_dict[line[0]] = line[1]
                self.protein_len.append(len(line[1]))
        
        print("protein num: {}".format(len(self.pseq_dict)))
        print("protein average length: {}".format(np.average(self.protein_len)))
        print("protein max & min length: {}, {}".format(np.max(self.protein_len), np.min(self.protein_len)))

    def embed_normal(self, seq, dim):
        if len(seq) > self.max_len:
            return seq[:self.max_len]
        elif len(seq) < self.max_len:
            less_len = self.max_len - len(seq)
            return np.concatenate((seq, np.zeros((less_len, dim))))
        return seq

    def vectorize(self, vec_path):
        self.acid2vec = {}
        self.dim = None
        for line in open(vec_path):
            line = line.strip().split('\t')
            temp = np.array([float(x) for x in line[1].split()])
            self.acid2vec[line[0]] = temp
            if self.dim is None:
                self.dim = len(temp)
        print("acid vector dimension: {}".format(self.dim))

        self.pvec_dict = {}

        for p_name in tqdm(self.pseq_dict.keys()):
            temp_seq = self.pseq_dict[p_name]
            temp_vec = []
            for acid in temp_seq:
                temp_vec.append(self.acid2vec[acid])
            temp_vec = np.array(temp_vec)

            temp_vec = self.embed_normal(temp_vec, self.dim)

            self.pvec_dict[p_name] = temp_vec

    def get_feature_origin(self, pseq_path, vec_path):
        self.get_protein_aac(pseq_path)

        self.vectorize(vec_path)

        self.protein_dict = {}
        for name in tqdm(self.protein_name.keys()):
            self.protein_dict[name] = self.pvec_dict[name]

    def get_connected_num(self):
        self.ufs = UnionFindSet(self.node_num)
        ppi_ndary = np.array(self.ppi_list)
        for edge in ppi_ndary:
            start, end = edge[0], edge[1]
            self.ufs.union(start, end)

    def generate_data(self):
        self.get_connected_num()

        print("Connected domain num: {}".format(self.ufs.count))

        ppi_list = np.array(self.ppi_list)
        ppi_label_list = np.array(self.ppi_label_list)

        self.edge_index = torch.tensor(ppi_list, dtype=torch.long)
        self.edge_attr = torch.tensor(ppi_label_list, dtype=torch.long)
        self.x = []
        i = 0
        for name in self.protein_name:
            assert self.protein_name[name] == i
            i += 1
            self.x.append(self.protein_dict[name])
        
        self.x = np.array(self.x)
        self.x = torch.tensor(self.x, dtype=torch.float)

        self.data = Data(x=self.x, edge_index=self.edge_index.T, edge_attr_1=self.edge_attr)
    
    def split_dataset(self, train_valid_index_path, test_size=0.2, random_new=False, mode='random'):
        if random_new:
            if mode == 'random':
                ppi_num = int(self.edge_num // 2)
                random_list = [i for i in range(ppi_num)]
                random.shuffle(random_list)

                self.ppi_split_dict = {}
                self.ppi_split_dict['train_index'] = random_list[: int(ppi_num * (1-test_size))]
                self.ppi_split_dict['valid_index'] = random_list[int(ppi_num * (1-test_size)) :]
                jsobj = json.dumps(self.ppi_split_dict)
                with open(train_valid_index_path, 'w') as f:
                    f.write(jsobj)
                    f.close()

            elif mode == 'bfs' or mode == 'dfs':
                print("use {} methed split train and valid dataset".format(mode))
                node_to_edge_index = {}
                edge_num = int(self.edge_num // 2)
                for i in range(edge_num):
                    edge = self.ppi_list[i]
                    if edge[0] not in node_to_edge_index.keys():
                        node_to_edge_index[edge[0]] = []
                    node_to_edge_index[edge[0]].append(i)

                    if edge[1] not in node_to_edge_index.keys():
                        node_to_edge_index[edge[1]] = []
                    node_to_edge_index[edge[1]].append(i)
                
                node_num = len(node_to_edge_index)

                sub_graph_size = int(edge_num * test_size)
                if mode == 'bfs':
                    selected_edge_index = get_bfs_sub_graph(self.ppi_list, node_num, node_to_edge_index, sub_graph_size)
                elif mode == 'dfs':
                    selected_edge_index = get_dfs_sub_graph(self.ppi_list, node_num, node_to_edge_index, sub_graph_size)
                
                all_edge_index = [i for i in range(edge_num)]

                unselected_edge_index = list(set(all_edge_index).difference(set(selected_edge_index)))

                self.ppi_split_dict = {}
                self.ppi_split_dict['train_index'] = unselected_edge_index
                self.ppi_split_dict['valid_index'] = selected_edge_index

                assert len(unselected_edge_index) + len(selected_edge_index) == edge_num

                jsobj = json.dumps(self.ppi_split_dict)
                with open(train_valid_index_path, 'w') as f:
                    f.write(jsobj)
                    f.close()
            
            else:
                print("your mode is {}, you should use bfs, dfs or random".format(mode))
                return
        else:
            with open(train_valid_index_path, 'r') as f:
                self.ppi_split_dict = json.load(f)
                f.close()



class GNN_DATA_Binding:
    def __init__(self, ppi_path, exclude_protein_path=None, max_len=2000, skip_head=True, p1_index=0, p2_index=1, label_index=2, graph_undirection=True, bigger_ppi_path=None):
        self.ppi_list = []
        self.ppi_dict = {}
        self.ppi_label_list = []
        self.protein_dict = {}
        self.protein_name = {}
        self.ppi_path = ppi_path
        self.bigger_ppi_path = bigger_ppi_path
        self.max_len = max_len

        name = 0
        ppi_name = 0
        # maxlen = 0
        self.node_num = 0
        self.edge_num = 0

        if 'yeast' in ppi_path:
            for line in tqdm(open(ppi_path)):
                if skip_head:
                    skip_head = False
                    continue
                line = line.strip().split('\t')
                # import pdb;
                # pdb.set_trace()

                if line[2]== '0' :
                    continue
                # print(line)
                

                # get node and node name
                if line[p1_index] not in self.protein_name.keys():
                    self.protein_name[line[p1_index]] = name
                    name += 1
                
                if line[p2_index] not in self.protein_name.keys():
                    self.protein_name[line[p2_index]] = name
                    name += 1

                # get edge and its label
                temp_data = ""
                if line[p1_index] < line[p2_index]:
                    temp_data = line[p1_index] + "__" + line[p2_index]
                else:
                    temp_data = line[p2_index] + "__" + line[p1_index]


                # import pdb
                # pdb.set_trace()
                if temp_data not in self.ppi_dict.keys():
                    self.ppi_dict[temp_data] = ppi_name
                    temp_label = [1]
                    self.ppi_label_list.append(temp_label)
                    ppi_name += 1

        else:
            for line in tqdm(open(ppi_path)):
                if skip_head:
                    skip_head = False
                    continue
                line = line.strip().split('\t')
                if line[label_index] != 'binding':
                    continue


                # get node and node name
                if line[p1_index] not in self.protein_name.keys():
                    self.protein_name[line[p1_index]] = name
                    name += 1
                
                if line[p2_index] not in self.protein_name.keys():
                    self.protein_name[line[p2_index]] = name
                    name += 1

                # get edge and its label
                temp_data = ""
                if line[p1_index] < line[p2_index]:
                    temp_data = line[p1_index] + "__" + line[p2_index]
                else:
                    temp_data = line[p2_index] + "__" + line[p1_index]


                # import pdb
                # pdb.set_trace()
                if temp_data not in self.ppi_dict.keys():
                    self.ppi_dict[temp_data] = ppi_name
                    temp_label = [1]
                    self.ppi_label_list.append(temp_label)
                    ppi_name += 1
        
      
        i = 0
        for ppi in tqdm(self.ppi_dict.keys()):
            name = self.ppi_dict[ppi]
            assert name == i
            i += 1
            temp = ppi.strip().split('__')
            self.ppi_list.append(temp)


        ppi_num = len(self.ppi_list)
        self.origin_ppi_list = copy.deepcopy(self.ppi_list)
        assert len(self.ppi_list) == len(self.ppi_label_list)
        for i in tqdm(range(ppi_num)):
            seq1_name = self.ppi_list[i][0]
            seq2_name = self.ppi_list[i][1]
            # print(len(self.protein_name))
            self.ppi_list[i][0] = self.protein_name[seq1_name]
            self.ppi_list[i][1] = self.protein_name[seq2_name]
        
        if graph_undirection:
            for i in tqdm(range(ppi_num)):
                temp_ppi = self.ppi_list[i][::-1]
                temp_ppi_label = self.ppi_label_list[i]
                # if temp_ppi not in self.ppi_list:
                self.ppi_list.append(temp_ppi)
                self.ppi_label_list.append(temp_ppi_label)

        self.node_num = len(self.protein_name)
        self.edge_num = len(self.ppi_list)
        print(self.ppi_list)
    
    def get_protein_aac(self, pseq_path):
        # aac: amino acid sequences

        self.pseq_path = pseq_path
        self.pseq_dict = {}
        self.protein_len = []

        for line in tqdm(open(self.pseq_path)):
            line = line.strip().split('\t')
            if line[0] not in self.pseq_dict.keys():
                self.pseq_dict[line[0]] = line[1]
                self.protein_len.append(len(line[1]))
        
        print("protein num: {}".format(len(self.pseq_dict)))
        print("protein average length: {}".format(np.average(self.protein_len)))
        print("protein max & min length: {}, {}".format(np.max(self.protein_len), np.min(self.protein_len)))

    def embed_normal(self, seq, dim):
        if len(seq) > self.max_len:
            return seq[:self.max_len]
        elif len(seq) < self.max_len:
            less_len = self.max_len - len(seq)
            return np.concatenate((seq, np.zeros((less_len, dim))))
        return seq

    def vectorize(self, vec_path):
        self.acid2vec = {}
        self.dim = None
        for line in open(vec_path):
            line = line.strip().split('\t')
            temp = np.array([float(x) for x in line[1].split()])
            self.acid2vec[line[0]] = temp
            if self.dim is None:
                self.dim = len(temp)
        print("acid vector dimension: {}".format(self.dim))

        self.pvec_dict = {}

        for p_name in tqdm(self.pseq_dict.keys()):
            temp_seq = self.pseq_dict[p_name]
            temp_vec = []
            for acid in temp_seq:
                temp_vec.append(self.acid2vec[acid])
            temp_vec = np.array(temp_vec)

            temp_vec = self.embed_normal(temp_vec, self.dim)

            self.pvec_dict[p_name] = temp_vec

    def get_feature_origin(self, pseq_path, vec_path):
        self.get_protein_aac(pseq_path)

        self.vectorize(vec_path)

        self.protein_dict = {}
        for name in tqdm(self.protein_name.keys()):
            self.protein_dict[name] = self.pvec_dict[name]

    def get_connected_num(self):
        self.ufs = UnionFindSet(self.node_num)
        ppi_ndary = np.array(self.ppi_list)
        for edge in ppi_ndary:
            start, end = edge[0], edge[1]
            self.ufs.union(start, end)

    def generate_data(self):
        self.get_connected_num()

        print("Connected domain num: {}".format(self.ufs.count))

        ppi_list = np.array(self.ppi_list)
        ppi_label_list = np.array(self.ppi_label_list)

        self.edge_index = torch.tensor(ppi_list, dtype=torch.long)
        self.edge_attr = torch.tensor(ppi_label_list, dtype=torch.long)
        self.x = []
        i = 0
        for name in self.protein_name:
            assert self.protein_name[name] == i
            i += 1
            self.x.append(self.protein_dict[name])
        
        self.x = np.array(self.x)
        self.x = torch.tensor(self.x, dtype=torch.float)

        self.data = Data(x=self.x, edge_index=self.edge_index.T, edge_attr_1=self.edge_attr)
    
    def split_dataset(self, train_valid_index_path, test_size=0.3, random_new=False, mode='random'):




        # self.ppi_split_dict['bs_pair_index'],self.ppi_split_dict['bs_pair_label']
        # self.ppi_split_dict['es_pair_index'],self.ppi_split_dict['es_pair_label']
        # self.ppi_split_dict['ns_pair_index'],self.ppi_split_dict['ns_pair_label']

        graph = self.data
        ppi_list = self.ppi_list

        if random_new:
            if mode == 'random':
                ppi_num = int((self.edge_num)/2)
                random_list = [i for i in range(ppi_num)]
                random.shuffle(random_list)

                self.ppi_split_dict = {}
                self.ppi_split_dict['train_index'] = random_list[: int(ppi_num * (1-test_size))]
                non_train_index = random_list[int(ppi_num * (1 - test_size)):]
                self.ppi_split_dict['non_train_index'] = non_train_index
                jsobj = json.dumps(self.ppi_split_dict)
                with open(train_valid_index_path, 'w') as f:
                    f.write(jsobj)
                    f.close()
                graph.train_mask = self.ppi_split_dict['train_index']

            elif mode == 'bfs' or mode == 'dfs':
                print("use {} methed split train and valid dataset".format(mode))
                node_to_edge_index = {}
                edge_num = int(self.edge_num // 2)
                for i in range(edge_num):
                    edge = self.ppi_list[i]
                    if edge[0] not in node_to_edge_index.keys():
                        node_to_edge_index[edge[0]] = []
                    node_to_edge_index[edge[0]].append(i)

                    if edge[1] not in node_to_edge_index.keys():
                        node_to_edge_index[edge[1]] = []
                    node_to_edge_index[edge[1]].append(i)
                
                node_num = len(node_to_edge_index)

                sub_graph_size = int(edge_num * test_size)
                if mode == 'bfs':
                    selected_edge_index = get_bfs_sub_graph(self.ppi_list, node_num, node_to_edge_index, sub_graph_size)
                elif mode == 'dfs':
                    selected_edge_index = get_dfs_sub_graph(self.ppi_list, node_num, node_to_edge_index, sub_graph_size)
                
                all_edge_index = [i for i in range(edge_num)]

                unselected_edge_index = list(set(all_edge_index).difference(set(selected_edge_index)))
                graph.train_mask = unselected_edge_index
                self.ppi_split_dict = {}
                self.ppi_split_dict['train_index'] = unselected_edge_index
                self.ppi_split_dict['non_train_index'] = selected_edge_index
                non_train_index = selected_edge_index
                assert len(unselected_edge_index) + len(selected_edge_index) == edge_num

            
            else:
                print("your mode is {}, you should use bfs, dfs or random".format(mode))
                return
        else:
            with open(train_valid_index_path, 'r') as f:
                self.ppi_split_dict = json.load(f)
                f.close()


        edge_index_train = graph.edge_index[:, graph.train_mask]
        train_pos_count = edge_index_train.shape[1]

        # if 'valid_index' in self.ppi_split_dict:
        #     graph.val_mask = self.ppi_split_dict['valid_index']

        edge_index_non_train = graph.edge_index[:, non_train_index]

        temp = graph.edge_index.transpose(0, 1).numpy()
        ppi_list = []

        for edge in temp:
            ppi_list.append(list(edge))

        truth_edge_num = len(ppi_list) // 2
        node_num = graph.x.shape[0]
        print('non_train_edge_num ',truth_edge_num )
        

        #print("train gnn, train_num: {}, valid_num: {}".format(len(graph.train_mask), len(graph.val_mask)))
        node_vision_dict = {}
        for index in graph.train_mask:
            ppi = ppi_list[index]
            if ppi[0] not in node_vision_dict.keys():
                node_vision_dict[ppi[0]] = 1
            if ppi[1] not in node_vision_dict.keys():
                node_vision_dict[ppi[1]] = 1
        for node_id in range(node_num):
            if node_id not in node_vision_dict:
                node_vision_dict[node_id] = 0
            
            
        vision_num = 0
        unvision_num = 0
        for node in node_vision_dict:
            if node_vision_dict[node] == 1:
                vision_num += 1
            elif node_vision_dict[node] == 0:
                unvision_num += 1
        print("vision node num: {}, unvision node num: {}".format(vision_num, unvision_num))

        bs_pair_index = []  # both seen
        es_pair_index = []  # either seen
        ns_pair_index = []  # neither seen
        bs_neg = []
        ns_neg = []
        es_neg = []
        train_neg = []

        neg_edge_index = negative_sampling(
            edge_index=graph.edge_index,
            num_nodes=graph.x.shape[0],
                            num_neg_samples=20000
)
        temp1 = edge_index_non_train.transpose(0, 1).numpy()
        temp2 = neg_edge_index.transpose(0,1).numpy()
        ppi_list_non_train = []
        ppi_list_neg = []
        for edge in temp1:
            ppi_list_non_train.append(list(edge))
        for edge in temp2:
            ppi_list_neg.append(list(edge))
        for pair in ppi_list_non_train:
            temp = node_vision_dict.get(pair[0], 0) + node_vision_dict.get(pair[1], 0)
            if temp == 2:
                bs_pair_index.append(pair)
            elif temp == 1:
                es_pair_index.append(pair)
            elif temp == 0:
                ns_pair_index.append(pair)

        
        if mode == 'random':
            self.ppi_split_dict['bs_pair_index'] = torch.tensor(bs_pair_index).transpose(0, 1)
            bs_pos_count = self.ppi_split_dict['bs_pair_index'].shape[1]
        else:
            bs_pos_count = 0
        self.ppi_split_dict['es_pair_index'] = torch.tensor(es_pair_index).transpose(0, 1)
        self.ppi_split_dict['ns_pair_index'] = torch.tensor(ns_pair_index).transpose(0, 1)

        
        

        es_pos_count = self.ppi_split_dict['es_pair_index'].shape[1]
        ns_pos_count = self.ppi_split_dict['ns_pair_index'].shape[1]

        bs_neg_num = 0
        es_neg_num = 0
        ns_neg_num = 0
        train_neg_num =0
   
        for pair in ppi_list_neg:
            temp = node_vision_dict.get(pair[0], 0) + node_vision_dict.get(pair[1], 0)
            if temp == 2:
                if bs_neg_num < bs_pos_count:
                    bs_neg.append(pair)
                    bs_neg_num  += 1
                else:
                    if train_neg_num < train_pos_count:
                        train_neg.append(pair)
                        train_neg_num +=1
            elif temp == 1:
                if es_neg_num < es_pos_count:
                    es_neg.append(pair)
                    es_neg_num  += 1
            elif temp == 0:
                if ns_neg_num < ns_pos_count:
                    ns_neg.append(pair)
                    ns_neg_num  += 1
        if mode == 'random':
            bs_neg = torch.tensor(bs_neg).transpose(0, 1)
        es_neg = torch.tensor(es_neg).transpose(0, 1)
        ns_neg = torch.tensor(ns_neg).transpose(0, 1)
        train_neg = torch.tensor(train_neg).transpose(0, 1)

        # import pdb
        # pdb.set_trace()
        
        if mode == 'random':
            self.ppi_split_dict['bs_pair_index'] = torch.cat([self.ppi_split_dict['bs_pair_index'], bs_neg], dim=1)
            bs_pos_label = torch.ones(bs_pos_count)
            bs_neg_label = torch.zeros(bs_neg.shape[1])
            self.ppi_split_dict['bs_pair_label'] = torch.cat([bs_pos_label, bs_neg_label], dim=0).view(-1, 1)
            print(f"BS - Positive samples: {bs_pos_count}, Negative samples: {bs_neg.shape[1]}, Ratio: {bs_pos_count/(bs_pos_count+bs_neg.shape[1]):.3f}")
        self.ppi_split_dict['es_pair_index'] = torch.cat([self.ppi_split_dict['es_pair_index'], es_neg], dim=1)
        self.ppi_split_dict['ns_pair_index'] = torch.cat([self.ppi_split_dict['ns_pair_index'], ns_neg], dim=1)



        es_pos_label = torch.ones(es_pos_count)
        es_neg_label = torch.zeros(es_neg.shape[1])
        ns_pos_label = torch.ones(ns_pos_count)
        ns_neg_label = torch.zeros(ns_neg.shape[1])


        self.ppi_split_dict['es_pair_label'] = torch.cat([es_pos_label, es_neg_label], dim=0).view(-1, 1)
        self.ppi_split_dict['ns_pair_label'] = torch.cat([ns_pos_label, ns_neg_label], dim=0).view(-1, 1)

        print(graph.train_mask,'s')

        self.ppi_split_dict['train_pair_index'] = torch.cat([edge_index_train, train_neg], dim=1)
        self.ppi_split_dict['train_pair_label'] = torch.cat([torch.ones(train_pos_count), torch.zeros(train_neg.shape[1])], dim=0).view(-1, 1)

        print(f"\n=== Data Balance Diagnosis ===")
        print(f"Train - Positive: {train_pos_count}, Negative: {train_neg.shape[1]}, Ratio: {train_pos_count/(train_pos_count+train_neg.shape[1]):.3f}")
        print(f"ES - Positive: {es_pos_count}, Negative: {es_neg.shape[1]}, Ratio: {es_pos_count/(es_pos_count+es_neg.shape[1]):.3f}")
        print(f"NS - Positive: {ns_pos_count}, Negative: {ns_neg.shape[1]}, Ratio: {ns_pos_count/(ns_pos_count+ns_neg.shape[1]):.3f}")
        print(f"===================\n")
        # import pdb
        # pdb.set_trace()

        jsobj = json.dumps(tensor_to_serializable(self.ppi_split_dict))
        with open(train_valid_index_path, 'w') as f:
            f.write(jsobj)
            f.close()


    