from gnn import GGNN, GAT, EGC, GGNNtest, GGNNtest2, GGNNtest3, GGNNtest4, GGNNtest5, GGNNtest6,GGNNtest7, toolSelector, StrategySelector,EdgeAwareGraphSAGE,EdgeAwareGraphSAGE0
from utils.utils import GeometricDataset3, train_modelTool, evaluateTool,evaluate2
import torch, json, time, argparse
import torch.optim as optim
import numpy as np
from torch_geometric.nn import AttentionalAggregation
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv

'''
File - netTrainer.py
This file is a driver used to train networks
'''

array = np.array([
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
])

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="GNN Trainer")
    parser.add_argument("--mp-layers", help="Number of message passing layers (Default=0)", default=0, type=int)
    parser.add_argument("-e", "--epochs", help="Number of training epochs (Default=20)", default=20, type=int)
    parser.add_argument("--edge-sets", help="Which edges sets to include: AST, CFG, Data (Default=All)", nargs='+',
                        default=['AST', 'Data', "ICFG"], choices=['AST', 'Data', "ICFG"])
    parser.add_argument("-p", "--problem-types",
                        help="Which problem types to consider:termination, overflow, reachSafety, memSafety (Default=All)",
                        nargs="+", default=['termination', 'overflow', 'reachSafety', 'memSafety'],
                        choices=['termination', 'overflow', 'reachSafety', 'memSafety'])
    parser.add_argument('-n', '--net', help="GGNN, GAT, EGC", default="EGC", choices=["GGNN", "GAT", "EGC"])
    parser.add_argument("-m", "--mode", help="Mode for jumping (Default LSTM): max, cat, lstm", default="cat",
                        choices=['max', 'cat', 'lstm'])
    parser.add_argument("--pool-type",
                        help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)",
                        default=["attention"],
                        choices=["max", 'min', "mean", "add", "attention", "power", "softmax", "equilibrium"],
                        nargs='+')
    parser.add_argument("--aggregators",
                        help="How to pool Nodes (max, mean, add, attention, power, softmax, equilibrium)",
                        default=["attention"], choices=["max", 'min', "mean", "symnorm", "std"], nargs='+')
    parser.add_argument("-g", "--gpu", help="Which GPU should the model be on", default=0, type=int)
    parser.add_argument("--task", help="Which task are you training for (topK, rank, success)?", default="topk",
                        choices=['topk'])
    parser.add_argument("-k", "--topk", help="k for topk (1-10)", default=3, type=int)
    parser.add_argument("--cache", help="If activated, will cache dataset in memory", action='store_true')
    parser.add_argument("--alg", help="If activate, will look at algorithm groups instead of tools",
                        action="store_true")
    parser.add_argument("--no-jump", help="Whether or not to use jumping knowledge", action="store_false", default=True)
    parser.add_argument("train", nargs=1, help="The training set")
    parser.add_argument("val", nargs=1, help="The validation set")
    parser.add_argument("test", nargs=1, help="The test set")
    parser.add_argument("dataset", nargs=1, help="Directory housing the dataset")
    parser.add_argument("--network1", required=True, help="The network being evaluated")
    parser.add_argument("--network0", required=True, help="The network being evaluated")
    parser.add_argument("--train_expand", required=True, help="The training set")
    parser.add_argument("--val_expand", required=True, help="The validation set")
    parser.add_argument("--test_expand", required=True, help="The test set")

    args = parser.parse_args()

    try:
        trainFiles = json.load(open(args.train[0]))
        with open(args.train_expand, 'r') as f:
            trainFiles_expand = json.load(f)
    except FileNotFoundError:
        print("Error:", args.train[0], "does not exists. Please input a valid file")
        exit(1)
    trainLabels1 = [(key, np.sum(
        [array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1]
         for idx, value in enumerate(trainFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in trainFiles[key]])[-3:]], axis=0).tolist()) for key in trainFiles]
    trainLabels0 = [(key, np.sum(
        [np.where(array[idx] == 1, abs(value[1] + 15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[1] <= 0 else 0, array[idx])
         for idx, value in enumerate(trainFiles[key])], axis=0).tolist()) for key in trainFiles]
    trainLabels_true = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in trainFiles_expand[key]]) for key in trainFiles_expand]
    #	trainLabels = getCorrectProblemTypes(trainLabels, args.problem_types)

    try:
        valFiles = json.load(open(args.val[0]))
        with open(args.val_expand, 'r') as f:
            valFiles_expand = json.load(f)
    except:
        print("Error:", args.val[0], "does not exists. Please input a valid file")
        exit(1)
    valLabels1 = [(key, np.sum(
        [array[idx] * (value[1]+1) if value[0] == 1 else array[idx] * value[1] 
         for idx, value in enumerate(valFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in valFiles[key]])[-3:]], axis=0).tolist()) for key in valFiles]
    valLabels0 = [(key, np.sum(
        [np.where(array[idx] == 1, abs(value[1] + 15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[1] <= 0 else 0, array[idx])
         for idx, value in enumerate(valFiles[key])], axis=0).tolist()) for key in valFiles]
    valLabels_true = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in valFiles_expand[key]]) for key in valFiles_expand]
    #	valLabels = getCorrectProblemTypes(valLabels, args.problem_types)

    try:
        testFiles = json.load(open(args.test[0]))
        with open(args.test_expand, 'r') as f:
            testFiles_expand = json.load(f)
    except:
        print("Error:", args.test[0], "does not exists. Please input a valid file")
        exit(1)
    testLabels1 = [(key, np.sum(
        [array[idx] * (value[1] + 1) if value[0]  == 1 else array[idx] * value[1]
         for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
    testLabels0 = [(key, np.sum(
        [np.where(array[idx] == 1, abs(value[1] + 15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[1] <= 0 else 0, array[idx])
         for idx, value in enumerate(testFiles[key])], axis=0).tolist()) for key in testFiles]
    testLabels_true = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in testFiles_expand[key]]) for key in testFiles_expand]
    #	testLabels = getCorrectProblemTypes(testLabels, args.problem_types)

    train_set = GeometricDataset3(trainLabels1, trainLabels0, trainLabels_true, args.dataset[0], 4, should_cache=args.cache)
    val_set = GeometricDataset3(valLabels1, valLabels0, valLabels_true, args.dataset[0], 4, should_cache=args.cache)
    test_set = GeometricDataset3(testLabels1, testLabels0, testLabels_true, args.dataset[0], 4, should_cache=args.cache)
    print(test_set[0][3].size(0))
    print(len(testLabels_true[0][1]))

    # 创建模型实例
    algPositiveModel = EdgeAwareGraphSAGE(
        passes=1,
        numEdgeSets=4,
        inputLayerSize=test_set[0][0].x.size(1),  # 确保这里的结构与加载时一致
        outputLayerSize=len(testLabels1[0][1]),
        collate= 'max'
    ).to(device=args.gpu)

    for param in algPositiveModel.parameters():
        param.requires_grad = False
    
    # 加载模型权重
    algPositiveModel.load_state_dict(torch.load(args.network1))
    
    algNegativeModel = EdgeAwareGraphSAGE0(
        passes=1,
        numEdgeSets=4,
        inputLayerSize=test_set[0][0].x.size(1),  # 确~]~Y~G~L~Z~D~S~^~D~N~J| 载~W~@~G
        outputLayerSize=len(testLabels0[0][1]),
        collate= 'max'
    ).to(device=args.gpu)

    for param in algNegativeModel.parameters():
        param.requires_grad = False

    # ~J| 载模~^~K~]~C~G~M
    algNegativeModel.load_state_dict(torch.load(args.network0))
    
    model = toolSelector(inputLayerSize=len(testLabels1[0][1]) * 2,
                         hidden_dim=64,
                         outputLayerSize=len(testLabels_true[0][1])
                         ).to(device=args.gpu)

    if args.task == "topk":
        loss_fn = torch.nn.NLLLoss(reduction = 'sum')
    else:
        raise ValueError("Not a valid task")
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    report = train_modelTool(model=model, algPositiveModel=algPositiveModel, algNegativeModel=algNegativeModel,loss_fn=loss_fn, batchSize=1, trainset=train_set,
                             valset=val_set,
                             optimizer=optimizer, scheduler=scheduler, num_epochs=args.epochs, gpu=args.gpu,
                             task=args.task,
                             k=args.topk)
    returnString = str(time.time_ns())
    torch.save(model.state_dict(), returnString + "tool-ROBUST-SSS-GraphSAGE-CPG-20.pt")
    train_loss, val_loss = report
    (overallRes, overflowRes, reachSafetyRes, terminationRes, memSafetyRes), (
        overallChoices, overflowChoices, reachSafetyChoices, terminationChoices,
        memSafetyChoices), predicts = evaluate2(
        model, algPositiveModel, algNegativeModel, test_set, files=[x[0] for x in testLabels_true], gpu=args.gpu)

#    returnString = str(time.time_ns())

    np.savez_compressed(returnString + "tool-ROBUST-SSS-GraphSAGE-CPG-20.npz", train_loss=train_loss, val_loss=val_loss,
                        overallRes=overallRes, overflowRes=overflowRes,
                        reachSafetyRes=reachSafetyRes, terminationRes=terminationRes, memSafetyRes=memSafetyRes,
                        overallChoices=overallChoices, overflowChoices=overflowChoices,
                        reachSafetyChoices=reachSafetyChoices, terminationChoices=terminationChoices,
                        memSafetyChoices=memSafetyChoices)
#    torch.save(model.state_dict(), returnString + "tool-ROBUST-SSS-GraphSAGE-expand-experi-15.pt")
#    json.dump(predicts, open(returnString + "tool-ROBUST-SSS-GraphSAGE-experi.json", 'w'))

