import argparse
import itertools
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling
import numpy as np
import torch_geometric.transforms as T
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from logger import Logger, ProductionLogger
from utils import get_dataset, do_edge_split
from models import MLP, GCN, SAGE, LinkPredictor
from torch_sparse import SparseTensor
from sklearn.metrics import *
from os.path import exists
from torch_cluster import random_walk
from torch.nn.functional import cosine_similarity
import torch_geometric
from train_teacher_gnn import test_transductive, test_production
from utils import analyze
import scipy.sparse as ssp
from torch_geometric.data import Data
from get_heuristic import CN

dir_path  = '/home/qinzongyue/HeaRT/'

import matplotlib.pyplot as plt
import seaborn as sns

def compute_jaccard_similarity(method_outputs):
    """
    Compute Jaccard similarity matrix for a dictionary of method outputs.

    Args:
        method_outputs (dict): Dictionary where key is method name and value is the output tensor (binary values).

    Returns:
        torch.Tensor: Jaccard similarity matrix.
        list: List of method names in order.
    """
    method_names = list(method_outputs.keys())
    method_values = torch.stack(list(method_outputs.values()))  # Stack tensors for similarity computation

    num_methods = len(method_names)
    jaccard_matrix = torch.zeros((num_methods, num_methods))

    for i in range(num_methods):
        for j in range(num_methods):
            intersection = torch.sum((method_values[i].bool() & method_values[j].bool()).float())
            union = torch.sum((method_values[i].bool() | method_values[j].bool()).float())
            jaccard_matrix[i, j] = intersection / union if union > 0 else 0.0

    return jaccard_matrix, method_names

def compute_cosine_similarity(method_outputs):
    """
    Compute cosine similarity matrix for a dictionary of method outputs.
    
    Args:
        method_outputs (dict): Dictionary where key is method name and value is the output tensor.

    Returns:
        torch.Tensor: Cosine similarity matrix.
        list: List of method names in order.
    """
    method_names = list(method_outputs.keys())
    method_values = torch.stack(list(method_outputs.values()))  # Stack tensors for similarity computation

    # Normalize tensors to unit vectors
    method_values_normalized = method_values / method_values.norm(dim=1, keepdim=True)

    # Compute cosine similarity
    similarity_matrix = torch.mm(method_values_normalized, method_values_normalized.T)

    return similarity_matrix, method_names

