import time
from collections import deque
from statistics import mean

import tensorflow as tf


class Timer:

    def __init__(self) -> None:
        self.before_run_time = None
        self.time_in_run = []
        self.iteration_count = 0

    def before_run(self):
        self.before_run_time = time.monotonic()

    def after_run(self):
        time_elapsed = time.monotonic() - self.before_run_time
        self.time_in_run.append(time_elapsed)
        self.before_run_time = None

        print("**** Elapsed time in one run: {time_elapsed:.4f}s *******".format(time_elapsed=time_elapsed))

    def end(self):
        average = sum(self.time_in_run[len(self.time_in_run) - 100:-1]) / 100
        print("**** Elapsed time average over last 100 steps: {average:.4f}s ****".format(average=average))


class TimerHook(tf.estimator.SessionRunHook):

    def __init__(self, batch_size=1) -> None:
        self.before_run_time = None
        self.time_in_run = None
        self.steps = None
        self.batch_size = batch_size

    def begin(self):
        self.time_in_run = deque(maxlen=500)
        self.steps = 0

    def before_run(self, run_context):
        self.before_run_time = time.monotonic()

    def after_run(self, run_context, run_values):
        time_elapsed = time.monotonic() - self.before_run_time

        if self.steps > 100:  # Ignore first 100 steps
            self.time_in_run.append(time_elapsed)
        self.before_run_time = None
        self.steps += 1
        if self.steps and self.steps % 500 == 0:
            average = mean(self.time_in_run)
            s = f"Elapsed time over last {len(self.time_in_run)} batches:\n"
            s += "\ttime: {average:.4f}s\n".format(average=average)
            s += "\tFPS: {fps:.4f}\n".format(fps=self.batch_size / average)
            print(s)

    def end(self, session):
        average = mean(self.time_in_run)
        s = f"Elapsed time over last {len(self.time_in_run)} batches:\n"
        s += "\ttime: {average:.4f}s\n".format(average=average)
        s += "\tFPS: {fps:.4f}\n".format(fps=self.batch_size / average)
        print(s)
