import torch
import numpy as np
import sys, copy, math, time, pdb
import pickle as pickle
import scipy.io as sio
import scipy.sparse as ssp
import os.path
import random
import argparse
from util_functions import *
from torch_geometric.data import DataLoader
from model import Net

from sklearn.metrics import average_precision_score
import torch.optim as optim
from Evaluate import *

def loop_dataset_gem(data_name, result_dir, classifier, loader, optimizer=None,save_plot=False ):
    total_loss = []
    all_targets = []
    all_scores = []
    all_embeddings = []

    pbar = tqdm(loader, unit='batch')

    n_samples = 0
    for batch in pbar:
        all_targets.extend(batch.y.tolist())
        logits, loss, acc, embeddings = classifier(batch)
        all_scores.append(logits[:, 1].cpu().detach())
        all_embeddings.append(embeddings.cpu().detach())

        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        loss = loss.data.cpu().detach().numpy()
        
        pbar.set_description('loss: %0.5f acc: %0.5f' % (loss, acc) )
        total_loss.append( np.array([loss, acc]) * len(batch.y))
        
        n_samples += len(batch.y)

    total_loss = np.array(total_loss)
    avg_loss = np.sum(total_loss, 0) / n_samples
    all_scores = torch.cat(all_scores).cpu().numpy()
    all_embeddings = torch.cat(all_embeddings)
    
    all_targets = np.array(all_targets)
    avg_precision = average_precision_score(all_targets, all_scores)
    fpr, tpr, _ = metrics.roc_curve(all_targets, all_scores, pos_label=1)
    auc = metrics.auc(fpr, tpr)

    k_list  = [1, 3, 10, 100]
    hits_ogb_n = hits_at_n_ogb(all_scores, all_targets, k_list)
    
    mrr_obg_value=  evaluate_mrr_scaled(all_scores, all_targets)
    
    pltval = ''
    if optimizer==None:
        pltval='_val'
       
    if save_plot==True: 
        draw_TSNE_embeding_nodeclass(data_name, all_embeddings, all_targets, save_path=result_dir+'/acc_results'+data_name+pltval+'_tsne.png')
        save_data_Embedding(all_embeddings,all_targets,data_name+pltval, result_dir)
    
    avg_loss = np.concatenate((avg_loss, [auc, avg_precision, hits_ogb_n['Hits@1'],hits_ogb_n['Hits@3'],hits_ogb_n['Hits@10'],hits_ogb_n['Hits@100'], mrr_obg_value]))

    return avg_loss




cmd_opt = argparse.ArgumentParser(description='Argparser for graph_classification')
cmd_opt.add_argument('-mode', default='cpu', help='cpu/gpu')
cmd_opt.add_argument('-gm', default='DGCNN', help='gnn model to use')
cmd_opt.add_argument('-data', default=None, help='data folder name')
cmd_opt.add_argument('-batch_size', type=int, default=50, help='minibatch size')
cmd_opt.add_argument('-seed', type=int, default=1, help='seed')
cmd_opt.add_argument('-feat_dim', type=int, default=0, help='dimension of discrete node feature (maximum node tag)')
cmd_opt.add_argument('-edge_feat_dim', type=int, default=0, help='dimension of edge features')
cmd_opt.add_argument('-num_class', type=int, default=0, help='#classes')
cmd_opt.add_argument('-fold', type=int, default=1, help='fold (1..10)')
cmd_opt.add_argument('-test_number', type=int, default=0, help='if specified, will overwrite -fold and use the last -test_number graphs as testing data')
cmd_opt.add_argument('-num_epochs', type=int, default=1000, help='number of epochs')
cmd_opt.add_argument('-latent_dim', type=str, default='64', help='dimension(s) of latent layers')
cmd_opt.add_argument('-sortpooling_k', type=float, default=30, help='number of nodes kept after SortPooling')
cmd_opt.add_argument('-conv1d_activation', type=str, default='ReLU', help='which nn activation layer to use')
cmd_opt.add_argument('-out_dim', type=int, default=1024, help='graph embedding output size')
cmd_opt.add_argument('-hidden', type=int, default=100, help='dimension of mlp hidden layer')
cmd_opt.add_argument('-max_lv', type=int, default=4, help='max rounds of message passing')
cmd_opt.add_argument('-learning_rate', type=float, default=0.0001, help='init learning_rate')
cmd_opt.add_argument('-dropout', type=bool, default=False, help='whether add dropout after dense layer')
cmd_opt.add_argument('-printAUC', type=bool, default=False, help='whether to print AUC (for binary classification only)')
cmd_opt.add_argument('-extract_features', type=bool, default=False, help='whether to extract final graph features')

