import argparse
import numpy as np
import time
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import dgl
from gat import GAT
from loss import subgraph_contrastive_loss
from utils import load_network_data
from cluster import cluster_test
import warnings


warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='GAT')
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--num-hidden", type=int, default=128)
parser.add_argument("--tau", type=float, default=1)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--in-drop", type=float, default=0.6)
parser.add_argument("--attn-drop", type=float, default=0.5)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--negative-slope', type=float, default=0.2)
parser.add_argument('--alpha', type=float, default=1)

args = parser.parse_args()
print(args)

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

f = open(args.dataset + '_c_result.txt', 'a+')
f.write('\n\n{}\n'.format(args))
f.flush()

adj2, features, Y = load_network_data(args.dataset)
features[features > 0] = 1
g = dgl.from_scipy(adj2)
if args.gpu >= 0 and torch.cuda.is_available():
    cuda = True
    g = g.int().to(args.gpu)
else:
    cuda = False

features = torch.FloatTensor(features.todense())

labels = np.argmax(Y, 1)
labels = torch.tensor(labels)
adj = torch.tensor(adj2.todense())

all_time = time.time()
num_feats = features.shape[1]
n_classes = Y.shape[1]
n_edges = g.number_of_edges()
num_nodes = features.shape[0]

g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)

model = GAT(g,
            1,
            num_feats,
            args.num_hidden,
            [3],
            F.elu,
            args.in_drop,
            args.attn_drop,
            args.negative_slope)

if cuda:
    features = features.cuda()
    adj = adj.cuda()
    model.cuda()

b_xent = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

dur = []
test_acc = 0

counter = 0
min_train_loss = 100
early_stop_counter = 100
max_nmi = -1
max_ari = -1

lbl1 = torch.ones(1, num_nodes)
lbl2 = torch.zeros(1, num_nodes)
lbl = torch.cat((lbl1, lbl2), 1).cuda()

for epoch in range(args.epochs):

    if epoch >= 0:
        t0 = time.time()
    model.train()
    optimizer.zero_grad()

    heads, logits, sub_emb = model(features, adj)

    loss = subgraph_contrastive_loss(heads, sub_emb, adj, args.tau) + args.alpha * b_xent(logits, lbl)

    loss.backward()
    optimizer.step()

    model.eval()
    
    with torch.no_grad():

        heads, logits, sub_emb = model(features, adj)
        
        loss_train = subgraph_contrastive_loss(heads, sub_emb, adj, args.tau) + args.alpha * b_xent(logits, lbl)

        heads.append(sub_emb)

    emb_res = heads[1: ]
    emb_es = heads[2: ]

    heads = torch.cat(heads, axis=1)
    emb_res = torch.cat(emb_res, axis=1)
    emb_es = torch.cat(emb_es, axis=1)

    emb_list = [heads, emb_res, emb_es]

    for emb in emb_list:

        nmi, ari = cluster_test(emb.detach().cpu(), torch.unique(labels).size()[0], labels.detach().cpu(), 12345)

        if nmi > max_nmi:
            max_nmi = nmi
        
        if ari > max_ari:
            max_ari = ari

    if loss_train < min_train_loss:
        counter = 0
        min_train_loss = loss_train
    else:
        counter += 1

    if counter >= early_stop_counter:
        print('early stop')
        break

    if epoch >= 0:
        dur.append(time.time() - t0)

f.write('NMI: %.4f | ARI: %.4f\n' % (max_nmi, max_ari))
f.flush()

f.close()

