import argparse
import numpy as np
import time
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import dgl
from gat import GAT
from loss import subgraph_contrastive_loss
from utils import load_network_data, get_train_data, random_planetoid_splits
import warnings
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression


warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser(description='GAT')
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--epochs", type=int, default=2000)
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--num-hidden", type=int, default=128)
parser.add_argument("--tau", type=float, default=1)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--in-drop", type=float, default=0.6)
parser.add_argument("--attn-drop", type=float, default=0.5)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--negative-slope', type=float, default=0.2)
parser.add_argument('--alpha', type=float, default=1)


args = parser.parse_args()
print(args)


random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

adj2, features, Y = load_network_data(args.dataset)
features[features > 0] = 1
g = dgl.from_scipy(adj2)
if args.gpu >= 0 and torch.cuda.is_available():
    cuda = True
    g = g.int().to(args.gpu)
else:
    cuda = False

features = torch.FloatTensor(features.todense())
f = open(args.dataset + '_result.txt', 'a+')
f.write('\n\n{}\n'.format(args))
f.flush()
save_path = './checkpoint_{}/{}_{}_{}_{}_{}.pkl'.format(args.dataset, args.dataset, args.num_hidden, args.lr, args.tau, args.seed)

labels = np.argmax(Y, 1)
adj = torch.tensor(adj2.todense())

all_time = time.time()
num_feats = features.shape[1]
n_classes = Y.shape[1]
n_edges = g.number_of_edges()
num_nodes = features.shape[0]

g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)

model = GAT(g,
            1,
            num_feats,
            args.num_hidden,
            [3],
            F.elu,
            args.in_drop,
            args.attn_drop,
            args.negative_slope)

if cuda:
    features = features.cuda()
    adj = adj.cuda()
    model.cuda()

b_xent = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

dur = []
test_acc = 0

counter = 0
min_train_loss = 100
early_stop_counter = 100
best_t = -1

lbl1 = torch.ones(1, num_nodes)
lbl2 = torch.zeros(1, num_nodes)
lbl = torch.cat((lbl1, lbl2), 1).cuda()

for epoch in range(args.epochs):

    if epoch >= 0:
        t0 = time.time()
    model.train()
    optimizer.zero_grad()

    heads, logits, sub_emb = model(features, adj)

    loss = subgraph_contrastive_loss(heads, sub_emb, adj, args.tau) + args.alpha * b_xent(logits, lbl)

    loss.backward()
    optimizer.step()

    model.eval()
    
    with torch.no_grad():

        heads, logits, sub_emb = model(features, adj)
        
        loss_train = subgraph_contrastive_loss(heads, sub_emb, adj, args.tau) + args.alpha * b_xent(logits, lbl)

    if loss_train < min_train_loss:
        counter = 0
        min_train_loss = loss_train
        best_t = epoch
        torch.save(model.state_dict(), save_path)
    else:
        counter += 1

    if counter >= early_stop_counter:
        print('early stop')
        break

    if epoch >= 0:
        dur.append(time.time() - t0)

    print("Epoch {:04d} | Time(s) {:.4f} | TrainLoss {:.4f} | Eearlystop {:03d} ".
          format(epoch + 1, np.mean(dur), loss_train.item(), counter))

print('Loading {}th epoch'.format(best_t))
model.load_state_dict(torch.load(save_path))
model.eval()
emb = []
with torch.no_grad():
    heads, logits, sub_emb = model(features, adj)
    heads.append(sub_emb)
emb_res = heads[1: ]

emb_res = torch.cat(emb_res, axis=1)
emb_res = emb_res.detach().cpu()

numRandom = 20
train_num = 20
    
Accuaracy_test_allK = []
AccuaracyAll = []

for i in range(numRandom):

    if args.dataset in ['cora', 'citeseer', 'pubmed']:
        val_num = 500
        idx_train, idx_val, idx_test = random_planetoid_splits(n_classes, torch.tensor(labels), train_num, i)
    else:
        val_num = 30
        idx_train, idx_val, idx_test = get_train_data(Y, train_num, val_num, i)


    train_embs = emb_res[idx_train, :]
    val_embs = emb_res[idx_val, :]
    test_embs = emb_res[idx_test, :]

    best_val_score = 0.0
    for param in [1e-4, 1e-3, 1e-2, 0.1, 1, 10, 100]:
        LR = LogisticRegression(solver='liblinear', multi_class='ovr', random_state=0, C=param)
        LR.fit(train_embs, labels[idx_train])
        val_score = LR.score(val_embs, labels[idx_val])
        if val_score > best_val_score:
            best_val_score = val_score
            best_parameters = {'C': param}

    LR_best = LogisticRegression(solver='liblinear', multi_class='ovr', random_state=0, **best_parameters)

    LR_best.fit(train_embs, labels[idx_train])
    y_pred_test = LR_best.predict(test_embs)

    test_acc = accuracy_score(labels[idx_test], y_pred_test)
    AccuaracyAll.append(test_acc)

average_acc = np.mean(AccuaracyAll) * 100
std_acc = np.std(AccuaracyAll) * 100
print('Acc: %.1f +/- %.1f' % (average_acc, std_acc))
f.write('Acc: %.1f +/- %.1f\n' % (average_acc, std_acc))
f.flush()

Accuaracy_test_allK.append(average_acc)

f.close()