cmd_args, _ = cmd_opt.parse_known_args()

cmd_args.latent_dim = [int(x) for x in cmd_args.latent_dim.split('-')]
if len(cmd_args.latent_dim) == 1:
    cmd_args.latent_dim = cmd_args.latent_dim[0]

parser = argparse.ArgumentParser(description='Link Prediction')
# general settings
parser.add_argument('--mask',action='store_true', default=False, help='mask test data')
parser.add_argument('--result-dir', default='../results/MSLGLP', help='network name')
parser.add_argument('--data-name', default='BUP', help='network name')
parser.add_argument('--train-name', default=None, help='train name')
parser.add_argument('--test-name', default=None, help='test name')
parser.add_argument('--max-train-num', type=int, default=10000, 
                    help='set maximum number of train links (to fit into memory)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--test-ratio', type=float, default=0.5,
                    help='ratio of test links')
# model settings
parser.add_argument('--hop', default=2, metavar='S', 
                    help='enclosing subgraph hop number, \
                    options: 1, 2,..., "auto"')
parser.add_argument('--max-nodes-per-hop', default=100, 
                    help='if > 0, upper bound the # nodes per hop by subsampling')



args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)

random.seed(cmd_args.seed)
np.random.seed(cmd_args.seed) 
torch.manual_seed(cmd_args.seed)
if args.hop != 'auto':
    args.hop = int(args.hop)
if args.max_nodes_per_hop is not None:
    args.max_nodes_per_hop = int(args.max_nodes_per_hop)


'''Prepare data'''
args.file_dir = os.path.dirname(os.path.realpath('__file__'))
args.res_dir = os.path.join(args.file_dir, '../results/MSLGLP{}'.format(args.data_name))

if args.train_name is None:
    args.data_dir = os.path.join(args.file_dir, '../data/{}.mat'.format(args.data_name))
    data = sio.loadmat(args.data_dir)
    net = data['net']
    attributes = None
    # check whether net is symmetric (for small nets only)
    if False:
        net_ = net.toarray()
        assert(np.allclose(net_, net_.T, atol=1e-8))
    #Sample train and test links
    train_pos, train_neg, test_pos, test_neg = sample_neg(net, args.test_ratio, max_train_num=args.max_train_num)
else:
    args.train_dir = os.path.join(args.file_dir, '../data/{}'.format(args.train_name))
    args.test_dir = os.path.join(args.file_dir, '../data/{}'.format(args.test_name))
    train_idx = np.loadtxt(args.train_dir, dtype=int)
    test_idx = np.loadtxt(args.test_dir, dtype=int)
    max_idx = max(np.max(train_idx), np.max(test_idx))
    net = ssp.csc_matrix((np.ones(len(train_idx)), (train_idx[:, 0], train_idx[:, 1])), shape=(max_idx+1, max_idx+1))
    net[train_idx[:, 1], train_idx[:, 0]] = 1  # add symmetric edges
    net[np.arange(max_idx+1), np.arange(max_idx+1)] = 0  # remove self-loops
    #Sample negative train and test links
    train_pos = (train_idx[:, 0], train_idx[:, 1])
    test_pos = (test_idx[:, 0], test_idx[:, 1])
    train_pos, train_neg, test_pos, test_neg = sample_neg(net, train_pos=train_pos, test_pos=test_pos, max_train_num=args.max_train_num)


