import json
import os
import pathlib
import time

import numpy as np


class Logger:

  def __init__(self, step, outputs, multiplier=1):
    self._step = step
    self._outputs = outputs
    self._multiplier = multiplier
    self._last_step = None
    self._last_time = None
    self._metrics = []

  def add(self, mapping, prefix=None):
    step = int(self._step) * self._multiplier
    for name, value in dict(mapping).items():
      name = f'{prefix}_{name}' if prefix else name
      value = np.array(value)
      if len(value.shape) not in (0, 2, 3, 4, 5):
        raise ValueError(
            f"Shape {value.shape} for name '{name}' cannot be "
            "interpreted as scalar, image, or video.")
      self._metrics.append((step, name, value))

  def scalar(self, name, value):
    self.add({name: value})

  def image(self, name, value):
    self.add({name: value})

  def video(self, name, value):
    self.add({name: value})

  def write(self, fps=False):
    fps and self.scalar('fps', self._compute_fps())
    if not self._metrics:
      return
    for output in self._outputs:
      output(self._metrics)
    self._metrics.clear()

  def _compute_fps(self):
    step = int(self._step) * self._multiplier
    if self._last_step is None:
      self._last_time = time.time()
      self._last_step = step
      return 0
    steps = step - self._last_step
    duration = time.time() - self._last_time
    self._last_time += duration
    self._last_step = step
    return steps / duration


class TerminalOutput:

  def __call__(self, summaries):
    step = max(s for s, _, _, in summaries)
    scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
    formatted = {k: self._format_value(v) for k, v in scalars.items()}
    print(f'[{step}]', ' / '.join(f'{k} {v}' for k, v in formatted.items()))

  def _format_value(self, value):
    if value == 0:
      return '0'
    elif 0.01 < abs(value) < 10000:
      value = f'{value:.2f}'
      value = value.rstrip('0')
      value = value.rstrip('0')
      value = value.rstrip('.')
      return value
    else:
      value = f'{value:.1e}'
      value = value.replace('.0e', 'e')
      value = value.replace('+0', '')
      value = value.replace('+', '')
      value = value.replace('-0', '-')
    return value


class JSONLOutput:

  def __init__(self, logdir):
    self._logdir = pathlib.Path(logdir).expanduser()

  def __call__(self, summaries):
    scalars = {k: float(v) for _, k, v in summaries if len(v.shape) == 0}
    step = max(s for s, _, _, in summaries)
    with (self._logdir / 'metrics.jsonl').open('a') as f:
      f.write(json.dumps({'step': step, **scalars}) + '\n')

class GIFOutput:

  def __init__(self, logdir):
    self._logdir = pathlib.Path(logdir).expanduser()

  def __call__(self, summaries):
    for step, name, video in summaries:
      if len(video.shape) not in {4, 5}: 
        continue

      name = name if isinstance(name, str) else name.decode('utf-8')
      if np.issubdtype(video.dtype, np.floating):
        video = np.clip(255 * video, 0, 255).astype(np.uint8)

      if len(video.shape) == 4:
        T, H, W, C = video.shape
      else:
        B, T, H, W, C = video.shape
        video = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
      import imageio
      imageio.mimwrite(self._logdir / f"{name}_{step}.gif", video)



class TensorBoardOutput:

  def __init__(self, logdir, fps=20):
    # The TensorFlow summary writer supports file protocols like gs://. We use
    # os.path over pathlib here to preserve those prefixes.
    self._logdir = os.path.expanduser(logdir)
    self._writer = None
    self._fps = fps

  def __call__(self, summaries):
    import tensorflow as tf
    self._ensure_writer()
    self._writer.set_as_default()
    for step, name, value in summaries:
      if len(value.shape) == 0:
        tf.summary.scalar('scalars/' + name, value, step)
      elif len(value.shape) == 2:
        tf.summary.image(name, value, step)
      elif len(value.shape) == 3:
        tf.summary.image(name, value, step)
      elif len(value.shape) == 4:
        self._video_summary(name, value, step)
      elif len(value.shape) == 5:
        self._lexa_video_summary(name, value, step)
    self._writer.flush()

  def _ensure_writer(self):
    if not self._writer:
      import tensorflow as tf
      self._writer = tf.summary.create_file_writer(
          self._logdir, max_queue=1000)

  def _video_summary(self, name, video, step):
    import tensorflow as tf
    import tensorflow.compat.v1 as tf1
    name = name if isinstance(name, str) else name.decode('utf-8')
    if np.issubdtype(video.dtype, np.floating):
      video = np.clip(255 * video, 0, 255).astype(np.uint8)
    try:
      T, H, W, C = video.shape
      summary = tf1.Summary()
      image = tf1.Summary.Image(height=H, width=W, colorspace=C)
      image.encoded_image_string = encode_gif(video, self._fps)
      summary.value.add(tag=name, image=image)
      tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
    except (IOError, OSError) as e:
      print('GIF summaries require ffmpeg in $PATH.', e)
      tf.summary.image(name, video, step)

  def _lexa_video_summary(self, name, video, step, fps=40):
    import tensorflow as tf
    import tensorflow.compat.v1 as tf1
    name = name if isinstance(name, str) else name.decode('utf-8')
    if np.issubdtype(video.dtype, np.floating):
      video = np.clip(255 * video, 0, 255).astype(np.uint8)
    B, T, H, W, C = video.shape
    try:
      frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
      summary = tf1.Summary()
      image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C)
      image.encoded_image_string = encode_gif(frames, fps)
      summary.value.add(tag=name, image=image)
      tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
    except (IOError, OSError) as e:
      print('GIF summaries require ffmpeg in $PATH.', e)
      frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
      tf.summary.image(name, frames, step)


def encode_gif(frames, fps):
  from subprocess import Popen, PIPE
  h, w, c = frames[0].shape
  pxfmt = {1: 'gray', 3: 'rgb24'}[c]
  cmd = ' '.join([
      'ffmpeg -y -f rawvideo -vcodec rawvideo',
      f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
      '[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
      f'-r {fps:.02f} -f gif -'])
  proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
  for image in frames:
    proc.stdin.write(image.tobytes())
  out, err = proc.communicate()
  if proc.returncode:
    raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
  del proc
  return out
