# 2019/01/21~2018/07/12
# This function is taken almost verbatim from https://github.com/amaiasalvador
# and all credit should go to Amaia Salvador.

import os
import glob
import torchvision.utils as vutils
from operator import itemgetter
from tensorboardX import SummaryWriter

class Visualizer():
    def __init__(self, checkpoints_dir, name):
        self.win_size = 256
        self.name = name
        self.saved = False
        self.checkpoints_dir = checkpoints_dir
        self.ncols = 4

        # remove existing
        for filename in glob.glob(self.checkpoints_dir+"/events*"):
            os.remove(filename)
        self.writer = SummaryWriter(checkpoints_dir)

    def reset(self):
        self.saved = False

    # images: (b, c, 0, 1) array of images
    def image_summary(self, mode, epoch, images):
        images = vutils.make_grid(images, normalize=True, scale_each=True)
        self.writer.add_image('{}/Image'.format(mode), images, epoch)

    # figure (for matplotlib figures)
    def figure_summary(self, mode, epoch, fig):
        self.writer.add_figure('{}/Figure'.format(mode), fig, epoch)

    # text: type: ingredients/recipe
    def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
        for i, el in enumerate(text):  # text_list
            if not gt:  # we are printing a sample
                idx = el.nonzero().squeeze() + 1
            else:
                idx = el  # we are printing the ground truth

            words_list = itemgetter(*idx)(vocabulary)

            if len(words_list) <= max_length:
                self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), ', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
            else:
                self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'), 'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)

    # losses: dictionary of error labels and values
    def scalar_summary(self, mode, epoch, **args):
        for k, v in args.items():
            self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)

        self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))

    def histo_summary(self, model, step):
        """Log a histogram of the tensor of values."""

        for name, param in model.named_parameters():
            self.writer.add_histogram(name, param, step)

    def close(self):
        self.writer.close()