'''Train and apply classifier'''
A = net.copy()  # the observed network
print("args.mask",args.mask)
if args.mask:
    print("Masked Test Data......................")
    A[test_pos[0], test_pos[1]] = 0  # mask test links
    A[test_pos[1], test_pos[0]] = 0  # mask test links
else: 
    print("UnMasked Test Data*******************")
A.eliminate_zeros()
print("start links2subgraphs function...")
train_graphs, test_graphs, max_n_label = links2subgraphs(A, train_pos, train_neg, test_pos, test_neg, args.hop, args.max_nodes_per_hop, None)
print(('# train: %d, # test: %d' % (len(train_graphs), len(test_graphs))))
print("start subgraphs2multiscalessubgraphs function...")

train_multiscalegraphs, test_multiscalegraphs= subgraphs2multiscalessubgraphs(train_graphs, test_graphs)

# i have to convert train_multiscalegraphs from MultiScaleGNNGraph class to GNNGraph class and compute lables for each nodes of multiscaled graph
newtrain = [[],[],[]]#newtrain1,newtrain2,newtrain3
newtest = [[],[],[]]

for traindata in train_multiscalegraphs:
    for j in range(3):
        gadj= ConvertGraphToAdjMatrixFirstTargetNodes(traindata.multiscalegraphs[j] , 0,1)
        n_labels= node_label(gadj)
        max_n_label = max(max(n_labels), max_n_label)
        newtrain[j].append(GNNGraph(traindata.multiscalegraphs[j] , traindata.orginalGNNgraph.label, n_labels, traindata.orginalGNNgraph.node_features))

for testdata in test_multiscalegraphs:
    for j in range(3):
        gadj= ConvertGraphToAdjMatrixFirstTargetNodes(testdata.multiscalegraphs[j] , 0,1)
        n_labels= node_label(gadj)
        max_n_label = max(max(n_labels), max_n_label)
        newtest[j].append(GNNGraph(testdata.multiscalegraphs[j] , testdata.orginalGNNgraph.label, n_labels, testdata.orginalGNNgraph.node_features))

        
train_lines_temp=[[],[],[]]
test_lines_temp = [[],[],[]]
for j in range(3):
    train_lines_temp[j] = to_linegraphs(newtrain[j], max_n_label)

for j in range(3):
    test_lines_temp[j] = to_linegraphs(newtest[j], max_n_label)

    
    

train_lines= []
test_lines = []
n=len(train_lines_temp[0])
for j in range(n):
    train_lines.append(linegraphsToMSlinegraphs(train_lines_temp[0][j],train_lines_temp[1][j],train_lines_temp[2][j]))

n=len(test_lines_temp[0])
for j in range(n):
    test_lines.append(linegraphsToMSlinegraphs(test_lines_temp[0][j],test_lines_temp[1][j],test_lines_temp[2][j]))


print("okk")
# Model configurations

cmd_args.latent_dim = [32, 32, 32, 1]
cmd_args.hidden = 128
cmd_args.out_dim = 0
cmd_args.dropout = True
cmd_args.num_class = 2
cmd_args.mode = 'gpu'
cmd_args.num_epochs = 50
cmd_args.learning_rate = 5e-3
cmd_args.batch_size = 50
cmd_args.printAUC = True
cmd_args.feat_dim = (max_n_label + 1)*2
cmd_args.attr_dim = 0

train_loader = DataLoader(train_lines, batch_size=cmd_args.batch_size, shuffle=True)
test_loader = DataLoader(test_lines, batch_size=cmd_args.batch_size, shuffle=False)


classifier = Net(cmd_args.feat_dim, cmd_args.hidden, cmd_args.latent_dim, cmd_args.dropout)
if cmd_args.mode == 'gpu':
    classifier = classifier.to("cuda")

optimizer = optim.Adam(classifier.parameters(), lr=cmd_args.learning_rate)



best_auc = 0
best_auc_acc = 0
best_acc = 0
best_acc_auc = 0


