try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

import torch
import torch.jit as jit


def draw_tb_graph(net, dir_path, batch_input=None):
    log_writer = SummaryWriter(dir_path)
    log_writer.add_graph(net, batch_input)
    log_writer.close()


