import os
import gzip
import time
import threading

import jax

from .proto import profile_pb2


class MemTracking():

    def __init__(self, file_name: str = "./mem.tracking", interval: float = 0.2) -> None:
        self.file_name = file_name
        self.interval = interval

        try:
            from mpi4py import MPI
            comm = MPI.COMM_WORLD
            self.rank = comm.Get_rank()
        except:
            self.rank = 0

    def start(self):

        if self.rank == 0 and os.path.exists(self.file_name):
            os.remove(self.file_name)

        def inner(event):

            while True:
                if event.is_set():
                    break
                memory_allocated = profile_jax_memory_allocated()
                with open(self.file_name, "a") as f:
                    f.write(f"{time.time()}: {memory_allocated}\n")
                time.sleep(self.interval)

        self.stop_event = threading.Event()
        self.thread = threading.Thread(target=inner, args=(self.stop_event, ), daemon=True)
        self.thread.start()

    def summary(self):
        if self.stop_event:
            self.stop_event.set()
        if self.thread:
            self.thread.join()

        with open(self.file_name, "r") as f:
            mem_record = [float(line.split(":")[-1].strip()) for line in f.readlines()]

        return max(mem_record)


def profile_jax_memory_allocated():

    mem_prof = gzip.decompress(jax.profiler.device_memory_profile())
    profile_proto = profile_pb2.Profile()
    profile_proto.ParseFromString(mem_prof)

    memory_allocated = 0

    if 'gpu:0' in profile_proto.string_table:
        buffer_idx = list(profile_proto.string_table).index('buffer')
        gpu_idx = list(profile_proto.string_table).index('gpu:0')

        for sample in profile_proto.sample:
            flag = 0
            for label in sample.label:
                if label.str == buffer_idx:
                    flag += 1
                if label.str == gpu_idx:
                    flag += 1

            if flag == 2:
                memory_allocated += sample.value[1]

    return memory_allocated