import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import sys
sys.path.append('/home/qinzongyue/HeaRT/benchmarking/')
from gnn_model import GCN, GAT, SAGE
from scoring import mlp_score
from models import MLP, LinkPredictor
from get_heuristic import *
from utils import get_dataset, do_edge_split, read_data
import torch_geometric.transforms as T
import torch
from torch_geometric.data import Data
import argparse
from os.path import exists
import scipy.sparse as ssp

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

def load_model(dataset, teacher, is_mlp):
    if is_mlp == True:
        # load student
        pass
    else:
        # load GNN
        path = f'saved_students/{dataset}_{teacher}.pth'
        states = torch.load(path, map_location='cpu')
        state = states[-1]
        input_channel = state
        hidden_channels = 0 
        num_layers = 0
        num_layers_predictor = 0
        dropout = 0
        gin_mlp_layer = None
        gat_head = 1
        node_num = None
        cat_node_feat_mf = None
        model = eval(teacher)(input_channel, hidden_channels,
                    hidden_channels, num_layers, dropout, gin_mlp_layer, gat_head, node_num, cat_node_feat_mf)
    
        score_func = mlp_score(hidden_channels, hidden_channels,
                    1, num_layers_predictor, dropout).to(device)

    #return model, link_predictor


def get_data(args, device = 'cpu'):
    if args.transductive == "transductive":
        if args.datasets not in ["collab", "citation2"]:
            if args.datasets in ['cora', 'citeseer', 'pubmed']:
                heart_data = read_data(args.datasets, 'equal')
                new_data = Data(x=heart_data['x'], adj_t = heart_data['train_pos'].T)

                data = new_data.to(device)
                input_size = data.x.size()[1]
                split_edge = {}
                split_edge['train'] = {'edge': heart_data['train_pos']}
                split_edge['valid'] = {'edge': heart_data['valid_pos'], 'edge_neg': heart_data['valid_neg']}
                split_edge['test'] = {'edge': heart_data['test_pos'], 'edge_neg': heart_data['test_neg']}
            else:
            
                dataset = get_dataset(args.dataset_dir, args.datasets)
                data = dataset[0]

                if exists("../data/" + args.datasets + ".pkl"):
                    split_edge = torch.load("../data/" + args.datasets + ".pkl")
                else:
                    assert False
            
                edge_index = split_edge['train']['edge'].t()
                data.adj_t = edge_index
                input_size = data.x.size()[1]
                #data = data.to(device)
            args.metric = 'Hits@20'
            data = data.to(device)
        else: # collab or citation2
            #TODO special treatment for citation2, including data loading and testing
            
            dataset = PygLinkPropPredDataset(name=f'ogbl-{args.datasets}')
            data = dataset[0]
            edge_index = data.edge_index
            if hasattr(data, "edge_weight") and data.edge_weight is not None:
                data.edge_weight = data.edge_weight.view(-1).to(torch.float)
            data = T.ToSparseTensor()(data)

            split_edge = dataset.get_edge_split()
            input_size = data.num_features
            data.adj_t = edge_index
            
            if args.datasets == 'collab':
                args.metric = 'Hits@50'
            else:
                args.metric = 'MRR'
        data = data.to(device)

    if args.use_valedges_as_input:
        val_edge_index = split_edge['valid']['edge'].t()
        full_edge_index = torch.cat([edge_index, val_edge_index], dim=-1)
    else:
        data.full_adj_t = data.adj_t

    return data, split_edge

def load_mlp(args, teacher, input_size):
    model_list = []
    runs = 5
    dataset = args.datasets
    load_num_layers = args.num_layers
    load_hidden_channels = args.hidden_channels
    if teacher == 'none':
        states = torch.load(f'saved_students/{dataset}_mlp_model.pth', map_location='cpu')
        new_states = []
        for state in states:
            new_state = {}
            for key, value in state['model'].items():
                # Rename keys: replace "lins." with "layers."
               if key.startswith("lins"):
                    new_key = key.replace("lins", "layers")
               else:
                    new_key = key
               new_state[new_key] = value
            state['model'] = new_state

            new_state = {}
            for key, value in state['predictor'].items():
                # Rename keys: replace "lins." with "layers."
               if key.startswith("lins") and False:
                    new_key = key.replace("lins", "layers")
               else:
                    new_key = key
               new_state[new_key] = value
            state['predictor'] = new_state

            new_states.append(state)
#                print(new_states)
#                xxx = input("pause")
        states = new_states
    else:
        try:
            states = torch.load(f'saved_students/{dataset}_{teacher}_MLPs.pth',
                           map_location='cpu')
        except:
            states = torch.load(f'saved_students/{dataset}_{teacher}_MLPs_2.pth',
                               map_location='cpu')

    model = MLP(load_num_layers, input_size, load_hidden_channels, load_hidden_channels, 0)

    predictor = LinkPredictor(args.predictor, load_hidden_channels, load_hidden_channels, 1,
                              load_num_layers, 0)
    return model, predictor, states

def get_teacher_prediction(teacher, data, pos_test_edge, neg_test_edge):
    if teacher in ['CN', 'AA', 'RA', 'capped_shortest_path']:
        edge_index = data.full_adj_t
        edge_weight = torch.ones(edge_index.size(1))
        num_nodes = data.x.size(0)
        A = ssp.csr_matrix((edge_weight.view(-1), (edge_index[0], edge_index[1])), shape=(num_nodes, num_nodes)) 

        pos_test_pred = eval(teacher)(A, pos_test_edge.T)
        neg_test_pred = eval(teacher)(A, neg_test_edge.T)

    else:
        raise NotImplementedError
    return pos_test_pred, neg_test_pred

