import collections
import contextlib
import time

import numpy as np


class Timer:

  def __init__(self, columns=('frac', 'min', 'avg', 'max', 'count', 'total')):
    available = ('frac', 'avg', 'min', 'max', 'count', 'total')
    assert all(x in available for x in columns), columns
    self._columns = columns
    self._durations = collections.defaultdict(list)
    self._start = time.time()

  def reset(self):
    for timings in self._durations.values():
      timings.clear()
    self._start = time.time()

  @contextlib.contextmanager
  def scope(self, name):
    start = time.time()
    yield
    stop = time.time()
    self._durations[name].append(stop - start)

  def wrap(self, name, obj, methods):
    for method in methods:
      decorator = self.scope(f'{name}.{method}')
      setattr(obj, method, decorator(getattr(obj, method)))

  def stats(self, reset=True, log=False):
    metrics = {}
    metrics['duration'] = time.time() - self._start
    for name, durs in self._durations.items():
      available = {}
      available['count'] = len(durs)
      available['total'] = np.sum(durs)
      available['frac'] = np.sum(durs) / metrics['duration']
      if len(durs):
        available['avg'] = np.mean(durs)
        available['min'] = np.min(durs)
        available['max'] = np.max(durs)
      for key, value in available.items():
        if key in self._columns:
          metrics[f'{name}_{key}'] = value
    if log:
      self._log(metrics)
    if reset:
      self.reset()
    return metrics

  def _log(self, metrics):
    names = self._durations.keys()
    names = sorted(names, key=lambda k: -metrics[f'{k}_frac'])
    print('Timer:'.ljust(20), ' '.join(x.rjust(8) for x in self._columns))
    for name in names:
      values = [metrics[f'{name}_{col}'] for col in self._columns]
      print(f'{name.ljust(20)}', ' '.join((f'{x:8.4f}' for x in values)))
