#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
sys.path.append('./codes/')
import time
from config import args

from utils import *
from models import GCN2 as GCN
from metrics import *

import torch
import torch.optim

args.dataset = 'syn4'


# In[ ]:


with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
    adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Some preprocessing
if args.normfea:
    features = preprocess_features(features)
support = preprocess_adj(adj,args.normadj)
model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.elr)

features_tensor = torch.Tensor(features).type(torch.float32)

i = torch.LongTensor([*support[0]])
v = torch.FloatTensor([*support[1]])
# LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))
support_tensor = support_tensor.type(torch.float32)

y_train_tensor = torch.Tensor(y_train).type(torch.float32)
train_mask_tensor = torch.Tensor(train_mask)

y_test_tensor = torch.Tensor(y_test).type(torch.float32)
test_mask_tensor = torch.Tensor(test_mask)

y_val_tensor = torch.Tensor(y_val).type(torch.float32)
val_mask_tensor = torch.Tensor(val_mask)

best_test_acc = 0
best_val_acc = 0
best_val_loss = 10000
clip_value_min = -2.0
clip_value_max = 2.0


# In[ ]:


print("*************")
print("Training GNN!")
print("*************")

tik = time.time()

f = open("LISA_TEST_LOGS/GNN_WEIGHTS/GNN_LOG_" + args.dataset + ".txt", "w")

curr_step = 0

for epoch in range(args.epochs):
#     print(features_tensor.shape, support_tensor.shape)
    output = model((features_tensor,support_tensor),training=True)
    cross_loss = masked_softmax_cross_entropy(output, y_train_tensor,train_mask_tensor)

    lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))
    loss = cross_loss + args.weight_decay*lossL2
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value_max)
    optimizer.step()

    train_acc = masked_accuracy(output, y_train_tensor,train_mask_tensor)
    val_acc  = masked_accuracy(output, y_val_tensor,val_mask_tensor)
    val_loss = masked_softmax_cross_entropy(output, y_val_tensor, val_mask_tensor)
    test_acc  = masked_accuracy(output, y_test_tensor,test_mask_tensor)

    if val_acc > best_val_acc:
        curr_step = 0
        best_test_acc = test_acc
        best_val_acc = val_acc
        best_val_loss= val_loss
        if args.save_model:
            best_state_dict = model.state_dict()

    else:
        curr_step +=1
    if curr_step > args.early_stop:
        print("Early stopping...")
        break

    if (epoch + 1) % 100 == 0:
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cross_loss), "train_acc=",
              "{:.5f}".format(train_acc), "val_acc=", "{:.5f}".format(val_acc), "test_acc=", "{:.5f}".format(test_acc),
              "best_test_acc=", "{:.5f}".format(best_test_acc))

    
torch.save(best_state_dict, f'model_weights/GCN_{args.dataset}_BEST.pt')

if not args.valid:
    torch.save(model.state_dict(), f'model_weights/GCN_{args.dataset}_LAST.pt')

tok = time.time()

f.write("Epoch,%04d" % (epoch + 1) + "\n")
f.write("train_loss,{:.5f}".format(cross_loss) + "\n")
f.write("train_acc,{:.5f}".format(train_acc) + "\n")
f.write("val_acc,{:.5f}".format(val_acc) + "\n")
f.write("test_acc,{:.5f}".format(test_acc) + "\n")
f.write("best_test_acc,{:.5f}".format(best_test_acc) + "\n")
f.write("Time,{}".format(tok - tik) + "\n")
    
f.close()

print("******************")        
print("Done training GNN!")
print("******************")

