from __future__ import annotations
from collections import defaultdict
import pickle
import sys
import torch
import json
from typing import Any
from tabulate import tabulate
import os
import time
import psutil
import pandas as pd


def get_mem_info(pid: int) -> dict[str, int]:
    res = defaultdict(int)
    for mmap in psutil.Process(pid).memory_maps():
        res["rss"] += mmap.rss
        res["pss"] += mmap.pss
        res["uss"] += mmap.private_clean + mmap.private_dirty
        res["shared"] += mmap.shared_clean + mmap.shared_dirty
        if mmap.path.startswith("/"):
            res["shared_file"] += mmap.shared_clean + mmap.shared_dirty
    return res


class MemoryMonitor:
    def __init__(self, csv_file_name="mem_out.csv", pids: list[int] = None):
        if pids is None:
            pids = [os.getpid()]
            names = [f"Main_process: {pids[0]}"]
        self.pids = pids
        self.names = names
        self.csv_file_path = csv_file_name
        if os.path.isfile(self.csv_file_path):
            os.remove(self.csv_file_path)

    def add_pid(self, pid: int, name: str = None):
        assert pid not in self.pids
        self.pids.append(pid)
        self.names.append(f"Worker_process: {pid}")

    def _refresh(self):
        self.data = {
            self.names[i]: get_mem_info(self.pids[i]) for i in range(len(self.pids))
        }
        return self.data

    def table(self) -> str:
        self._refresh()
        table = []
        keys = list(list(self.data.values())[0].keys())
        now = str(int(time.perf_counter() % 1e5))
        self.save_csv()
        for pid, data in self.data.items():
            table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys))
        return tabulate(table, headers=["time", "PID"] + keys)

    def save_csv(self):

        df = (
            pd.DataFrame.from_dict(self.data, orient="index")
            .reset_index()
            .rename(columns={"index": "PID"})
        )
        csv_file_path = self.csv_file_path

        # Check if file exists to determine whether to add headers
        file_exists = os.path.isfile(csv_file_path)

        # Append to the CSV, omitting headers if the file already exists
        df.to_csv(csv_file_path, mode="a", index=False, header=not file_exists)

    def str(self):
        self._refresh()
        keys = list(list(self.data.values())[0].keys())
        res = []
        for pid in self.pids:
            s = f"PID={pid}"
            for k in keys:
                v = self.format(self.data[pid][k])
                s += f", {k}={v}"
            res.append(s)
        return "\n".join(res)

    @staticmethod
    def format(size: int) -> str:
        for unit in ("", "K", "M", "G"):
            if size < 1024:
                break
            size /= 1024.0
        return "%.1f%s" % (size, unit)


def read_sample(x):
    # A function that is supposed to read object x, incrementing its refcount.
    # This mimics what a real dataloader would do.
    if sys.version_info >= (3, 10, 6):
        # Before this version, pickle does not increment refcount. This is a bug that's
        # fixed in https://github.com/python/cpython/pull/92931.
        return pickle.dumps(x)
    else:
        import msgpack

        return msgpack.dumps(x)


class DatasetFromList(torch.utils.data.Dataset):
    def __init__(self, lst):
        self.lst = lst

    def __len__(self):
        return len(self.lst)

    def __getitem__(self, idx: int):
        return self.lst[idx]


if __name__ == "__main__":
    from serialize import NumpySerializedList

    monitor = MemoryMonitor()
    print("Initial", monitor.str())
