import datetime
import threading
import time
import torch


class MemTracker:
    def __init__(self, loop_time: float = 0.001, output_file: str = 'cumemory_log.txt'):
        self.output_file = output_file
        self.loop_time = loop_time
        self.thread = None
            
    def create_track_thread(self, loop_time: float = 0.001, output_file: str = 'cumemory_log.txt'):
        '''
        create a gpu mem track thread
        
        file format: 
            time gap | time | current memory usage | max memory usage
            
        args: 
        - loop time: check time spacing (s)
        - output file: output file name
        '''
        # track thread function
        def record_memory_usage(loop_time: float = 0.001, output_file: str = 'cumemory_log.txt'):
            old_log = ''
            old_time = datetime.datetime.now()
            while True:
                max_memory = torch.cuda.max_memory_allocated()
                current_memory = torch.cuda.memory_allocated()
                log = f"Current memory usage: {current_memory/1024**2:.2f} MB; Max memory usage: {max_memory/1024**2:.2f} MB\n"
                if log != old_log: 
                    with open(output_file, 'a') as f:
                        new_time = datetime.datetime.now()
                        f.write(f"++{new_time - old_time} "
                            + new_time.strftime('%Y-%m-%d %H:%M:%S.%f')
                            + ": "+log)
                time.sleep(loop_time)
                old_log = log
                old_time = new_time
                
        with open(self.output_file, 'w') as f:
            f.write('')
        self.thread = threading.Thread(target=record_memory_usage, args=(self.loop_time, self.output_file))
        self.thread.start()

    def record_epoch(self, epoch: int, iter: int, train: bool = True):
        step = 'train' if train else 'validate'
        with open(self.output_file, 'a') as f:
            f.write(f'===========epoch: {epoch} iter: {iter} step: {step}============\n')
            
    def end_track(self):
        self.thread.join()