from gnn import GGNNtest6,toolSelector,EdgeAwareGraphSAGE0,EdgeAwareGraphSAGE
from utils.utils import evaluate2, evaluate3, GeometricDataset2, GeometricDataset3
from gnn import GGNN, GAT, EGC,GGNNtest,GGNNtest2,GGNNtest3,GGNNtest4,GGNNtest5,GGNNtest6
from utils.utils import ModifiedMarginRankingLoss, train_model, getCorrectProblemTypes, evaluate, GeometricDataset, groupLabels, getWeights,GeometricDataset2
import torch, json, time, argparse
import torch.optim as optim
import numpy as np
from torch_geometric.explain import GNNExplainer
from torch_geometric.nn import AttentionalAggregation
import torch_geometric
from torch_geometric.nn import GCNConv,GATConv

array = np.array([
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
])

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="GNN Evaluator")
    parser.add_argument("--networkT", required=True, help="The network being evaluated")
    parser.add_argument("--network1", required=True, help="The network being evaluated")
    parser.add_argument("--network0", required=True, help="The network being evaluated")
    parser.add_argument("--test", required=True, nargs=1, help="The test set")
    parser.add_argument("--dataset", nargs=1, help="Directory housing the dataset")
    parser.add_argument("-p", "--problem-types", help="Which problem types to consider:termination, overflow, reachSafety, memSafety (Default=All)", nargs="+", default=['termination', 'overflow', 'reachSafety', 'memSafety'], choices=['termination', 'overflow', 'reachSafety', 'memSafety'])
#    parser.add_argument("--test_expand", required=True, help="The expand test set")

    args = parser.parse_args()

    # 加载测试文件
    try:
        testFiles = json.load(open(args.test[0]))
#        with open(args.test_expand, 'r') as f:
#            testFiles_expand = json.load(f)
    except FileNotFoundError:
        print("Error:", args.test[0], "does not exist. Please input a valid file.")
        exit(1)
    testLabels1 = [(key, np.sum([array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1] for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
    testLabels0 = [(key, np.sum([np.where(array[idx] == 1, abs(value[1]+15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[0] == 0 else 0, array[idx]) for idx,value in enumerate(testFiles[key])], axis=0).tolist()) for key in testFiles]
    testLabels = [(key, [item[1]+1 if item[0] == 1 else item[1] for item in testFiles[key]]) for key in testFiles]

    # �~H~[建�~U��~M��~[~F
    test_set = GeometricDataset3(testLabels1, testLabels0, testLabels, args.dataset[0], 4, should_cache= True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    Tmodel = toolSelector(inputLayerSize=len(testLabels1[0][1]) * 2,
                          hidden_dim=64,
                          outputLayerSize=len(testLabels[0][1])
                          ).to(device)

    Tmodel.load_state_dict(torch.load(args.networkT))
    Tmodel.eval()

    # 创建模型实例
    Pmodel = EdgeAwareGraphSAGE(
        passes=1,
        numEdgeSets=4,
        inputLayerSize=test_set[0][0].x.size(1),  # 确保这里的结构与加载时一致
        outputLayerSize=len(testLabels1[0][1]),
        collate= 'max'
    ).to(device)

    # 加载模型权重
    Pmodel.load_state_dict(torch.load(args.network1))
    Pmodel.eval()  # 设置模垄估模式

    Nmodel = EdgeAwareGraphSAGE0(passes=1,numEdgeSets=4,inputLayerSize=test_set[0][0].x.size(1),outputLayerSize=len(testLabels0[0][1]),collate = 'max').to(device)
    Nmodel.load_state_dict(torch.load(args.network0))
    Nmodel.eval()
    # 进行评估
    (overallRes, overflowRes, reachSafetyRes, terminationRes, memSafetyRes), (overallChoices, overflowChoices, reachSafetyChoices, terminationChoices, memSafetyChoices),predicts = evaluate2(Tmodel, Pmodel, Nmodel, test_set, files=[x[0] for x in testLabels])
#    res = evaluate3(test_set, files=[x[0] for x in testLabels])

    # 处理结果
  #  returnString = str(args).replace(",", "_").replace(" ", "").replace("\'", "").replace("Namespace", "").replace("(", "").replace(")", "") + "_" + str(time.time_ns())
#    np.savez_compressed("Alg-tanxin.npz",res=res)
    np.savez_compressed("Alg-Robust-change.npz", overallRes=overallRes, overflowRes=overflowRes,
                        reachSafetyRes=reachSafetyRes, terminationRes=terminationRes, memSafetyRes=memSafetyRes,
                        overallChoices=overallChoices, overflowChoices=overflowChoices,
                        reachSafetyChoices=reachSafetyChoices, terminationChoices=terminationChoices,
                        memSafetyChoices=memSafetyChoices)

    # 保存模型状态
#    torch.save(model.state_dict(),  "test.pt")
#    json.dump(predicts, open("test.json", 'w'))