def plot_heatmap(similarity_matrix, method_names):
    """
    Plot a heatmap of the cosine similarity matrix.

    Args:
        similarity_matrix (torch.Tensor): Cosine similarity matrix.
        method_names (list): List of method names corresponding to the matrix.
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(similarity_matrix.numpy(), xticklabels=method_names, yticklabels=method_names, annot=True, cmap="coolwarm")
    plt.title("Cosine Similarity Heatmap")


def read_data(data_name, neg_mode):
    data_name = data_name

    node_set = set()
    train_pos, valid_pos, test_pos = [], [], []
    train_neg, valid_neg, test_neg = [], [], []

    for split in ['train', 'test', 'valid']:

        if neg_mode == 'equal':
            path = dir_path+'/dataset' + '/{}/{}_pos.txt'.format(data_name, split)

        elif neg_mode == 'all':
            path = dir_path+'/dataset' + '/{}/allneg/{}_pos.txt'.format(data_name, split)

        for line in open(path, 'r'):
            sub, obj = line.strip().split('\t')
            sub, obj = int(sub), int(obj)
            
            node_set.add(sub)
            node_set.add(obj)
            
            if sub == obj:
                continue

            if split == 'train': 
                train_pos.append((sub, obj))
                

            if split == 'valid': valid_pos.append((sub, obj))  
            if split == 'test': test_pos.append((sub, obj))
    
    num_nodes = len(node_set)
    print('the number of nodes in ' + data_name + ' is: ', num_nodes)

    for split in ['test', 'valid']:

        if neg_mode == 'equal':
            path = dir_path+'/dataset' + '/{}/{}_neg.txt'.format(data_name, split)

        elif neg_mode == 'all':
            path = dir_path+'/dataset' + '/{}/allneg/{}_neg.txt'.format(data_name, split)

        for line in open(path, 'r'):
            sub, obj = line.strip().split('\t')
            sub, obj = int(sub), int(obj)
            # if sub == obj:
            #     continue
            
            if split == 'valid': 
                valid_neg.append((sub, obj))
               
            if split == 'test': 
                test_neg.append((sub, obj))

    train_edge = torch.transpose(torch.tensor(train_pos), 1, 0)
    edge_index = torch.cat((train_edge,  train_edge[[1,0]]), dim=1)
    edge_weight = torch.ones(edge_index.size(1))


    A = ssp.csr_matrix((edge_weight.view(-1), (edge_index[0], edge_index[1])), shape=(num_nodes, num_nodes)) 

    adj = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes])
          

    train_pos_tensor = torch.tensor(train_pos)

    valid_pos = torch.tensor(valid_pos)
    valid_neg =  torch.tensor(valid_neg)

    test_pos =  torch.tensor(test_pos)
    test_neg =  torch.tensor(test_neg)

    idx = torch.randperm(train_pos_tensor.size(0))
    idx = idx[:valid_pos.size(0)]
    train_val = train_pos_tensor[idx]


    feature_embeddings = torch.load(dir_path+'/dataset' + '/{}/{}'.format(data_name, 'gnn_feature'))
    feature_embeddings = feature_embeddings['entity_embedding']

    data = {}
    data['A'] = A
    data['adj'] = adj
    data['train_pos'] = train_pos_tensor
    data['train_val'] = train_val

    data['valid_pos'] = valid_pos
    data['valid_neg'] = valid_neg
    data['test_pos'] = test_pos
    data['test_neg'] = test_neg

    data['x'] = feature_embeddings

    return data



def cosine_loss(s, t):
    return 1-cosine_similarity(s, t.detach(), dim=-1).mean()

def kl_loss(s,t,T):
    y_s = F.log_softmax(s/T, dim=-1)
    y_t = F.softmax(t/T, dim=-1)
    loss = F.kl_div(y_s,y_t,size_average=False) * (T**2) / y_s.size()[0]
    return loss

def neighbor_samplers(row, col, sample, x, step, ps_method, ns_rate, hops, device='cuda'):
    batch = sample

    if ps_method == 'rw':
        pos_batch = random_walk(row, col, batch, walk_length=step*hops,
                                coalesced=False)
    elif ps_method == 'nb':
        pos_batch = None
        for i in range(step):
            if pos_batch is None:
                pos_batch = random_walk(row, col, batch, walk_length=hops, coalesced=False)
            else:
                pos_batch = torch.cat((pos_batch, random_walk(row, col, batch, walk_length=hops,coalesced=False)[:,1:]), 1)

    neg_batch = torch.randint(0, x.size(0), (batch.numel(), step*hops*ns_rate),
                                  dtype=torch.long)

    if device == 'cuda':
        return pos_batch.to("cuda"), neg_batch.to("cuda")
    else:
        return pos_batch, neg_batch

def train_minibatch(model, predictor, t_pos_pred, t_neg_pred, pos_nb, neg_nb, data, split_edge, optimizer, args, device):
    
    if args.transductive == "transductive":
        pos_train_edge = split_edge['train']['edge']
        row, col = data.adj_t
    else:
        pos_train_edge = data.edge_index.t()
        row, col = data.edge_index

    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    mse_loss = nn.MSELoss()
    bce_loss = nn.BCELoss()
    margin_rank_loss = nn.MarginRankingLoss(margin=args.margin)

    total_loss = total_examples = 0

    node_loader = iter(DataLoader(range(data.x.size(0)), args.node_batch_size, shuffle=True))
    pos_cnt = t_pos_pred.size(1)
    neg_cnt = t_neg_pred.size(1)
    for link_perm in DataLoader(range(pos_train_edge.size(0)), args.link_batch_size, shuffle=True):
        optimizer.zero_grad()

        node_perm = next(node_loader)

        edge = pos_train_edge[link_perm].t()

        if args.datasets != "collab":
            neg_edge = negative_sampling(edge_index, num_nodes=data.x.size(0),
                                 num_neg_samples=link_perm.size(0), method='dense')
        elif args.datasets == "collab":
            neg_edge = torch.randint(0, data.x.size()[0], [edge.size(0), edge.size(1)], dtype=torch.long)

        train_edges = torch.cat((edge, neg_edge), dim=-1).to(device)

        src = train_edges[0]
        dst = train_edges[1]

        # sampled the neary nodes and randomely sampled nodes 
#        sample_step = args.rw_step
#        pos_sample, neg_sample = neighbor_samplers(row, col, node_perm, data.x, sample_step, args.ps_method, args.ns_rate, args.hops)
        pos_idx = torch.randint(pos_cnt, (sample_step*args.hops,))
        pos_samples = pos_nb[node_perm]
        pos_samples = pos_samples[:, pos_idx+1]
        neg_idx = torch.randint(neg_cnt, (sample_step*args.hops*args.ns_rate,))
        neg_samples = neg_nb[node_perm]
        neg_samples = neg_samples[:, neg_idx]
        samples = torch.cat((pos_sample, neg_sample), 1)
        this_target = torch.cat((torch.reshape(samples, (-1,)), src, dst), 0)
        h = model(data.x[this_target.to("cpu")].to(device))

        ### calculate the distribution based matching loss 
        for_loss = torch.reshape(h[:samples.size(0) * samples.size(1)], (samples.size(0), samples.size(1), args.hidden_channels))
        src_h = h[samples.size(0) * samples.size(1): samples.size(0) * samples.size(1)+src.size(0)]
        dst_h = h[samples.size(0) * samples.size(1)+src.size(0): ]

        batch_emb = torch.reshape(for_loss[:,0,:], (samples[:,0].size(0), 1, h.size(1))).repeat(1,sample_step*args.hops*(1+args.ns_rate),1)
#        t_emb = torch.reshape(t_h[samples[:,0]].to(device), (samples[:,0].size(0), 1, t_h.size(1))).repeat(1,sample_step*args.hops*(1+args.ns_rate),1)
        s_r = predictor(batch_emb, for_loss[:,1:,:])
#        t_r = teacher_predictor(t_emb, t_h[samples[:, 1:]].to(device))
        pos_t_r = t_pos_pred[node_perm]
        pos_t_r = pos_t_r[:, pos_idx]
        neg_t_r = t_neg_pred[node_perm]
        neg_t_r = neg_t_r[:, neg_idx]
        t_r = torch.cat((pos_t_r, neg_t_r), dim=1)
        llp_d_loss = kl_loss(torch.reshape(s_r, (s_r.size()[0], s_r.size()[1])), torch.reshape(t_r, (t_r.size()[0], t_r.size()[1])), 1)

        #### calculate the rank based matching loss
        rank_loss = torch.tensor(0.0).to(device)
        sampled_nodes = [l_i for l_i in range(sample_step*args.hops*(1+args.ns_rate))]
        dim_pairs = [x for x in itertools.combinations(sampled_nodes, r=2)]
        dim_pairs = np.array(dim_pairs).T
        teacher_rank_list = torch.zeros((len(t_r), dim_pairs.shape[1],1)).to(device)
                    
        mask = t_r[:, dim_pairs[0]] > (t_r[:, dim_pairs[1]] + args.margin)
        teacher_rank_list[mask] = 1
        mask2 = t_r[:, dim_pairs[0]] < (t_r[:, dim_pairs[1]] - args.margin)
        teacher_rank_list[mask2] = -1
        first_rank_list = s_r[:, dim_pairs[0]].squeeze()
        second_rank_list = s_r[:, dim_pairs[1]].squeeze()
        llp_r_loss = margin_rank_loss(first_rank_list, second_rank_list, teacher_rank_list.squeeze())
        

        train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(h.device)
        out = predictor(src_h, dst_h).squeeze()
        label_loss = bce_loss(out, train_label)
       
        if args.LLP_D or args.LLP_R:
            loss = args.True_label * label_loss + args.LLP_D * llp_d_loss + args.LLP_R * llp_r_loss

        loss.backward()

        torch.nn.utils.clip_grad_norm_(data.x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = edge.size(1)
        total_loss += loss.item() * num_examples
        total_examples += num_examples
        
    return total_loss / total_examples


def get_samples(model, predictor, data, split_edge,
                         args, device):
    if args.transductive == "transductive":
        pos_train_edge = split_edge['train']['edge'].to(data.x.device)
        row, col = data.adj_t
    else:
        pos_train_edge = data.edge_index.t()
        row, col = data.edge_index

    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    mse_loss = nn.MSELoss()
    bce_loss = nn.BCELoss()
    margin_rank_loss = nn.MarginRankingLoss(margin=args.margin)

    total_loss = total_examples = 0

    node_loader = iter(DataLoader(range(data.x.size(0)), args.node_batch_size, shuffle=True))
    all_pos_samples = []
    all_neg_samples = []
    for link_perm in DataLoader(range(pos_train_edge.size(0)), args.link_batch_size, shuffle=True):

        node_perm = next(node_loader).to(data.x.device)

        h = model(data.x)

        edge = pos_train_edge[link_perm].t()

        if args.LLP_R or args.LLP_D:
            sample_step = args.rw_step
            # sampled the neary nodes and randomely sampled nodes
            pos_sample, neg_sample = neighbor_samplers(row, col, node_perm, data.x, 100*sample_step, args.ps_method, args.ns_rate, args.hops, 'cpu')

            ### calculate the distribution based matching loss 
            all_pos_samples.append(pos_sample.cpu())
            all_neg_samples.append(neg_sample.cpu())
    return all_pos_samples, all_neg_samples


def train(model, predictor, t_pos_pred, t_neg_pred, pos_nb, neg_nb, data, split_edge,
                         optimizer, args, device):

    if args.transductive == "transductive":
        pos_train_edge = split_edge['train']['edge'].to(data.x.device)
        row, col = data.adj_t
    else:
        pos_train_edge = data.edge_index.t()
        row, col = data.edge_index

    edge_index = torch.stack([col, row], dim=0)

    model.train()
    predictor.train()

    mse_loss = nn.MSELoss()
    bce_loss = nn.BCELoss()
    margin_rank_loss = nn.MarginRankingLoss(margin=args.margin)

    total_loss = total_examples = 0

    node_loader = iter(DataLoader(range(data.x.size(0)), args.node_batch_size, shuffle=True))
    pos_cnt = t_pos_pred.size(1)
    neg_cnt = t_neg_pred.size(1)

    for link_perm in DataLoader(range(pos_train_edge.size(0)), args.link_batch_size, shuffle=True):
        optimizer.zero_grad()

        node_perm = next(node_loader).to(data.x.device)

        h = model(data.x)

        edge = pos_train_edge[link_perm].t()

        if args.LLP_R or args.LLP_D:
            sample_step = args.rw_step
            # sampled the neary nodes and randomely sampled nodes
            # pos_sample, neg_sample = neighbor_samplers(row, col, node_perm, data.x, sample_step, args.ps_method, args.ns_rate, args.hops)
            pos_idx = torch.randint(pos_cnt, (sample_step*args.hops,))
#            pos_idx = torch.LongTensor(list(range(sample_step*args.hops)))
#            pos_sample = pos_nb[node_perm]
            anchor_nodes = pos_nb[:,0:1] 
            pos_sample = pos_nb[:, pos_idx+1]
            neg_idx = torch.randint(neg_cnt, (sample_step*args.hops*args.ns_rate,))
 #           neg_idx = torch.LongTensor(list(range(sample_step *args.hops)))

#            neg_sample = neg_nb[node_perm]
            neg_sample = neg_nb[:, neg_idx]

            ### calculate the distribution based matching loss 
            samples = torch.cat((anchor_nodes, pos_sample, neg_sample), 1)
#            print(samples.size())
            samples = samples[node_perm]
 #           print(samples.size())
            batch_emb = torch.reshape(h[samples[:,0]], (samples[:,0].size(0), 1, h.size(1))).repeat(1,sample_step*args.hops*(1+args.ns_rate),1)
            #t_emb = torch.reshape(t_h[samples[:,0]], (samples[:,0].size(0), 1, h.size(1))).repeat(1,sample_step*args.hops*(1+args.ns_rate),1)
            s_r = predictor(batch_emb, h[samples[:, 1:]])
            #t_r = teacher_predictor(t_emb, t_h[samples[:, 1:]])
  #          print(t_pos_pred.size())
            pos_t_r = t_pos_pred[node_perm]
   #         print(pos_t_r.size())

            pos_t_r = pos_t_r[:, pos_idx]
    #        print(pos_t_r.size())
            neg_t_r = t_neg_pred[node_perm]
            neg_t_r = neg_t_r[:, neg_idx]
            t_r = torch.cat((pos_t_r, neg_t_r), dim=1)

            llp_d_loss = kl_loss(torch.reshape(s_r, (s_r.size()[0], s_r.size()[1])), torch.reshape(t_r, (t_r.size()[0], t_r.size()[1])), 1)

            #### calculate the rank based matching loss
            rank_loss = torch.tensor(0.0).to("cuda")
            sampled_nodes = [l_i for l_i in range(sample_step*args.hops*(1+args.ns_rate))]
            dim_pairs = [x for x in itertools.combinations(sampled_nodes, r=2)]
            dim_pairs = np.array(dim_pairs).T
            teacher_rank_list = torch.zeros((len(t_r), dim_pairs.shape[1],1)).to(t_r.device)
                      
            mask = t_r[:, dim_pairs[0]] > (t_r[:, dim_pairs[1]] + args.margin)
            teacher_rank_list[mask] = 1
            mask2 = t_r[:, dim_pairs[0]] < (t_r[:, dim_pairs[1]] - args.margin)
            teacher_rank_list[mask2] = -1
            first_rank_list = s_r[:, dim_pairs[0]].squeeze()
            second_rank_list = s_r[:, dim_pairs[1]].squeeze()
            llp_r_loss = margin_rank_loss(first_rank_list, second_rank_list, teacher_rank_list.squeeze())

        if args.datasets != "collab":
            neg_edge = negative_sampling(edge_index, num_nodes=data.x.size(0),
                                 num_neg_samples=link_perm.size(0), method='dense')
        elif args.datasets == "collab":
            neg_edge = torch.randint(0, data.x.size()[0], [edge.size(0), edge.size(1)], dtype=torch.long, device=h.device)

        ### calculate the true_label loss
        train_edges = torch.cat((edge, neg_edge), dim=-1)
        train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(h.device)
        out = predictor(h[train_edges[0]], h[train_edges[1]]).squeeze()
        label_loss = bce_loss(out, train_label)
       
        #t_out = teacher_predictor(t_h[train_edges[0]], t_h[train_edges[1]]).squeeze().detach()

        if args.LLP_D or args.LLP_R:
            loss = args.True_label * label_loss + args.LLP_D * llp_d_loss + args.LLP_R * llp_r_loss
        else:
            loss = args.True_label * label_loss 

        loss.backward()

        torch.nn.utils.clip_grad_norm_(data.x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = edge.size(1)
        total_loss += loss.item() * num_examples
        total_examples += num_examples
        
    return total_loss / total_examples

def parse_args():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('--distill_pred_path', type=str, default='None')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--log_steps', type=int, default=1)
    parser.add_argument('--encoder', type=str, default='sage')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--link_batch_size', type=int, default=64*1024)
    parser.add_argument('--node_batch_size', type=int, default=64*1024)
    parser.add_argument('--lr', type=float, default=0.005)
    parser.add_argument('--epochs', type=int, default=20000)
    parser.add_argument('--eval_steps', type=int, default=5)
    parser.add_argument('--runs', type=int, default=10)
    parser.add_argument('--dataset_dir', type=str, default='../data')
    parser.add_argument('--datasets', type=str, default='collab')
    parser.add_argument('--predictor', type=str, default='mlp', choices=['inner','mlp'])
    parser.add_argument('--patience', type=int, default=100, help='number of patience steps for early stopping')
    parser.add_argument('--metric', type=str, default='Hits@20', choices=['auc', 'hits@20', 'hits@50'], help='main evaluation metric')
    parser.add_argument('--use_valedges_as_input', action='store_true')
    parser.add_argument('--True_label', default=0.1, type=float, help="true_label loss")
    parser.add_argument('--KD_RM', default=0, type=float, help="Representation-based matching KD") 
    parser.add_argument('--KD_LM', default=0, type=float, help="logit-based matching KD") 
    parser.add_argument('--LLP_D', default=1, type=float, help="distribution-based matching kd")
    parser.add_argument('--LLP_R', default=1, type=float, help="rank-based matching kd") 
    parser.add_argument('--margin', default=0.1, type=float, help="margin for rank-based kd") 
    parser.add_argument('--rw_step', type=int, default=3, help="nearby nodes sampled times")
    parser.add_argument('--ns_rate', type=int, default=1, help="randomly sampled rate over # nearby nodes") 
    parser.add_argument('--hops', type=int, default=2, help="random_walk step for each sampling time")
    parser.add_argument('--ps_method', type=str, default='nb', help="positive sampling is rw or nb")
    parser.add_argument('--transductive', type=str, default='transductive', choices=['transductive', 'production'])
    parser.add_argument('--minibatch', action='store_true')

    args = parser.parse_args()
    return args


def main(args, states, dataset = None): 
    if dataset is not None:
        args.datasets = dataset
    args.runs = len(states)
    print(args)


    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    #device = 'cpu'
    device = torch.device(device)
    mini_batch_device = 'cpu'

    ### Prepare the datasets
    if args.transductive == "transductive":
        if args.datasets != "collab":
            """
            dataset = get_dataset(args.dataset_dir, args.datasets)
            data = dataset[0]

            if exists("../data/" + args.datasets + ".pkl"):
                split_edge = torch.load("../data/" + args.datasets + ".pkl")
            else:
                split_edge = do_edge_split(dataset)
                torch.save(split_edge, "../data/" + args.datasets + ".pkl")
            
            edge_index = split_edge['train']['edge'].t()
            data.adj_t = edge_index
            input_size = data.x.size()[1]
            """
            args.metric = 'Hits@20'

        elif args.datasets == "collab":
            """
            dataset = PygLinkPropPredDataset(name='ogbl-collab')
            data = dataset[0]
            edge_index = data.edge_index
            data.edge_weight = data.edge_weight.view(-1).to(torch.float)
            data = T.ToSparseTensor()(data)

            split_edge = dataset.get_edge_split()
            input_size = data.num_features
            data.adj_t = edge_index
            """
            args.metric = 'Hits@50'

        # Use training + validation edges for inference on test set.
        """
        if args.use_valedges_as_input:
            val_edge_index = split_edge['valid']['edge'].t()
            full_edge_index = torch.cat([edge_index, val_edge_index], dim=-1)
            if args.datasets != "collab":
                data.full_adj_t = full_edge_index
            elif args.datasets == "collab":
                data.full_adj_t = SparseTensor.from_edge_index(full_edge_index).t()
                data.full_adj_t = data.full_adj_t.to_symmetric()
        else:
            data.full_adj_t = data.adj_t

        if args.minibatch:
            data = data.to(mini_batch_device)
        else:
            data = data.to(device)

        args.node_batch_size = int(data.x.size()[0] / (split_edge['train']['edge'].size()[0] / args.link_batch_size))
        """

    else:
        training_data, val_data, inference_data, data, test_edge_bundle, negative_samples = torch.load("../data/" + args.datasets + "_production.pkl")
        input_size = training_data.x.size(1)

        if args.minibatch:
            training_data.to(mini_batch_device)
        else:
            training_data.to(device)
        val_data.to(device)
        inference_data.to(device)

        args.node_batch_size = int(training_data.x.size()[0] / (training_data.edge_index.size(1) / args.link_batch_size))

    heart_data = read_data(args.datasets, 'equal')
    new_data = Data(x=heart_data['x'], adj_t = heart_data['train_pos'].T)
    data = new_data.to(device)
    input_size = data.x.size()[1]
    split_edge = {}
    split_edge['train'] = {'edge': heart_data['train_pos']}
    split_edge['valid'] = {'edge': heart_data['valid_pos'], 'edge_neg': heart_data['valid_neg']}
    split_edge['test'] = {'edge': heart_data['test_pos'], 'edge_neg': heart_data['test_neg']}
    evaluator = Evaluator(name='ogbl-ddi')


    model = MLP(args.num_layers, input_size, args.hidden_channels, args.hidden_channels, args.dropout).to(device)

    predictor = LinkPredictor(args.predictor, args.hidden_channels, args.hidden_channels, 1,
                              args.num_layers, args.dropout).to(device)

 #   print("computing CN...")
    cn = CN(heart_data['A'], heart_data['test_pos'].T)
    print(cn.min(), cn.max(), cn.mean())
 #   print(heart_data['test_pos'].size())
 #   print(cn.size())
    low_thres = torch.quantile(cn, 0.3)
    high_thres = torch.quantile(cn, 0.7)
#    print('done')
    valid_pos_num = heart_data['valid_pos'].size(0)
    valid_neg_num = heart_data['valid_neg'].size(0)
 #   print('valid pos num', valid_pos_num)
 #   print('valid neg num', valid_neg_num)
    all_hits = []
    ks = [0,10,20,50]

    for run in range(args.runs):
#        print(run)
        state = states[run]
        model.load_state_dict(state['model'])
        predictor.load_state_dict(state['predictor'])

        torch_geometric.seed.seed_everything(run+1)
        
        if args.transductive == "transductive":
                
                results, h, pos_pred, pos_edge, neg_pred, neg_edge = test_transductive(model, predictor, data, split_edge,
                            evaluator, args.link_batch_size, 'mlp', args.datasets, args, detail = True)
                pos_pred_test = pos_pred[valid_pos_num:]
                neg_pred_test = neg_pred[valid_neg_num:]
                hits = []
                for k in ks:
                    if k > 0:
                        thres = torch.topk(neg_pred_test, k)[0][-1]
                        hit = (pos_pred_test > thres).float()
                        hits.append(hit)
                    else:
                        sorted_neg, _ = torch.sort(-neg_pred_test)
                        rankings = torch.searchsorted(sorted_neg, -pos_pred_test, right=False)
                        rankings_2 = torch.searchsorted(sorted_neg, -pos_pred_test, right=True)
                        rankings = (rankings+rankings_2)/2 + 1
                 #       print(rankings)
#                        print((1/rankings.float()).mean())

                        y_pred_pos = pos_pred_test.view(-1,1)
                        y_pred_neg = neg_pred_test.repeat(pos_pred_test.size(0), 1).view(y_pred_pos.size(0), -1)

                        optimistic_rank = (y_pred_neg >= y_pred_pos).sum(dim=1)
    # pessimistic rank: "how many negatives have a larger score than the positive?"
    # ~> the positive is ranked last among those with equal score
                        pessimistic_rank = (y_pred_neg > y_pred_pos).sum(dim=1)
                        ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
                  #      print(ranking_list)
#                        print((1/ranking_list.float()).mean())
#                        xxx = input("pause")





                        hits.append((1/rankings.float()))
                hits = torch.stack(hits, dim=0)
                all_hits.append(hits)

#                print(neg_pred_test.size())
#                print(pos_pred_test.size())
#                xxx = input("pause")
        else:
                results, h = test_production(model, predictor, val_data, inference_data, test_edge_bundle, negative_samples,
                        evaluator, args.link_batch_size, 'mlp', args.datasets)
            
    all_hits = torch.stack(all_hits, dim=0)
    print(all_hits.size())
    avg_hits = torch.mean(all_hits, dim=0)
    print(f"For low CN pairs (thres={low_thres}")
    mask = cn <= low_thres
    for i,k in enumerate(ks):
        hits = avg_hits[i][mask]
        print(f'HITS@{k}: {hits.mean().item()} ({mask.float().sum().int().item()})')

    print(f"For median CN pairs")
    mask = torch.logical_and(cn > low_thres, cn <= high_thres)
    for i,k in enumerate(ks):
        hits = avg_hits[i][mask]
        print(f'HITS@{k}: {hits.mean().item()}({mask.float().sum().int().item()})')


    print(f"For high CN pairs (thres={high_thres}")
    mask = cn > high_thres
    for i,k in enumerate(ks):
        hits = avg_hits[i][mask]
        print(f'HITS@{k}: {hits.mean().item()}({mask.float().sum().int().item()})')

    return avg_hits[0]


select_key = 'Hits@20' 

args = parse_args()
for dataset in [args.datasets]:
    teacher_hits_dict = {}
    for distill_teacher in ['none', 'GCN', 'SAGE', 'GAT', 'random_GCN', 'random_SAGE', 'random_GAT', 'CN', 'AA', 'RA', 'shortest_path', 'katz_close']:
#    for distill_teacher in ['none', 'AA', 'RA']:
#    for distill_teacher in ['GCN', 'SAGE', 'GAT']:

        if distill_teacher == 'none':
            states = torch.load(f'saved_students/{dataset}_mlp_model.pth')
            new_states = []
            for state in states:
                new_state = {}
                for key, value in state['model'].items():
                    # Rename keys: replace "lins." with "layers."
                   if key.startswith("lins"):
                        new_key = key.replace("lins", "layers")
                   else:
                        new_key = key
                   new_state[new_key] = value
                state['model'] = new_state

                new_state = {}
                for key, value in state['predictor'].items():
                    # Rename keys: replace "lins." with "layers."
                   if key.startswith("lins") and False:
                        new_key = key.replace("lins", "layers")
                   else:
                        new_key = key
                   new_state[new_key] = value
                state['predictor'] = new_state

                new_states.append(state)
#                print(new_states)
#                xxx = input("pause")
            states = new_states
        else:
           states = torch.load(f'saved_students/{dataset}_{distill_teacher}_MLPs.pth')
        print(dataset, distill_teacher)

        hits = main(args, states, dataset)
        teacher_hits_dict[distill_teacher] = (hits>0.5).float()
    #TODO compute a cosine similarity matrix and draw heatmap
    sim_mat, method_names = compute_jaccard_similarity(teacher_hits_dict)
    print(method_names)
    print(sim_mat)
#    plot_heatmap(sim_mat, method_names)

    # update all_results
# print out best results
