import sys
sys.path.append('./codes/forgraph/')
from config import args
# import tensorflow as tf
import torch
import numpy as np
from models import GCN2 as GCN
from metrics import *

####

args.dataset = 'Mutagenicity'
import pickle as pkl
print('Opening dataset')
with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
    adjs, features, labels = pkl.load(fin)
print('Opened dataset')
order = np.random.permutation(adjs.shape[0])
shuffle_adjs = adjs[order]
shuffle_features = features[order]
shuffle_labels = labels[order]

train_split = int(adjs.shape[0] * 0.8)
val_split = int(adjs.shape[0] * 0.9)

train_adjs = shuffle_adjs[:train_split]
train_features = shuffle_features[:train_split]
train_labels = shuffle_labels[:train_split]
train_ids = order[:train_split]

val_adjs = shuffle_adjs[train_split:val_split]
val_features = shuffle_features[train_split:val_split]
val_labels = shuffle_labels[train_split:val_split]
val_ids = order[train_split:val_split]

test_adjs = shuffle_adjs[val_split:]
test_features = shuffle_features[val_split:]
test_labels = shuffle_labels[val_split:]
test_ids = order[val_split:]

####

print(GCN)
model = GCN(input_dim=train_features.shape[-1], output_dim=train_labels.shape[1])
model.load_state_dict(torch.load(f'model_weights/GCN_{args.dataset}_LAST.pt'))

optimizer = torch.optim.Adam(lr=args.lr)
# optimizer = tf.keras.optimizers.Adam(learning_rate=args.lr)

# train_adjs_tensor = tf.convert_to_tensor(train_adjs,dtype=tf.float32)
train_adjs_tensor = torch.tensor(train_adjs, dtype=torch.float32)
# train_features_tensor = tf.convert_to_tensor(train_features,dtype=tf.float32)
train_features_tensor = torch.tensor(train_features, dtype=torch.float32)
# train_labels_tensor = tf.convert_to_tensor(train_labels,dtype=tf.float32)
train_labels_tensor = torch.tensor(train_labels,dtype=torch.float32)

val_adjs_tensor = torch.tensor(val_adjs,dtype=torch.float32)
val_features_tensor = torch.tensor(val_features,dtype=torch.float32)
val_labels_tensor = torch.tensor(val_labels,dtype=torch.float32)

test_adjs_tensor = torch.tensor(test_adjs,dtype=torch.float32)
test_features_tensor = torch.tensor(test_features,dtype=torch.float32)
test_labels_tensor = torch.tensor(test_labels,dtype=torch.float32)

best_val_acc = 0
best_val_loss = 10000
clip_value_min = -2.0
clip_value_max = 2.0

output = model.forward((train_features_tensor, train_adjs_tensor), training=False)
train_acc = accuracy(output, train_labels_tensor)
val_output = model.forward((val_features_tensor, val_adjs_tensor), training=False)
val_acc  = accuracy(val_output, val_labels_tensor)
val_loss = softmax_cross_entropy(val_output, val_labels_tensor)

test_output = model.forward((test_features_tensor, test_adjs_tensor), training=False)
test_acc  = accuracy(test_output, test_labels_tensor)
test_loss = softmax_cross_entropy(test_output, test_labels_tensor)

print("train_acc={:.5f}".format(train_acc), "val_acc=", "{:.5f}".format(val_acc),
      "test_acc=", "{:.5f}".format(test_acc))
