import argparse
import numpy as np
import time
import torch
from dgl import DGLGraph, add_self_loop
from module import GCN, GAT, SGC
from sklearn.metrics import accuracy_score
import pickle as pkl
import os
from scipy import io as sio
from sklearn.model_selection import train_test_split


def main(args, data, i):
    meters = {'acc': 0, 'iter': 0}
    train_idx = data.train_idx
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if args.model == 'gcn':
        model = GCN(data.n_feats, args.d_hidden, data.n_class, dropout=0.1, n_layers=args.n_layers, alpha=args.alpha,
                     size_neg=args.negative_node, use_resnet=args.resnet).float()
    elif args.model == 'gat':
        model = GAT(data.n_feats, args.d_hidden, data.n_class, dropout=0.1, n_layers=args.n_layers, alpha=args.alpha,
                     size_neg=args.negative_node, use_resnet=args.resnet).float()
    else:
        model = SGC(data.n_feats, args.d_hidden, data.n_class, dropout=0.1, n_layers=args.n_layers, alpha=args.alpha,
                     size_neg=args.negative_node, use_resnet=args.resnet).float()
    if use_cuda:
        model.cuda()
    if args.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.95, weight_decay=args.weight_decay)
        print('use momemtum sgd')
    print("start training...")
    model.train()
    t0 = time.time()
    for epoch in range(1, args.epochs+1):
        optimizer.zero_grad()
        loss, clf_loss, prediction = model(data, train_idx=graph.train_idx)
        logits = torch.max(prediction, dim=1)[1]
        (loss + clf_loss).backward()
        optimizer.step()
        acc = save_best_model(model, data, meters, i)
        model.train()
        if epoch % 50 == 0:
            train_acc = accuracy_score(data.labels[train_idx].cpu().numpy(), logits.cpu().numpy())
            print("Epoch {:05d} | Time(s) {:.4f} | Train Accuracy: {:.4f} | Val Accuracy: {:.4f} | Contra Loss: {:.4f} | Clf Loss: {:.4f}"
                  .format(epoch, time.time()-t0, train_acc, acc, loss, clf_loss))
            t0 = time.time()
        if meters['iter'] >= args.patience:
            break
        meters['iter'] += 1
    model.load_state_dict(torch.load('./best_model_{}.pt'.format(i)).state_dict())
    model.eval()
    test_idx = data.test_idx
    _, _, prediction = model(data, train_idx=graph.test_idx)
    logits = torch.max(prediction, dim=1)[1]
    test_acc = accuracy_score(data.labels[test_idx].cpu().numpy(), logits.cpu().numpy())
    print("====>Final Test Accuracy: {:.4f}".format(test_acc))


def save_best_model(model, data, meters, i):
    model.eval()
    val_idx = data.val_idx
    _,_, prediction = model(data, train_idx=graph.val_idx)
    logits = torch.max(prediction, dim=1)[1].cpu().numpy()
    val_acc = accuracy_score(data.labels[val_idx].cpu().numpy(), logits)
    if val_acc >= meters['acc']:
        meters['acc'] = val_acc
        torch.save(model, './best_model_{}.pt'.format(i))
        meters['iter'] = 0
    return val_acc


