import time
from collections import OrderedDict
import random
import io
import pprint

import numpy as np

import absl.flags
from absl import logging

import torch
import tensorflow as tf

from PIL import Image



class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def define_flags_with_default(**kwargs):
    for key, val in kwargs.items():
        if isinstance(val, bool):
            # Note that True and False are instances of int.
            absl.flags.DEFINE_bool(key, val, 'automatically defined flag')
        elif isinstance(val, int):
            absl.flags.DEFINE_integer(key, val, 'automatically defined flag')
        elif isinstance(val, float):
            absl.flags.DEFINE_float(key, val, 'automatically defined flag')
        elif isinstance(val, str):
            absl.flags.DEFINE_string(key, val, 'automatically defined flag')
        else:
            raise ValueError('Incorrect value type')
    return kwargs


def set_random_seed(seed):
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    random.seed(seed)


class TensorBoardLogger(object):
    """Logging to TensorBoard outside of TensorFlow ops."""

    def __init__(self, output_dir):
        if not tf.io.gfile.exists(output_dir):
            tf.gfile.MakeDirs(output_dir)
        self.file_writer = tf.compat.v1.summary.FileWriter(output_dir)

    def log_scaler(self, step, name, value):
        summary = tf.compat.v1.Summary(
            value=[tf.compat.v1.Summary.Value(tag=name, simple_value=value)]
        )
        self.file_writer.add_summary(summary, step)

    def log_image(self, step, name, image):
        summary = tf.compat.v1.Summary(
            value=[tf.compat.v1.Summary.Value(
                tag=name,
                image=self._make_image(image)
            )]
        )
        self.file_writer.add_summary(summary, step)

    def log_images(self, step, data):
        if len(data) == 0:
            return
        summary = tf.compat.v1.Summary(
            value=[
                tf.compat.v1.Summary.Value(tag=name, image=self._make_image(image))
                for name, image in data.items() if image is not None
            ]
        )
        self.file_writer.add_summary(summary, step)

    def _make_image(self, tensor):
        """Convert an numpy representation image to Image protobuf"""
        height, width, channel = tensor.shape
        image = Image.fromarray(tensor)
        output = io.BytesIO()
        image.save(output, format='PNG')
        image_string = output.getvalue()
        output.close()
        return tf.compat.v1.Summary.Image(
            height=height,
            width=width,
            colorspace=channel,
            encoded_image_string=image_string
        )

    def log_dict(self, step, data):
        summary = tf.compat.v1.Summary(
            value=[
                tf.compat.v1.Summary.Value(tag=name, simple_value=value)
                for name, value in data.items() if value is not None
            ]
        )
        self.file_writer.add_summary(summary, step)

    def flush(self):
        self.file_writer.flush()



def print_flags(flags, flags_def):
    logging.info(
        'Running training with hyperparameters: \n{}'.format(
            pprint.pformat(
                ['{}: {}'.format(key, getattr(flags, key)) for key in flags_def]
            )
        )
    )
