#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
sys.path.append('./codes/forgraph')
from config import args
# import tensorflow as tf
import torch
import torch.nn as nn
import time
from models import GCN2 as GCN
from metrics import *
import pickle as pkl
from matplotlib import pyplot as plt
import networkx as nx
import numpy as np

# args.bn = True
# args.concat = True


# In[ ]:


with open('./dataset/BA-2motif.pkl','rb') as fin:
    (adjs, feas, labels) = pkl.load(fin)

def vis(adj):
    G = nx.from_numpy_matrix(adj)
    pos = nx.kamada_kawai_layout(G)
    nx.draw_networkx_nodes(G, pos)
    nx.draw_networkx_edges(G, pos)

    plt.axis('off')
    plt.show()
    plt.clf()


# In[ ]:


# vis(adjs[0])
# vis(adjs[500])

order = np.arange(adjs.shape[0])
shuffle_adjs = adjs[order]
shuffle_feas = feas[order]
shuffle_labels = labels[order]

train_split = int(adjs.shape[0] * 0.8)
val_split = int(adjs.shape[0] * 0.9)

train_adjs = shuffle_adjs[:train_split]
train_feas = shuffle_feas[:train_split]
train_labels = shuffle_labels[:train_split]

val_adjs = shuffle_adjs[train_split:val_split]
val_feas = shuffle_feas[train_split:val_split]
val_labels = shuffle_labels[train_split:val_split]

test_adjs = shuffle_adjs[val_split:]
test_feas = shuffle_feas[val_split:]
test_labels = shuffle_labels[val_split:]


# In[ ]:


print("*************")
print("Training GNN!")
print("*************")

tik = time.time()

f = open("LISA_TEST_LOGS/GNN_WEIGHTS/GNN_LOG_" + args.dataset + ".txt", "w")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = GCN(input_dim=train_feas.shape[1:], output_dim=train_labels.shape[1], device=device)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

train_adjs_tensor = torch.tensor(train_adjs).to(torch.float32)
train_features_tensor = torch.tensor(train_feas).to(torch.float32)
train_labels_tensor = torch.tensor(train_labels).to(torch.float32)

val_adjs_tensor = torch.tensor(val_adjs).to(torch.float32)
val_features_tensor = torch.tensor(val_feas).to(torch.float32)
val_labels_tensor = torch.tensor(val_labels).to(torch.float32)

test_adjs_tensor = torch.tensor(test_adjs).to(torch.float32)
test_features_tensor = torch.tensor(test_feas).to(torch.float32)
test_labels_tensor = torch.tensor(test_labels).to(torch.float32)

best_test_acc = 0
best_val_acc = -1
clip_value_min = -2.0
clip_value_max = 2.0

curr_step = 0
for epoch in range(args.epochs):
    if args.batch:
        begin = 0
        batch_size= 64
        end = batch_size
        trainsize = train_adjs.shape[0]
        outputs = []
        while begin<trainsize:
            batch_train_adjs_tensor = torch.tensor(train_adjs[begin:end]).to(torch.float32)
            batch_train_features_tensor = torch.tensor(train_feas[begin:end]).to(torch.float32)
            batch_train_labels_tensor = torch.tensor(train_labels[begin:end]).to(torch.float32)
            
            output = model.forward((batch_train_features_tensor,batch_train_adjs_tensor),training=True)
            cross_loss = softmax_cross_entropy(output, batch_train_labels_tensor)
            lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))
            loss = cross_loss + args.weight_decay * lossL2
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(model.parameters(), clip_value_max)
            optimizer.step()
            begin = end
            end = min(end+batch_size,trainsize)
            outputs.append(output)
        output = torch.cat(outputs,dim=0)
    else:
        output = model.forward((train_feas_tensor,train_adjs_tensor),training=True)
        cross_loss = softmax_cross_entropy(output, train_labels_tensor)
        lossL2 = torch.sum(torch.Tensor([torch.sum(v**2) / 2 for v in model.parameters()]))
        loss = cross_loss + args.weight_decay*lossL2
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), clip_value_max)
        optimizer.step()

    train_acc = accuracy(output, train_labels_tensor)
    val_output = model.forward((val_features_tensor, val_adjs_tensor), training=False)
    val_acc  = accuracy(val_output, val_labels_tensor)
    val_loss = softmax_cross_entropy(val_output, val_labels_tensor)

    test_output = model.forward((test_features_tensor, test_adjs_tensor), training=False)
    test_acc  = accuracy(test_output, test_labels_tensor)
    test_loss = softmax_cross_entropy(test_output, test_labels_tensor)
    
    # Save better model
    if val_acc > best_val_acc:
        curr_step = 0
        best_test_acc = test_acc
        best_val_acc = val_acc
        if args.save_model:
            best_state_dict = model.state_dict()
    else:
        curr_step +=1
        
    if curr_step > args.early_stop:
        print("Early stopping...")
        break

    if (epoch + 1) % 100 == 0:
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cross_loss), "train_acc=",
              "{:.5f}".format(train_acc), "val_acc=", "{:.5f}".format(val_acc),
              "test_acc=", "{:.5f}".format(test_acc),
              "test_best_acc=", "{:.5f}".format(best_test_acc))


torch.save(best_state_dict, f'model_weights/GCN_BA2motif_BEST.pt')
torch.save(model.state_dict(), f'model_weights/GCN_BA2motif_LAST.pt')

tok = time.time()

f.write("Epoch,%04d" % (epoch + 1) + "\n")
f.write("train_loss,{:.5f}".format(cross_loss) + "\n")
f.write("train_acc,{:.5f}".format(train_acc) + "\n")
f.write("val_acc,{:.5f}".format(val_acc) + "\n")
f.write("test_acc,{:.5f}".format(test_acc) + "\n")
f.write("best_test_acc,{:.5f}".format(best_test_acc) + "\n")
f.write("Time,{}".format(tok - tik) + "\n")
    
f.close()

print("******************")        
print("Done training GNN!")
print("******************")

