import os
import sys  
import json
import numpy as np
import argparse
from argparse import Namespace

sys.path.insert(0,'./src')
os.environ["CUDA_VISIBLE_DEVICES"]="-1"

import sklearn

from dataset import TUDataset
from torch_geometric.data import DataLoader
import pytorch_lightning as pl

from model import *


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--path", default='./checkpoints/GKNN_MUTAG_fold0/last.ckpt')
    params = parser.parse_args()


    model = Model.load_from_checkpoint(checkpoint_path=params.path, map_location=None)
    model.eval()

    run_params = model.hparams

    dataset = TUDataset('./data',run_params.dataset)

    yy = [int(d.y) for d in dataset]
    fold = run_params.fold

    fold = run_params.fold
    with open('./data/folds/%s_folds_%d.txt' % (run_params.dataset, run_params.folds), 'r') as f:
        folds = json.loads(f.read())
    train_i_split,val_i_split,test_i_split = folds[fold]

    test_dataset = dataset[test_i_split]
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    trainer = pl.Trainer.from_argparse_args(run_params)


    res_test = trainer.test(model,test_loader)