save_plot=False
for epoch in range(cmd_args.num_epochs):
    if epoch==cmd_args.num_epochs-1: 
        save_plot= True

    classifier.train()
    avg_loss = loop_dataset_gem(args.data_name, args.result_dir,classifier, train_loader, optimizer=optimizer,save_plot=save_plot)
    if not cmd_args.printAUC:
        avg_loss[2] = 0.0
    print(('\033[92maverage training of epoch %d: loss %.5f acc %.5f auc %.5f ap %.5f hit1 %.5f hit3 %.5f hit10 %.5f hit100 %.5f mmr %.5f\033[0m' % (epoch, avg_loss[0], avg_loss[1], avg_loss[2], avg_loss[3],avg_loss[4],avg_loss[5],avg_loss[6],avg_loss[7],avg_loss[8])))

    classifier.eval()
    test_loss = loop_dataset_gem(args.data_name, args.result_dir,classifier, test_loader, None, save_plot)

    if not cmd_args.printAUC:
        test_loss[2] = 0.0
    print('\033[94maverage test of epoch %d: loss %.5f acc %.5f auc %.5f ap %.5f hit1 %.5f hit3 %.5f hit10 %.5f hit100 %.5f mmr %.5f\033[0m' % (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3],test_loss[4],test_loss[5],test_loss[6],test_loss[7],test_loss[8]))
    
    with open(args.result_dir+'/'+args.data_name+'train_loss_results.txt', 'a+') as f:
        f.write(str(avg_loss[0]) + '\n')
    with open(args.result_dir+'/'+args.data_name+'train_acc_results.txt', 'a+') as f:
        f.write(str(avg_loss[1]) + '\n')
    with open(args.result_dir+'/'+args.data_name+'train_auc_results.txt', 'a+') as f:
        f.write(str(avg_loss[2]) + '\n')
    with open(args.result_dir+'/'+args.data_name+'train_ap_results.txt', 'a+') as f:
        f.write(str(avg_loss[3]) + '\n')
    with open(args.result_dir+'/'+args.data_name+'all_results.txt', 'a+') as f:
        f.write(str(  '\033[92maverage training of epoch %d: loss %.5f acc %.5f auc %.5f ap %.5f hit1 %.5f hit3 %.5f hit10 %.5f hit100 %.5f mmr %.5f\033[0m' % (epoch, avg_loss[0], avg_loss[1], avg_loss[2], avg_loss[3],avg_loss[4],avg_loss[5],avg_loss[6],avg_loss[7],avg_loss[8])) + '\n')
        f.write(str(   '\033[94maverage test of epoch %d: loss %.5f acc %.5f auc %.5f ap %.5f hit1 %.5f hit3 %.5f hit10 %.5f hit100 %.5f mmr %.5f\033[0m' % (epoch, test_loss[0], test_loss[1], test_loss[2], test_loss[3],test_loss[4],test_loss[5],test_loss[6],test_loss[7],test_loss[8])) + '\n')


    if best_auc < test_loss[2]:
        best_auc = test_loss[2]
        best_auc_acc = test_loss[3]

    if best_acc < test_loss[3]:
        best_acc = test_loss[3]
        best_acc_auc = test_loss[2]

with open(args.result_dir+'/'+args.data_name+'acc_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[1]) + '\n')
with open(args.result_dir+'/'+args.data_name+'loss_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[0]) + '\n')
with open(args.result_dir+'/'+args.data_name+'auc_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[2]) + '\n')
with open(args.result_dir+'/'+args.data_name+'ap_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[3]) + '\n')
with open(args.result_dir+'/'+args.data_name+'hit1_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[4]) + '\n')
with open(args.result_dir+'/'+args.data_name+'hit3_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[5]) + '\n')
with open(args.result_dir+'/'+args.data_name+'hit10_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[6]) + '\n')
with open(args.result_dir+'/'+args.data_name+'hit100_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[7]) + '\n')
with open(args.result_dir+'/'+args.data_name+'mrr_results'+'.txt', 'a+') as f:
    f.write(str(test_loss[8]) + '\n')
