import torch


class MemoryProfiler:
    def __init__(self, prefix="memory", disable=False, exit_on_done=True):
        self.prefix = prefix
        self.disable = disable
        self.exit_on_done = exit_on_done

    def __enter__(self):
        if self.disable:
            return
        torch.cuda.memory._record_memory_history(max_entries=100000)

    def __exit__(self, exc_type, exc_value, traceback):
        if self.disable:
            return
        torch.cuda.memory._dump_snapshot(f"{self.prefix}.pickle")
        torch.cuda.memory._record_memory_history(enabled=None)
        if self.exit_on_done:
            exit()
