from gnn import GGNN, GAT, EGC,GGNNtest,GGNNtest2,GGNNtest3,GGNNtest4,GGNNtest5,GGNNtest6,GGNNtest7,GGNNtest8,EdgeAwareGraphSAGE0
from utils.utils import ModifiedMarginRankingLoss,train_model, train_model2, SignLoss, getCorrectProblemTypes, evaluate, GeometricDataset, groupLabels, getWeights,GeometricDataset2,evaluate4
import torch, json, time, argparse
import torch.optim as optim
import numpy as np
from torch_geometric.explain import GNNExplainer
from torch_geometric.nn import AttentionalAggregation
import torch_geometric
from torch_geometric.nn import GCNConv,GATConv


'''
File - netTrainer.py
This file is a driver used to train networks
'''

array=np.array([
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
])


array1=np.array([
    [1, 1, 0, 0, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 0, 1, 0, 0],
    [1, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 1, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 1, 0, 0, 1],
    [0, 0, 1, 0, 0, 1],
    [0, 0, 1, 0, 1, 1]
])

if __name__ == '__main__':
        parser = argparse.ArgumentParser(description="GNN Trainer")
        parser.add_argument("--mp-layers", help="Number of message passing layers (Default=0)", default=0, type=int)
        parser.add_argument("-e", "--epochs", help="Number of training epochs (Default=20)", default=20, type=int)
        parser.add_argument("--edge-sets", help="Which edges sets to include: AST, CFG, Data (Default=All)", nargs='+', default=['AST', 'Data', "ICFG"], choices=['AST', 'Data', "ICFG"])
        parser.add_argument("-p", "--problem-types", help="Which problem types to consider:termination, overflow, reachSafety, memSafety (Default=All)", nargs="+", default=['termination', 'overflow', 'reachSafety', 'memSafety'], choices=['termination', 'overflow', 'reachSafety', 'memSafety'])
        parser.add_argument('-n','--net', help="GGNN, GAT, EGC", default="EGC", choices=["GGNN","GAT", "EGC"])
        parser.add_argument("-m", "--mode", help="Mode for jumping (Default LSTM): max, cat, lstm", default="cat", choices=['max', 'cat', 'lstm'])
        parser.add_argument("--pool-type", help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)", default=["attention"], choices=["max",'min', "mean","add","attention","power","softmax","equilibrium"], nargs='+')
        parser.add_argument("--aggregators", help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)", default=["attention"], choices=["max",'min', "mean","symnorm","std"], nargs='+')
        parser.add_argument("-g", "--gpu", help="Which GPU should the model be on", default=0, type=int)
        parser.add_argument("--task", help="Which task are you training for (topK, rank, success,algorithm)?", default="rank", choices=['topk', 'ranking', 'success','algorithm'])
        parser.add_argument("-k", "--topk", help="k for topk (1-10)", default=3, type=int)
        parser.add_argument("--cache", help="If activated, will cache dataset in memory", action='store_true')
#       parser.add_argument("--network", required=True, help="The network continue training")
        parser.add_argument("--no-jump", help="Whether or not to use jumping knowledge", action="store_false", default=True)
        parser.add_argument("train", nargs=1, help="The training set")
        parser.add_argument("val", nargs=1, help="The validation set")
        parser.add_argument("test", nargs=1, help="The test set")
        parser.add_argument("dataset", nargs=1, help="Directory housing the dataset")
       # parser.add_argument("--network", required=True, help="The network continue training")


        args = parser.parse_args()

        try:
                trainFiles = json.load(open(args.train[0]))
        except FileNotFoundError:
                print("Error:", args.train[0], "does not exists. Please input a valid file")
                exit(1)
     #  trainLabels=[(key,np.sum([np.where(array[idx]== 1,abs(value[1]+15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[0] == 0 else 0,array[idx]) for idx, value in enumerate(trainFiles[key])],axis=0).tolist())]
        trainLabels = [(key, np.sum([array1[idx] * (value[1] + (value[0] - value[2])) for idx, value in enumerate(trainFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in trainFiles[key]])[-3:]], axis=0).tolist()) for key in trainFiles]
