from gnn import GGNN, GAT, EGC,GGNNtest,GGNNtest2,GGNNtest3,GGNNtest4,GGNNtest5,GGNNtest6,GGNNtest7,GGNNtest8,EdgeAwareGraphSAGE
from utils.utils import ModifiedMarginRankingLoss, train_model2, SignLoss, getCorrectProblemTypes, evaluate, GeometricDataset, groupLabels, getWeights,GeometricDataset2
import torch, json, time, argparse
import torch.optim as optim
import numpy as np
from torch_geometric.explain import GNNExplainer
from torch_geometric.nn import AttentionalAggregation
import torch_geometric
from torch_geometric.nn import GCNConv,GATConv


'''
File - netTrainer.py
This file is a driver used to train networks
'''

array=np.array([
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
])

if __name__ == '__main__':
	parser = argparse.ArgumentParser(description="GNN Trainer")
	parser.add_argument("--mp-layers", help="Number of message passing layers (Default=0)", default=0, type=int)
	parser.add_argument("-e", "--epochs", help="Number of training epochs (Default=20)", default=20, type=int)
	parser.add_argument("--edge-sets", help="Which edges sets to include: AST, CFG, Data (Default=All)", nargs='+', default=['AST', 'Data', "ICFG"], choices=['AST', 'Data', "ICFG"])
	parser.add_argument("-p", "--problem-types", help="Which problem types to consider:termination, overflow, reachSafety, memSafety (Default=All)", nargs="+", default=['termination', 'overflow', 'reachSafety', 'memSafety'], choices=['termination', 'overflow', 'reachSafety', 'memSafety'])
	parser.add_argument('-n','--net', help="GGNN, GAT, EGC", default="EGC", choices=["GGNN","GAT", "EGC"])
	parser.add_argument("-m", "--mode", help="Mode for jumping (Default LSTM): max, cat, lstm", default="cat", choices=['max', 'cat', 'lstm'])
	parser.add_argument("--pool-type", help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)", default=["attention"], choices=["max",'min', "mean","add","attention","power","softmax","equilibrium"], nargs='+')
	parser.add_argument("--aggregators", help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)", default=["attention"], choices=["max",'min', "mean","symnorm","std"], nargs='+')
	parser.add_argument("-g", "--gpu", help="Which GPU should the model be on", default=0, type=int)
	parser.add_argument("--task", help="Which task are you training for (topK, rank, success,algorithm)?", default="rank", choices=['topk', 'ranking', 'success','algorithm'])
	parser.add_argument("-k", "--topk", help="k for topk (1-10)", default=3, type=int)
	parser.add_argument("--cache", help="If activated, will cache dataset in memory", action='store_true')
#	parser.add_argument("--network", required=True, help="The network continue training")
	parser.add_argument("--no-jump", help="Whether or not to use jumping knowledge", action="store_false", default=True)
	parser.add_argument("train", nargs=1, help="The training set")
	parser.add_argument("val", nargs=1, help="The validation set")
	parser.add_argument("test", nargs=1, help="The test set")
	parser.add_argument("dataset", nargs=1, help="Directory housing the dataset")
       # parser.add_argument("--network", required=True, help="The network continue training")


	args = parser.parse_args()

	try:
		trainFiles = json.load(open(args.train[0]))
	except FileNotFoundError:
		print("Error:", args.train[0], "does not exists. Please input a valid file")
		exit(1)
     #  trainLabels=[(key,np.sum([np.where(array[idx]== 1,abs(value[1]+15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[0] == 0 else 0,array[idx]) for idx, value in enumerate(trainFiles[key])],axis=0).tolist())]
	trainLabels = [(key, np.sum([array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1] for idx, value in enumerate(trainFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in trainFiles[key]])[-3:]], axis=0).tolist()) for key in trainFiles]
	sign_weight = torch.tensor(1.0, requires_grad=True,device=args.gpu)

	try:
		valFiles = json.load(open(args.val[0]))
	except:
		print("Error:", args.val[0], "does not exists. Please input a valid file")
		exit(1)
