from typing import List, Callable, Dict
import collections
import re
import numpy as np
from elements import printing, timer

class Logger:
    def __init__(self, output_handles: List[Callable]):
        self.output_handles = output_handles
        self._metrics = []

    @timer.section('logger_add')
    def add(self, step: int, metrics: Dict, prefix: str | None = None):
        for key, value in metrics.items():
            key = f"{prefix}/{key}" if prefix else key
            value = np.asarray(value)
            self._key_value_check(key, value)
            self._metrics.append((step, key, value))

    def _key_value_check(self, key: str, value: np.ndarray):
        if not np.issubdtype(value.dtype, np.number):
            raise ValueError(
                f"Type {value.dtype} for key '{key}' is not a number."
            )
        if len(value.shape) not in (0, 1, 2, 3, 4):
            raise ValueError(
                f"Shape {value.shape} for name '{key}' cannot be "
                "interpreted as scalar, vector, image, or video."
            )

    @timer.section('logger_flush')
    def flush(self):
        if not self._metrics:
            return
        for handle in self.output_handles:
            handle(tuple(self._metrics))
        self._metrics.clear()

    def close(self):
        self.flush()

class TerminalOutput:

  def __init__(self, pattern=r'.*', name=None, limit=50):
    self._pattern = (pattern != r'.*') and re.compile(pattern)
    self._name = name
    self._limit = limit

  @timer.section('terminal')
  def __call__(self, summaries):
    step = max(s for s, _, _, in summaries)
    scalars = {
        k: float(v) for _, k, v in summaries
        if isinstance(v, np.ndarray) and len(v.shape) == 0}
    if self._pattern:
      scalars = {k: v for k, v in scalars.items() if self._pattern.search(k)}
    else:
      truncated = 0
      if len(scalars) > self._limit:
        truncated = len(scalars) - self._limit
        scalars = dict(list(scalars.items())[:self._limit])
    formatted = {k: self._format_value(v) for k, v in scalars.items()}
    if self._name:
      header = f'{"-" * 20}[{self._name} Step {step:_}]{"-" * 20}'
    else:
      header = f'{"-" * 20}[Step {step:_}]{"-" * 20}'
    content = ''
    if self._pattern:
      content += f"Metrics filtered by: '{self._pattern.pattern}'"
    elif truncated:
      content += f'{truncated} metrics truncated, filter to see specific keys.'
    content += '\n'
    if formatted:
      content += ' / '.join(f'{k} {v}' for k, v in formatted.items())
    else:
      content += 'No metrics.'
    printing.print_(f'\n{header}\n{content}\n', flush=True)

  def _format_value(self, value):
    value = float(value)
    if value == 0:
      return '0'
    elif 0.01 < abs(value) < 10000:
      value = f'{value:.2f}'
      value = value.rstrip('0')
      value = value.rstrip('0')
      value = value.rstrip('.')
      return value
    else:
      value = f'{value:.1e}'
      value = value.replace('.0e', 'e')
      value = value.replace('+0', '')
      value = value.replace('+', '')
      value = value.replace('-0', '-')
    return value

class WandBOutput:
  def __init__(self, name, pattern=r'.*', **kwargs):
    self._pattern = re.compile(pattern)
    import wandb
    wandb.init(name=name, **kwargs)
    self._wandb = wandb

  @timer.section('wandb')
  def __call__(self, summaries):
    bystep = collections.defaultdict(dict)
    wandb = self._wandb
    for step, name, value in summaries:
      if not self._pattern.search(name):
        continue
      if isinstance(value, str):
        bystep[step][name] = value
      elif len(value.shape) == 0:
        bystep[step][name] = float(value)
      elif len(value.shape) == 1:
        bystep[step][name] = wandb.Histogram(value)
      elif len(value.shape) in (2, 3):
        value = value[..., None] if len(value.shape) == 2 else value
        assert value.shape[3] in [1, 3, 4], value.shape
        if value.dtype != np.uint8:
          value = (255 * np.clip(value, 0, 1)).astype(np.uint8)
        value = np.transpose(value, [2, 0, 1])
        bystep[step][name] = wandb.Image(value)
      elif len(value.shape) == 4:
        assert value.shape[3] in [1, 3, 4], value.shape
        value = np.transpose(value, [0, 3, 1, 2])
        if value.dtype != np.uint8:
          value = (255 * np.clip(value, 0, 1)).astype(np.uint8)
        bystep[step][name] = wandb.Video(value)

    for step, metrics in bystep.items():
      self._wandb.log(metrics, step=step)
