import json
import os
import pathlib
import time

import numpy as np


class Logger:

    def __init__(self, step, outputs, multiplier=1):
        self._step = step
        self._outputs = outputs
        self._multiplier = multiplier
        self._last_step = None
        self._last_time = None
        self._metrics = []

    def add(self, mapping, prefix=None):
        step = int(self._step) * self._multiplier
        for name, value in dict(mapping).items():
            name = f'{prefix}_{name}' if prefix else name
            value = np.array(value)
            if len(value.shape) not in (0, 2, 3, 4):
                raise ValueError(
                    f"Shape {value.shape} for name '{name}' cannot be "
                    "interpreted as scalar, image, or video.")
            self._metrics.append((step, name, value))

    def scalar(self, name, value):
        self.add({name: value})

    def image(self, name, value):
        self.add({name: value})

    def video(self, name, value):
        self.add({name: value})

    def write(self, fps=False):
        fps and self.scalar('fps', self._compute_fps())
        if not self._metrics:
            return
        for output in self._outputs:
            output(self._metrics)
        self._metrics.clear()

    def _compute_fps(self):
        step = int(self._step) * self._multiplier
        if self._last_step is None:
            self._last_time = time.time()
            self._last_step = step
            return 0
        steps = step - self._last_step
        duration = time.time() - self._last_time
        self._last_time += duration
        self._last_step = step
        return steps / duration


class TerminalOutput:

    def __call__(self, summaries):
        step = max(s for s, _, _, in summaries)
        scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
        formatted = {k: self._format_value(v) for k, v in scalars.items()}
        print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items()))

    def _format_value(self, value):
        if value == 0:
            return '0'
        elif 0.01 < abs(value) < 10000:
            value = f'{value:.2f}'
            value = value.rstrip('0')
            value = value.rstrip('0')
            value = value.rstrip('.')
            return value
        else:
            value = f'{value:.1e}'
            value = value.replace('.0e', 'e')
            value = value.replace('+0', '')
            value = value.replace('+', '')
            value = value.replace('-0', '-')
        return value


class JSONLOutput:

    def __init__(self, logdir):
        self._logdir = pathlib.Path(logdir).expanduser()

    def __call__(self, summaries):
        scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
        step = max(s for s, _, _, in summaries)
        with (self._logdir / 'metrics.jsonl').open('a') as f:
            f.write(json.dumps({'step': step, **scalars}) + '\n')


class TensorBoardOutput:

    def __init__(self, logdir, fps=20):
        # The TensorFlow summary writer supports file protocols like gs://. We use
        # os.path over pathlib here to preserve those prefixes.
        self._logdir = os.path.expanduser(logdir)
        self._writer = None
        self._fps = fps

    def __call__(self, summaries):
        import tensorflow as tf
        self._ensure_writer()
        self._writer.set_as_default()
        for step, name, value in summaries:
            if len(value.shape) == 0:
                tf.summary.scalar('scalars/' + name, value, step)
            elif len(value.shape) == 2:
                tf.summary.image(name, value, step)
            elif len(value.shape) == 3:
                tf.summary.image(name, value, step)
            elif len(value.shape) == 4:
                self._video_summary(name, value, step)
        self._writer.flush()

    def _ensure_writer(self):
        if not self._writer:
            import tensorflow as tf
            self._writer = tf.summary.create_file_writer(
                self._logdir, max_queue=1000)

    def _video_summary(self, name, video, step):
        import tensorflow as tf
        import tensorflow.compat.v1 as tf1
        name = name if isinstance(name, str) else name.decode('utf-8')
        if np.issubdtype(video.dtype, np.floating):
            video = np.clip(255 * video, 0, 255).astype(np.uint8)
        try:
            T, H, W, C = video.shape
            summary = tf1.Summary()
            image = tf1.Summary.Image(height=H, width=W, colorspace=C)
            image.encoded_image_string = encode_gif(video, self._fps)
            summary.value.add(tag=name, image=image)
            tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
        except (IOError, OSError) as e:
            print('GIF summaries require ffmpeg in $PATH.', e)
            tf.summary.image(name, video, step)


def encode_gif(frames, fps):
    from subprocess import Popen, PIPE
    h, w, c = frames[0].shape
    pxfmt = {1: 'gray', 3: 'rgb24'}[c]
    cmd = ' '.join([
        'ffmpeg -y -f rawvideo -vcodec rawvideo',
        f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
        '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
        f'-r {fps:.02f} -f gif -'])
    proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
    for image in frames:
        proc.stdin.write(image.tobytes())
    out, err = proc.communicate()
    if proc.returncode:
        raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
    del proc
    return out
