import argparse
import os
import time
import dgl
import torch
from dgl import DGLGraph
from dgl.data import register_data_args
import scipy.sparse as sp
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

from src.dataloader import load_homo_data, load_heter_data, load_heter_data_new
from torch.utils.tensorboard import SummaryWriter

# from src.models import Graphite, SUBVGAE, VGAE, DGI, Classifier
from src.models.dgi import DGI, Classifier
from src.models import VGAE
from src.models.subgi_models_v2 import SubGI
from src.utils import mask_test_edges, generate_neg_edges
import torch.utils.data as tdata
import torch.nn as nn
import networkx as nx
# from src.models.subgi_models_task import SubGI

from IPython import embed

dir_path = os.path.dirname(os.path.realpath(__file__))
parent_path = os.path.abspath(os.path.join(dir_path, os.pardir))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def args_parser():
    parser = argparse.ArgumentParser()
    register_data_args(parser)
    # general args
    parser.add_argument('--model', type=str, default='gcn_vae', help="[dgi, gcn_vae, graphite]")
    parser.add_argument('--data-set', type=str, default='yago_sub_g', help="[yago_sub_g, cora, yago, yago_ko_0.8]")
    parser.add_argument('--pretrain', type=str, default=None, help="pretraining model path")
    parser.add_argument('--feature', type=str, default=None, help="feature path")

    parser.add_argument('--gpu', type=int, default=0, help='GPU ID.')
    parser.add_argument('--seed', type=int, default=42, help='Random seed.')
    parser.add_argument('--tensorboard', type=int, default=1, help='whether use tensorboard.')
    # model args
    parser.add_argument('--hidden1', type=int, default=256, help='Number of units in hidden layer 1.')
    parser.add_argument('--hidden2', type=int, default=128, help='Number of units in hidden layer 2.')
    parser.add_argument('--hidden3', type=int, default=128, help='Number of units in hidden layer 3.')
    parser.add_argument('--vae', type=int, default=1, help='1 for variational objective')
    parser.add_argument('--autoregressive_scalar', type=float, default=0.5, help='Scalar for Graphite.')
    # training args
    parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
    parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
    parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.')
    parser.add_argument('--test_interval', type=int, default=1, help='test interval.')
    # sampling args
    parser.add_argument("--n-layers", type=int, default=1, help="number of hidden gcn layers")
    parser.add_argument('--sampling_method', type=str, default='node', help='sampling method, [node,edge]')
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--self-loop", action='store_true', help="graph self-loop (default=False)")
    parser.add_argument("--n-neighbors", type=int, default=10, help="number of neighbors to be sampled")
    parser.add_argument("--decoder_batch_size", type=int, default=512, help="decoder graph sampler, batch size.")
    parser.add_argument("--edge_loader_batch_size", type=int, default=512, help="batch size for edge loader.")

    args = parser.parse_args()
    return args