#        valLabels = [(key,np.sum([np.where(array[idx] == 1,value[1],array[idx]) for idx, value in enumerate(valFiles[key])],axis=0).tolist())]
	valLabels = [(key,np.sum([array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1] for idx, value in enumerate(valFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in valFiles[key]])[-3:]], axis=0).tolist()) for key in valFiles]
	mse_weight = torch.tensor(1.0, requires_grad=True,device=args.gpu)

	try:
		testFiles = json.load(open(args.test[0]))
	except:
		print("Error:", args.test[0], "does not exists. Please input a valid file")
		exit(1)
#        testLabels = [(key,np.sum([np.where(array[idx] == 1,value[1],array[idx]) for idx, value in enumerate(testFiles[key])],axis=0).tolist())]
	testLabels = [(key, np.sum([array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1] for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
	rank_weight = torch.tensor(1.0, requires_grad=True,device=args.gpu)

	train_set = GeometricDataset2(trainLabels, args.dataset[0],4, should_cache=args.cache)
	val_set = GeometricDataset2(valLabels, args.dataset[0],4, should_cache=args.cache)
	test_set = GeometricDataset2(testLabels, args.dataset[0], 4, should_cache=args.cache)


	if args.net == 'GGNN':
		model = EdgeAwareGraphSAGE(passes=args.mp_layers, numEdgeSets=4, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), collate=args.mode).to(device=args.gpu)
	elif args.net == "GAT":
		model = GAT(passes=args.mp_layers, numEdgeSets=len(args.edge_sets), numAttentionLayers=5, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), mode=args.mode, k=20, shouldJump=args.no_jump, pool=args.pool_type).to(device=args.gpu)
	else:
		model = EGC(passes=args.mp_layers, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), aggregators=['sum','mean','max'],shouldJump=args.no_jump, pool=args.pool_type).to(device=args.gpu)

	if args.task == "algorithm":
            loss_fn = torch.nn.NLLLoss()
 #           model.load_state_dict(torch.load(args.network))
            returnString = str(time.time_ns())
        #    loss_fn = SignLoss(gpu=args.gpu).to(device=args.gpu)
       # elif args.task == "algorithm":
        #        loss_fn = SignLoss(gpu=args,gpu).to(device=arg.gpu)
	else:
		raise ValueError("Not a valid task") 
	optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-4)
	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
       # model.load_state_dict(torch.load(args.network))
	report = train_model2(model=model, loss_fn = loss_fn, batchSize=1, trainset=train_set, valset=val_set, optimizer=optimizer, scheduler=scheduler, num_epochs=args.epochs, gpu=args.gpu, task=args.task, k=args.topk)
	train_acc, train_loss, val_acc, val_loss = report
	(overallRes, overflowRes, reachSafetyRes, terminationRes, memSafetyRes), (overallChoices, overflowChoices, reachSafetyChoices, terminationChoices, memSafetyChoices), predicts = evaluate(model, test_set, files=[x[0] for x in testLabels], gpu=args.gpu)
	np.savez_compressed(returnString+"algRank-P-ROBUST-SSS-SAGE-101-ALL-14_17.npz", train_acc = train_acc, train_loss = train_loss, val_acc = val_acc, val_loss = val_loss, overallRes=overallRes, overflowRes=overflowRes, reachSafetyRes=reachSafetyRes, terminationRes=terminationRes, memSafetyRes=memSafetyRes, overallChoices=overallChoices, overflowChoices=overflowChoices, reachSafetyChoices=reachSafetyChoices, terminationChoices=terminationChoices, memSafetyChoices=memSafetyChoices)
	torch.save(model.state_dict(),returnString+ "algRank-P-ROBUST-SSS-SAGE-101-ALL-14_17.pt")
#	json.dump(predicts, open(returnString+"algRank-P-ROBUST-SSS-8.json",'w'))

