# import tensorflow as tf
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np


class Logger(object):
    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = SummaryWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        self.writer.add_scalar(tag, value, global_step=step)

    def image_summary(self, tag, images, step):
        """Log a list of images.
        Args::images: numpy of shape (Batch x C x H x W) in the range [-1.0, 1.0]
        """
        # with self.writer.as_default():
        #     imgs = None
        #     for i, j in enumerate(images):
        #         img = ((j*0.5+0.5)*255).round().astype('uint8')
        #         if len(img.shape) == 3:
        #             img = img.transpose(1, 2, 0)
        #         else:
        #             img = img[:, :, np.newaxis]
        #         img = img[np.newaxis, :]
        #         if not imgs is None:
        #             imgs = np.append(imgs, img, axis=0)
        #         else:
        #             imgs = img
            # tf.summary.image('{}'.format(tag), imgs, max_outputs=len(imgs), step=step)
        self.writer.add_image('{}'.format(tag), torch.tensor(images).permute(0, 3, 1, 2).squeeze(), global_step=step)
        # self.writer.add_image('{}'.format(tag), images, global_step=step)

    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""
        self.writer.add_histogram('{}'.format(tag), values, bins=bins, global_step=step)

    def config_summary(self, config):
        # Write config
        self.writer.add_text("config",
                [[k, str(w)] for k, w in sorted(vars(config).items())],
                global_step=0)