def train(args):
    if args.tensorboard > 0:
        write_id = "{}_{}_lr:{}_{}".format(args.data_set, args.model, args.lr,
                                           'finetune' if args.pretrain is not None else 'pretrain')
        writer = SummaryWriter(comment=write_id)
    print("Using {}".format(args.data_set))
    # print('process id: {}'.format(os.getpid()))
    # 0.load data
    t = time.time()
    if args.data_set in ['cora']:
        adj_orig, orig_features, _ = load_homo_data(args.data_set, feature=args.feature, n_hidden=args.hidden1)
    elif args.data_set in ['top_2013']:
        # tODO: load data & use degree as feature
        adj_orig, orig_features, adj_train, _, _ = load_heter_data_new(args.data_set, n_hidden=256)
    else:
        adj_orig, orig_features, _ = load_heter_data(args.data_set, feature=args.feature,
                                                              n_hidden=args.hidden1, pretrain=args.pretrain is not None)

    feat_dim = orig_features.shape[1]
    print("Load data: {}s".format(time.time() - t))

    # 1.adj remove self loop
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()
    print('node num: {}'.format(adj_orig.shape[0]))
    print('edge num: {}'.format(adj_orig.sum()))
    # 2.sample pos/neg edge for test
    t = time.time()

    # if args.feature is None:
    # I tweak the function, val_edges serve as train, test_edges serve as validation
    #    adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj_orig)
    # else:
    #    test_edges_false = generate_neg_edges(adj_orig, test_edges, 1)
    #    adj_train = adj_orig
    adj_train = adj_orig
    # embed()
    # print("Number of train edges: {}, number of val edges: {}".format(len(val_edges), len(test_edges)))
    print("split data: {}s".format(time.time() - t))

    # g = nx.from_scipy_sparse_matrix(adj_train + sp.eye(adj_train.shape[0]))
    g = nx.from_scipy_sparse_matrix(adj_train)
    # TODO: use directed?
    g = g.to_directed()
    g = DGLGraph(g).to(device)
    g.readonly()
    orig_features = orig_features.to(device)
    print("process data & create DGL graph: {}s".format(time.time() - t))
    # embed()
    t = time.time()
    if args.model == 'graphite':
        model = Graphite(args, feat_dim, args.hidden1, args.hidden2, args.hidden3, args.dropout).to(device)
        pass
    elif args.model == 'gcn_vae':
        model = VGAE(args, feat_dim, args.hidden1, args.hidden2, args.dropout, pretrain=args.pretrain)
        model.prepare()
        model = model.to(device)
        model.g = g
        model.features = orig_features
        model.adj_train = adj_train
        model.train_sampler = dgl.contrib.sampling.NeighborSampler(g, args.decoder_batch_size, args.n_neighbors,  # 0,
                                                                   neighbor_type='in', num_workers=4,
                                                                   add_self_loop=True,
                                                                   num_hops=args.n_layers, shuffle=True)
        # embed()
    elif args.model == 'subgi':
        model = SubGI(g, feat_dim, args.hidden1, args.n_layers, torch.nn.PReLU(args.hidden1), args.dropout, 2).to(
            device)
        model.features = orig_features
    elif args.model == 'dgi':
        model = DGI(args, feat_dim, args.hidden1, n_layers=args.n_layers, activation=nn.PReLU(), dropout=args.dropout,
                    pretrain=args.pretrain)
        # warp fine-tune specs in prepare function
        model.prepare()
        model = model.to(device)
        g.ndata['features'] = orig_features
        g.ndata['nfeatures'] = orig_features
        model.g = g
        model.output_emb = torch.zeros((orig_features.shape[0], args.hidden1))
        model.classifier = Classifier(args.hidden1).to(device)
        model.test_sampler = dgl.contrib.sampling.NeighborSampler(g, args.batch_size, args.n_neighbors,  # 0,
                                                                  neighbor_type='in', num_workers=4,
                                                                  add_self_loop=True,
                                                                  num_hops=args.n_layers, shuffle=False)
    elif args.model == 'regression':
        model = LogisticRegressionModel(args, args.hidden1, node_num=adj_train.shape[0], feature=orig_features).to(
            device)
    elif args.model == 'edgepred':
        pass
    else:
        exit(1)
    print("create model: {}s".format(time.time() - t))

    t = time.time()
    roc_best, ap_best = 0.0, 0.0
    model.optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # model.preds = np.zeros((test_edges.shape[0] + test_edges_false.shape[0]))
    # model.train_dataloader = tdata.DataLoader(
    #    tdata.TensorDataset(torch.cat([torch.LongTensor(val_edges), torch.LongTensor(val_edges_false)]),
    #                        torch.cat([torch.ones(val_edges.shape[0]), torch.zeros(val_edges_false.shape[0])])),
    #                        batch_size=args.batch_size, shuffle=True)
    # model.val_dataloader = tdata.DataLoader(
    #    tdata.TensorDataset(torch.cat([torch.LongTensor(test_edges), torch.LongTensor(test_edges_false)]),
    #                        torch.cat([torch.ones(test_edges.shape[0]), torch.zeros(test_edges_false.shape[0])])),
    #                        batch_size=args.batch_size, shuffle=False)
    print(model)
    # embed()
    # exit(1)
    for epoch in range(args.epochs):
        t = time.time()
        model.train_sampler = dgl.contrib.sampling.NeighborSampler(g, args.decoder_batch_size, args.n_neighbors,  # 0,
                                                                   neighbor_type='in', num_workers=4,
                                                                   add_self_loop=False,
                                                                   num_hops=args.n_layers + 1, shuffle=True)
        cur_loss = model.train_model()
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(np.array(cur_loss).mean()), "time=",
              "{:.5f}".format(time.time() - t))
        if args.tensorboard > 0:
            writer.add_scalar('Train_Loss', cur_loss, epoch)
        if False:
            if (epoch + 1) % args.test_interval == 0:
                roc_score, ap_score, mrr_score = model.test_model(test_edges, test_edges_false)
                if args.tensorboard > 0:
                    writer.add_scalar('Val_ROC', roc_score, epoch)
                    writer.add_scalar('Val_AP', ap_score, epoch)
                print(
                    'Test ROC score: {}; Test AP score: {}; Test MRR score: {}.'.format(roc_score, ap_score, mrr_score))

                if roc_score > roc_best or ap_score > ap_best:
                    roc_best = max(roc_best, roc_score)
                    ap_best = max(ap_best, ap_score)
                    print('<Best>: Test ROC score: {}; Test AP score: {}.'.format(roc_best, ap_best))
        if args.pretrain is None:
            print('Pre-train stops at epoch: {}'.format(epoch))
            torch.save(model.state_dict(), 'output/pretrain_new_{}_{}.pkl'.format(args.data_set, args.model))
        # break
    if args.tensorboard > 0:
        writer.close()


if __name__ == '__main__':
    torch.manual_seed(0)
    args = args_parser()
    train(args)