import contextlib
import sys
import time

import torch

if sys.version_info >= (3, 7):

    @contextlib.contextmanager
    def profile_time(trace_name,
                     name,
                     enabled=True,
                     stream=None,
                     end_stream=None):
        """Print time spent by CPU and GPU.

        Useful as a temporary context manager to find sweet spots of code
        suitable for async implementation.
        """
        if (not enabled) or not torch.cuda.is_available():
            yield
            return
        stream = stream if stream else torch.cuda.current_stream()
        end_stream = end_stream if end_stream else stream
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        stream.record_event(start)
        try:
            cpu_start = time.monotonic()
            yield
        finally:
            cpu_end = time.monotonic()
            end_stream.record_event(end)
            end.synchronize()
            cpu_time = (cpu_end - cpu_start) * 1000
            gpu_time = start.elapsed_time(end)
            msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
            msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
            print(msg, end_stream)
