from gnn import GGNNtest6, EdgeAwareGraphSAGE,EdgeAwareGraphSAGE0
from utils.utils import evaluateAlg, GeometricDataset3,evaluate, GeometricDataset2,evaluate4
import torch
import json
import argparse
import numpy as np
from gnn import GGNN, GAT, EGC,GGNNtest,GGNNtest2,GGNNtest3,GGNNtest4,GGNNtest5,GGNNtest6
from utils.utils import ModifiedMarginRankingLoss, train_model, getCorrectProblemTypes, evaluate, GeometricDataset, groupLabels, getWeights,GeometricDataset2
import torch, json, time, argparse
import torch.optim as optim
import numpy as np
from torch_geometric.explain import GNNExplainer
from torch_geometric.nn import AttentionalAggregation
import torch_geometric
from torch_geometric.nn import GCNConv,GATConv

array=np.array([
    [0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1]
])

array1=np.array([
    [1, 1, 0, 0, 0, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 0, 1, 0, 0],
    [1, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1],
    [1, 1, 0, 0, 1, 0],
    [0, 1, 0, 1, 0, 0],
    [0, 0, 1, 0, 0, 1],
    [0, 0, 1, 0, 0, 1],
    [0, 0, 1, 0, 1, 1]
])

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="GNN Evaluator")
    parser.add_argument("--network1", required=True, help="The network being evaluated")
#    parser.add_argument("--network0", required=True, help="The network being evaluated")
    parser.add_argument("--test", required=True, nargs=1, help="The test set")
    parser.add_argument("--dataset", nargs=1, help="Directory housing the dataset")
    parser.add_argument("-p", "--problem-types", help="Which problem types to consider:termination, overflow, reachSafety, memSafety (Default=All)", nargs="+", default=['termination', 'overflow', 'reachSafety', 'memSafety'], choices=['termination', 'overflow', 'reachSafety', 'memSafety'])

    args = parser.parse_args()

    # 加载测试文件
    try:
        testFiles = json.load(open(args.test[0]))
    except FileNotFoundError:
        print("Error:", args.test[0], "does not exist. Please input a valid file.")
        exit(1)

#    testLabels1 = [(key, np.sum([np.where(array[idx] == 1,value[1] if value[1] > 0 else 0,array[idx]) for idx, value in enumerate(testFiles[key])],axis=0).tolist()) for key in testFiles]
    testLabels0 = [(key, np.sum([np.where(array[idx] == 1,abs(value[1] + 15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[1] <= 0 else 0,array[idx]) for idx, value in enumerate(testFiles[key])],axis=0).tolist()) for key in testFiles]
    testLabels_true = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in testFiles[key]]) for key in testFiles]
#    testLabels = [(key, [item[1] for item in testFiles[key]]) for key in testFiles]
#    testLabels = [(key, np.sum([array1[idx] * (value[1] + (value[0] - value[2])) for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
#    testLabels = [(key, np.sum([np.where(array[idx] == 1,abs(value[1]+15) if value[0] == -16 else abs(value[1] + 30) if value[0] == -32 else abs(value[1]) if value[0] == 0 else 0,array[idx]) for idx, value in enumerate(testFiles[key])],axis=0).tolist()) for key in testFiles]
    testLabels1 = [(key, np.sum([array[idx] * (value[1] + 1) if value[0] == 1 else array[idx] * value[1] for idx, value in enumerate(testFiles[key]) if (value[1] + 1 if value[0] == 1 else value[1]) > 0 and idx in np.argsort([v[1] + 1 if v[0] == 1 else v[1] for v in testFiles[key]])[-3:]], axis=0).tolist()) for key in testFiles]
    # 创建数据集
#    test_set = GeometricDataset3(testLabels1, testLabels0, testLabels_true, args.dataset[0], 4, should_cache=True)
#    testLabels = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in testFiles[key]]) for key in testFiles] 
    testLabels = [(key, [item[1] + 1 if item[0] == 1 else item[1] for item in testFiles[key]]) for key in testFiles]
    test_set = GeometricDataset2(testLabels,args.dataset[0],4,should_cache=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 创建模型实例
    Pmodel = EdgeAwareGraphSAGE0(
        passes= 1,
        numEdgeSets=4,
        inputLayerSize=test_set[0][0].x.size(1),  # 确保这里的结构与加载时一致
        outputLayerSize=10,
        collate= 'max'
    ).to(device)

    # 加载模型权重
    Pmodel.load_state_dict(torch.load(args.network1))
    Pmodel.eval()  # 设置模型为评估模式
    
#    Nmodel = GGNNtest6(
#        passes= 1,
#        numEdgeSets=4,
#        inputLayerSize=test_set[0][0].x.size(1),  # 确~]~Y~G~L~Z~D~S~^~D~N~J| 载~W~@~G
#        outputLayerSize=len(testLabels0[0][1]),
#        collate= 'max'
#    ).to(device)

    # ~J| 载模~^~K~]~C~G~M
#    Nmodel.load_state_dict(torch.load(args.network0))
#    Nmodel.eval()  # 设置模~^~K为~D估模~O

    # 进行评估
    (overallRes, overflowRes, reachSafetyRes, terminationRes, memSafetyRes), (overallChoices, overflowChoices, reachSafetyChoices, terminationChoices, memSafetyChoices),predicts = evaluate4(Pmodel, test_set, files=[x[0] for x in testLabels])

    # 处理结果
    #returnString = str(args).replace(",", "_").replace(" ", "").replace("\'", "").replace("Namespace", "").replace("(", "").replace(")", "") + "_" + str(time.time_ns())

    np.savez_compressed("testAlg-CPG.npz", overallRes=overallRes, overflowRes=overflowRes,
                        reachSafetyRes=reachSafetyRes, terminationRes=terminationRes, memSafetyRes=memSafetyRes,
                        overallChoices=overallChoices, overflowChoices=overflowChoices,
                        reachSafetyChoices=reachSafetyChoices, terminationChoices=terminationChoices,
                        memSafetyChoices=memSafetyChoices)

    # 保存模型状态
  #  torch.save(model.state_dict(), "testAlg-topk-P1.pt")
#    json.dump(predicts, open("testAlg-topk-P1.json", 'w'))

