import json
import torch
import torch._tensor

import sys 
from models.sg_model import SceneGraphModel
from utils.dataloader import VQARTaskIndexDataset
import pickle
import os
from models.sg_model import test_SceneGraphModel
from models.idx2word import Idx2Word
from arguments import parser


def main():
    args = parser.parse_args()
    print(args)
    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # Create model
    meta_info = json.load(open(args.meta_f, "r"))
    idx2word = Idx2Word(meta_info)

    sg_model = SceneGraphModel(
        feat_dim=args.feat_dim,
        meta_info=meta_info,
        model_dir=args.exp_dir,
    )
    test_function = test_SceneGraphModel

    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        sg_model = sg_model.cuda(args.gpu)
    else:
        sg_model.cuda()

    # Load the testing dataset
    test_file = args.test_file
    test_scene_graphs_and_features_file = args.test_features_file
    with open(test_file, "rb") as test_samples_file:
        test_samples = pickle.load(test_samples_file)
    with open(test_scene_graphs_and_features_file, "rb") as file:
        test_features_and_scene_graphs = pickle.load(file)
    test_features_and_scene_graphs = {feature_and_scene_graph["image_id"]: feature_and_scene_graph for feature_and_scene_graph in test_features_and_scene_graphs}
    test_dataset = VQARTaskIndexDataset(test_samples)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        sampler=torch.utils.data.RandomSampler(
            test_dataset,
        ),
    )

    name_accuracy, attr_accuracy, rela_accuracy, accuracy = test_function(
        sg_model, test_loader, test_samples, test_features_and_scene_graphs, idx2word
    )
    print(
        f"Accuracy of model in {args.exp_dir} is {accuracy}: Name Acc {name_accuracy}, Attr Acc {attr_accuracy}, Rela Acc {rela_accuracy}"
    )


if __name__ == "__main__":
    main()
