




import time
from collections import defaultdict



_total_times = defaultdict(lambda: 0)
_start_times = defaultdict(lambda: -1)
_disabled_names = set()
_timer_stack = []
_running_timer = None
_disable_all = False


def disable_all():
    global _disable_all
    _disable_all = True


def enable_all():
    global _disable_all
    _disable_all = False


def disable(fn_name):
    
    _disabled_names.add(fn_name)


def enable(fn_name):
    
    _disabled_names.remove(fn_name)


def reset():
    
    global _running_timer
    _total_times.clear()
    _start_times.clear()
    _timer_stack.clear()
    _running_timer = None


def start(fn_name, use_stack=True):
    
    global _running_timer, _disable_all

    if _disable_all:
        return

    if use_stack:
        if _running_timer is not None:
            stop(_running_timer, use_stack=False)
            _timer_stack.append(_running_timer)
        start(fn_name, use_stack=False)
        _running_timer = fn_name
    else:
        _start_times[fn_name] = time.perf_counter()


def stop(fn_name=None, use_stack=True):
    
    global _running_timer, _disable_all

    if _disable_all:
        return

    if use_stack:
        if _running_timer is not None:
            stop(_running_timer, use_stack=False)
            if len(_timer_stack) > 0:
                _running_timer = _timer_stack.pop()
                start(_running_timer, use_stack=False)
            else:
                _running_timer = None
        else:
            print('Warning: timer stopped with no timer running!')
    else:
        if _start_times[fn_name] > -1:
            _total_times[fn_name] += time.perf_counter() - _start_times[fn_name]
        else:
            print('Warning: timer for %s stopped before starting!' % fn_name)


def print_stats():
    
    print()

    all_fn_names = [k for k in _total_times.keys() if k not in _disabled_names]

    max_name_width = max([len(k) for k in all_fn_names] + [4])
    if max_name_width % 2 == 1: max_name_width += 1
    format_str = ' {:>%d} | {:>10.4f} ' % max_name_width

    header = (' {:^%d} | {:^10} ' % max_name_width).format('Name', 'Time (ms)')
    print(header)

    sep_idx = header.find('|')
    sep_text = ('-' * sep_idx) + '+' + '-' * (len(header) - sep_idx - 1)
    print(sep_text)

    for name in all_fn_names:
        print(format_str.format(name, _total_times[name] * 1000))

    print(sep_text)
    print(format_str.format('Total', total_time() * 1000))
    print()
    return total_time()*1000

def return_stats():
    

    all_fn_names = [k for k in _total_times.keys() if k not in _disabled_names]

    max_name_width = max([len(k) for k in all_fn_names] + [4])
    if max_name_width % 2 == 1: max_name_width += 1
    format_str = ' {:>%d} | {:>10.4f} ' % max_name_width

    header = (' {:^%d} | {:^10} ' % max_name_width).format('Name', 'Time (ms)')

    sep_idx = header.find('|')
    sep_text = ('-' * sep_idx) + '+' + '-' * (len(header) - sep_idx - 1)

    return total_time()*1000


def total_time():
    
    return sum([elapsed_time for name, elapsed_time in _total_times.items() if name not in _disabled_names])


class env():
    

    def __init__(self, fn_name, use_stack=True):
        self.fn_name = fn_name
        self.use_stack = use_stack

    def __enter__(self):
        start(self.fn_name, use_stack=self.use_stack)

    def __exit__(self, e, ev, t):
        stop(self.fn_name, use_stack=self.use_stack)