#       trainLabels = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in trainFiles[key]]) for key in trainFiles]
   #     sign_weight = torch.tensor(1.0, requires_grad=True,device=args.gpu)

        try:
                valFiles = json.load(open(args.val[0]))
        except:
                print("Error:", args.val[0], "does not exists. Please input a valid file")
                exit(1)
#        valLabels = [(key,np.sum([np.where(array[idx] == 1,value[1],array[idx]) for idx, value in enumerate(valFiles[key])],axis=0).tolist())]
        valLabels = [(key,np.sum([array1[idx] * (value[1] + (value[0] - value[2])) for idx, value in enumerate(valFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in valFiles[key]])[-3:]], axis=0).tolist()) for key in valFiles]
#        valLabels = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in valFiles[key]]) for key in valFiles]

        try:
                testFiles = json.load(open(args.test[0]))
        except:
                print("Error:", args.test[0], "does not exists. Please input a valid file")
                exit(1)
#        testLabels = [(key,np.sum([np.where(array[idx] == 1,value[1],array[idx]) for idx, value in enumerate(testFiles[key])],axis=0).tolist())]
        testLabels = [(key, np.sum([array1[idx] * (value[1] + (value[0] - value[2])) for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
#        testLabels = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in testFiles[key]]) for key in testFiles]       
#rank_weight = torch.tensor(1.0, requires_grad=True,device=args.gpu)

        train_set = GeometricDataset2(trainLabels, args.dataset[0],4, should_cache=args.cache)
        val_set = GeometricDataset2(valLabels, args.dataset[0],4, should_cache=args.cache)
        test_set = GeometricDataset2(testLabels, args.dataset[0], 4, should_cache=args.cache)


        if args.net == 'GGNN':
                model = EdgeAwareGraphSAGE0(passes=args.mp_layers, numEdgeSets=4, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), collate=args.mode).to(device=args.gpu)
        elif args.net == "GAT":
                model = GAT(passes=args.mp_layers, numEdgeSets=len(args.edge_sets), numAttentionLayers=5, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), mode=args.mode, k=20, shouldJump=args.no_jump, pool=args.pool_type).to(device=args.gpu)
        else:
                model = EGC(passes=args.mp_layers, inputLayerSize=train_set[0][0].x.size(1), outputLayerSize=len(trainLabels[0][1]), aggregators=['sum','mean','max'],shouldJump=args.no_jump, pool=args.pool_type).to(device=args.gpu)

        if args.task == "topk":
            loss_fn = torch.nn.NLLLoss()
 #           model.load_state_dict(torch.load(args.network))
            returnString = str(time.time_ns())
        #    loss_fn = SignLoss(gpu=args.gpu).to(device=args.gpu)
#        elif args.task == "topk":
#            loss_fn = torch.nn.NLLLoss(reduction = 'sum')
        else:
                raise ValueError("Not a valid task") 
        optimizer = optim.Adam(model.parameters(), lr = 1e-3, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
       # model.load_state_dict(torch.load(args.network))
        report = train_model(model=model, loss_fn = loss_fn, batchSize=1, trainset=train_set, valset=val_set, optimizer=optimizer, scheduler=scheduler, num_epochs=args.epochs, gpu=args.gpu, task=args.task, k=args.topk)
        train_acc, train_loss, val_acc, val_loss = report
        torch.save(model.state_dict(),returnString+ "algRank-stra-test-tool.pt")
        (overallRes, overflowRes, reachSafetyRes, terminationRes, memSafetyRes), (overallChoices, overflowChoices, reachSafetyChoices, terminationChoices, memSafetyChoices), predicts = evaluate4(model, test_set, files=[x[0] for x in testLabels], gpu=args.gpu)
        np.savez_compressed(returnString+"algRank-stra-test-tool.npz", train_acc = train_acc, train_loss = train_loss, val_acc = val_acc, val_loss = val_loss, overallRes=overallRes, overflowRes=overflowRes, reachSafetyRes=reachSafetyRes, terminationRes=terminationRes, memSafetyRes=memSafetyRes, overallChoices=overallChoices, overflowChoices=overflowChoices, reachSafetyChoices=reachSafetyChoices, terminationChoices=terminationChoices, memSafetyChoices=memSafetyChoices)
#       torch.save(model.state_dict(),returnString+ "algRank-stra-test-tool.pt")
#       json.dump(predicts, open(returnString+"algRank-P-ROBUST-SSS-8.json",'w'))