def load_graph(args, index=0, p=0):
    graph = DGLGraph()
    if os.path.exists('./data/{}/preprocessed_data.mat'.format(args.graph)):
        data = sio.loadmat('./data/{}/preprocessed_data.mat'.format(args.graph))
        feats = data['feats']
        labels = data['labels'].reshape(-1,)
        n_nodes = feats.shape[0]
        edge_src = data['edge_list'][:, 0]
        edge_dst = data['edge_list'][:, 1]
        train_idx = data['train_idx'][index]
        test_idx = data['test_idx'][index]
        val_idx = data['val_idx'][index]
    else:
        feats = pkl.load(open('./data/{}/'.format(args.graph) + args.graph + ".x.pkl", 'rb'))
        labels = pkl.load(open('./data/{}/'.format(args.graph) + args.graph + ".y.pkl", 'rb'))
        n_nodes = feats.shape[0]
        with open('./data/{}/'.format(args.graph) + args.graph + '.edgelist', 'r') as f:
            next(f)
            edge_list = np.array(list(map(lambda x: x.strip().split(' '), f.readlines())), dtype=np.int)
            edge_src = edge_list[:, 0]
            edge_dst = edge_list[:, 1]
        train, test, val = [], [], []
        for i in range(5):
            train_index, test_index, _, _ = train_test_split(range(feats.shape[0]), labels, test_size=0.87, random_state=8)
            test.append(test_index)
            train_index, val_index, _, _ = train_test_split(range(len(train_index)), labels[train_index], test_size=0.7692, random_state=8)
            train.append(train_index)
            val.append(val_index)
        data = {'feats': feats, 'labels': labels, 'edge_list': edge_list, 'train_idx': train, 'test_idx': test,
                'val_idx': val}
        sio.savemat('./data/{}/preprocessed_data.mat'.format(args.graph), data)
        train_idx = data['train_idx'][index]
        test_idx = data['test_idx'][index]
        val_idx = data['val_idx'][index]
    if p != 1:
        if not os.path.exists('./data/{}/mask_idx_{}.mat'.format(args.graph, p*100)):
            m = feats.shape[1]
            mask_idx = np.random.permutation(np.arange(m))[:int(m*p)]
            mask_data = {'mask_idx': mask_idx}
            sio.savemat('./data/{}/mask_idx_{}.mat'.format(args.graph, p*100), mask_data)
        else:
            mask_idx = sio.loadmat('./data/{}/mask_idx_{}.mat'.format(args.graph, p*100))['mask_idx']
    else:
        mask_idx = []
    graph.add_nodes(n_nodes)
    graph.add_edges(edge_src, edge_dst)
    graph.n_class = np.unique(labels).shape[0]
    graph = add_self_loop(graph)
    graph.train_idx = train_idx
    graph.test_idx = test_idx
    if len(mask_idx) != 0:
        feats[:, mask_idx[0]] = 1
    graph.n_feats = feats.shape[1]
    graph.feats = torch.FloatTensor(feats)
    graph.labels = torch.LongTensor(labels)
    graph.adj = graph.adjacency_matrix()
    graph.val_idx = val_idx
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(args.gpu)
        graph = graph.to(torch.device('cuda:{}'.format(args.gpu)))
        graph.feats = graph.feats.cuda().float()
        graph.labels = graph.labels.cuda()
        graph.adj = graph.adj.cuda()
    return graph


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='TGCL')
    parser.add_argument("--dropout", type=float, default=0.1, help="dropout probability")
    parser.add_argument("-graph", type=str, default='cora', help="graph in the target domain")
    parser.add_argument("-model", type=str, default='gcn', help="the base model (gcn, gat, sgc), default is gcn")
    parser.add_argument("-d_hidden", type=int, default=50, help="number of hidden units")
    parser.add_argument("-gpu", type=int, default=0, help="gpu")
    parser.add_argument("-resnet", action='store_true', help="use resnet or not")
    parser.add_argument("-lr", type=float, default=5e-4, help="learning rate")
    parser.add_argument("-n_layers", type=int, default=10, help="number of layers") # 2
    parser.add_argument("-patience", type=int, default=500, help="number of iterations more after getting the best model")
    parser.add_argument("-e", "--epochs", type=int, default=5000, help="number of training epochs")
    parser.add_argument("-weight_decay", type=float, default=0, help="weight decay rate")
    parser.add_argument("-alpha", type=float, default=0.1, help="contrastive loss coef")
    parser.add_argument("-negative_node", type=int, default=250, help="contrastive loss coef")
    parser.add_argument("-optim", type=str, default='adam', help="optimizer")
    args = parser.parse_args()
    print(args)

    for i in range(5):
        print('Round {}'.format(i+1))
        graph = load_graph(args, i)
        main(args, graph, i)