@torch.no_grad()
def get_mlp_prediction(model, predictor, state, data, pos_edge, neg_edge, device, batch_size = 65536):
    model.load_state_dict(state['model'])
    predictor.load_state_dict(state['predictor'])
    model = model.to(device)
    predictor = predictor.to(device)
    model.eval()
    predictor.eval()

    h = model(data.x.to("cuda"))


    pos_preds = []
    for perm in DataLoader(range(pos_edge.size(0)), batch_size):
        edge = pos_edge[perm].t()
        pos_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_pred = torch.cat(pos_preds, dim=0)

    neg_preds = []
    for perm in DataLoader(range(neg_edge.size(0)), batch_size):
        edge = neg_edge[perm].t()
        neg_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_pred = torch.cat(neg_preds, dim=0)

    return pos_pred, neg_pred


def compute_ratio(teacher_pred, mlp_pred, k=10):
    pos_t_pred, neg_t_pred = teacher_pred
    pos_s_pred, neg_s_pred = mlp_pred
    # for each pred in pos_t_pred, compute HITS@K. If 
    ret = torch.topk(neg_t_pred, k)
    thres_t = torch.topk(neg_t_pred, k).values[-1].item()
    hits_t = (pos_t_pred > thres_t).float()
    print(hits_t.mean())

    thres_s = torch.topk(neg_s_pred, k).values[-1].item()
    hits_s = (pos_s_pred > thres_s).float()
    print(hits_s.mean())

    ratio = torch.sum(hits_t * hits_s) / torch.sum(hits_t)
    return ratio.item()

def get_overlap_ratio(pred1, pred2, k=3):
    pos_t_pred, neg_t_pred = pred1
    pos_s_pred, neg_s_pred = pred2
    # for each pred in pos_t_pred, compute HITS@K. If 
    ret = torch.topk(neg_t_pred, k)
    thres_t = torch.topk(neg_t_pred, k).values[-1].item()
    hits_t = (pos_t_pred > thres_t).float()

    thres_s = torch.topk(neg_s_pred, k).values[-1].item()
    hits_s = (pos_s_pred > thres_s).float()
    if k > 3:
        print(hits_t.mean())
        print(hits_s.mean())

    union = (hits_s + hits_t) > 0

    ratio = torch.sum(hits_t * hits_s) / torch.sum(union.float())
    return ratio.item()

   



def parse_args():
    parser = argparse.ArgumentParser(description='OGBL-DDI (GNN)')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--dataset_dir', type=str, default='../data')
    parser.add_argument('--teacher', type=str, default='CN')
    parser.add_argument('--datasets', type=str, default='collab')
    parser.add_argument('--transductive', type=str, default='transductive', choices=['transductive', 'production'])
    parser.add_argument('--minibatch', action='store_true')
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden_channels', type=int, default=256)
    parser.add_argument('--use_valedges_as_input', action='store_true')
    parser.add_argument('--predictor', type=str, default='mlp', choices=['inner','mlp'])

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    data, split_edge = get_data(args)
    print(split_edge.keys())
    input_size = data.x.size(1)

    mlp_model, mlp_predictor, mlp_states = load_mlp(args, 'none', input_size)
    student_model, student_predictor, student_states = load_mlp(args, args.teacher, input_size)

    pos_edge = split_edge['test']['edge']
    neg_edge = split_edge['test']['edge_neg']

    # get teacher prediction
    teacher_pred = get_teacher_prediction(args.teacher, data, pos_edge, neg_edge)
    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'

    # load MLP and distilled MLP 
    mlp_pred_list = []
    ratio_list = []
    for state in mlp_states:
        mlp_pred = get_mlp_prediction(mlp_model, mlp_predictor, state, data, pos_edge, neg_edge, device)
        mlp_pred_list.append(mlp_pred)
        ratio = compute_ratio(teacher_pred, mlp_pred)
        ratio_list.append(ratio)
    print(f'independent mlp ratio: {np.mean(ratio_list):.4f} ({np.std(ratio_list):.4f})')


    student_pred_list = []
    ratio_list = []
    for state in student_states:
        student_pred = get_mlp_prediction(student_model, student_predictor, state, data, pos_edge, neg_edge, device)
        student_pred_list.append(student_pred)
        ratio = compute_ratio(teacher_pred, student_pred)
        ratio_list.append(ratio)
    print(f'distilled mlp ratio: {np.mean(ratio_list):.4f} ({np.std(ratio_list):.4f})')

    teacher_list = ['CN', 'AA', 'RA', 'capped_shortest_path']
    pred_dict = {}
    for teacher in teacher_list:
        model, predictor, states = load_mlp(args, teacher, input_size)
        pred = get_mlp_prediction(model, predictor, states[0], data, pos_edge, neg_edge, device)
        pred_dict[teacher] = pred

    for i in range(4):
        for j in range(i, 4):
            print(teacher_list[i], teacher_list[j])
            ratio = get_overlap_ratio(pred_dict[teacher_list[i]],
                                      pred_dict[teacher_list[j]],
                                      k = 10)

            ratio = get_overlap_ratio(pred_dict[teacher_list[i]],
                                      pred_dict[teacher_list[j]])
            print(ratio)
       
if __name__ == '__main__':
    main()
