import sys
import json
import torch
import random
import logging

import numpy as np


def pad_graph(G, max_nodes):
    while len(G.nodes) < max_nodes:
        G.add_node(len(G.nodes), m_idx=0, t_idx=0)
    return G


def get_ori_task_types(merged_list):
    with open("data/testdev_balanced_questions.json", "r") as file:
        gqa = json.load(file)
    cnt = 0
    for _, v in gqa.items():
        # early exist
        if cnt > 10100:
            break
        cnt += 1
        imageId, question, types = v["imageId"], v["question"], v["types"]
        for item in merged_list:
            if item["imageId"] == imageId and item["question"] == question:
                item["types"] = types
    return merged_list


def print_and_record(output_file, content):
    print(content)
    with open(output_file, "a") as f:
        original_stdout = sys.stdout
        sys.stdout = f
        print(content)
        sys.stdout = original_stdout


def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter("[%(asctime)s][%(levelname)s] %(message)s")

    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger
