#https://github.com/aimagelab/mammoth
import logging
import os
import random
import sys
from functools import partial

from typing import List
import numpy as np
import torch
from torch.utils.data import DataLoader

from utilities.Logger import Logger
def _get_gpu_memory_pynvml_all_processes(device_id: int = 0) -> int:
    """
    Use pynvml to get the memory allocated on the GPU.
    Returns the memory allocated on the GPU in Bytes.
    """
    if not hasattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}'):
        torch.cuda.pynvml.nvmlInit()  # only once
        handle = torch.cuda.pynvml.nvmlDeviceGetHandleByIndex(device_id)
        setattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}', handle)

    handle = getattr(_get_gpu_memory_pynvml_all_processes, f'handle_{device_id}')

    procs = torch.cuda.pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
    return sum([proc.usedGpuMemory for proc in procs])

def get_alloc_memory_all_devices(return_all=False) -> list[int]:
    """
    Returns the memory allocated on all the available devices.
    By default, tries to return the memory read from pynvml, if available.
    Else, it returns the memory `reserved` by torch.

    If `return_all` is set to True, it returns a tuple with the memory reserved, allocated and from pynvml.

    Values are in Bytes.
    """
    gpu_memory_reserved = []
    gpu_memory_allocated = []
    gpu_memory_nvidiasmi = []
    for i in range(torch.cuda.device_count()):
        _ = torch.tensor([1]).to(i)  # allocate memory to get more accurate reading from torch
        gpu_memory_reserved.append(torch.cuda.max_memory_reserved(i))
        gpu_memory_allocated.append(torch.cuda.max_memory_allocated(i))

        try:
            gpu_memory_nvidiasmi.append(_get_gpu_memory_pynvml_all_processes(i))
        except BaseException as e:
            Warning(f"Error while reading memory from pynvml: {e}")
            gpu_memory_nvidiasmi.append(-1)

    if return_all:
        return gpu_memory_reserved, gpu_memory_allocated, gpu_memory_nvidiasmi
    else:
        if any([g > 0 for g in gpu_memory_nvidiasmi]):
            return gpu_memory_nvidiasmi
        return gpu_memory_allocated
import platform
try:
    if platform.system() == 'Windows':
        import psutil
    else:
        from resource import getrusage, RUSAGE_CHILDREN, RUSAGE_SELF

    def get_memory_mb():
        """
        Get the memory usage of the current process and its children.

        Returns:
            dict: A dictionary containing the memory usage of the current process and its children.

            The dictionary has the following keys:
                - self: The memory usage of the current process.
                - children: The memory usage of the children of the current process.
                - total: The total memory usage of the current process and its children.
        """
        if platform.system() == 'Windows':
            process = psutil.Process()
            mem_info_self = process.memory_info().rss / 1024 / 1024  # in MB
            children = process.children(recursive=True)
            mem_info_children = sum(child.memory_info().rss / 1024 / 1024 for child in children)
            res = {
                "self": mem_info_self,
                "children": mem_info_children,
                "total": mem_info_self + mem_info_children
            }
        else:
            res = {
                "self": getrusage(RUSAGE_SELF).ru_maxrss / 1024,
                "children": getrusage(RUSAGE_CHILDREN).ru_maxrss / 1024,
                "total": getrusage(RUSAGE_SELF).ru_maxrss / 1024 + getrusage(RUSAGE_CHILDREN).ru_maxrss / 1024
            }
        return res
except BaseException as e:
    print(f"Error while reading memory: {e}")
    get_memory_mb = None

try:
    import torch

    if torch.cuda.is_available():

        def get_memory_gpu_mb():
            """
            Get the memory usage of all GPUs in MB.
            """

            return [d / 1024 / 1024 for d in get_alloc_memory_all_devices()]
    else:
        get_memory_gpu_mb = None
except BaseException as e:
    print(f"Error while reading memory from torch: {e}")
    get_memory_gpu_mb = None



class track_system_stats:
    """
    A context manager that tracks the memory usage of the system.
    Tracks both CPU and GPU memory usage if available.

    Usage:

    .. code-block:: python

        with track_system_stats() as t:
            for i in range(100):
                ... # Do something
                t()

            cpu_res, gpu_res = t.cpu_res, t.gpu_res

    Args:
        logger (Logger): external logger.
        disabled (bool): If True, the context manager will not track the memory usage.
    """

    def get_stats(self):
        """
        Get the memory usage of the system.

        Returns:
            tuple: (cpu_res, gpu_res) where cpu_res is the memory usage of the CPU and gpu_res is the memory usage of the GPU.
        """
        cpu_res = None
        if get_memory_mb is not None:
            cpu_res = get_memory_mb()['total']

        gpu_res = None
        if get_memory_gpu_mb is not None:
            gpu_res = get_memory_gpu_mb()

        return cpu_res, gpu_res

    def __init__(self, logger: Logger = None, disabled=False):
        self.logger = logger
        self.disabled = disabled
        self._it = 0

    def __enter__(self):
        if self.disabled:
            return self
        self.initial_cpu_res, self.initial_gpu_res = self.get_stats()
        if self.initial_cpu_res is None and self.initial_gpu_res is None:
            self.disabled = True
        else:
            if self.initial_gpu_res is not None:
                self.initial_gpu_res = {g: g_res for g, g_res in enumerate(self.initial_gpu_res)}

            self.avg_gpu_res = self.initial_gpu_res
            self.avg_cpu_res = self.initial_cpu_res

            self.max_cpu_res = self.initial_cpu_res
            self.max_gpu_res = self.initial_gpu_res

            if self.logger is not None:
                self.logger.log_system_stats(self.initial_cpu_res, self.initial_gpu_res)

        return self

    def __call__(self):
        if self.disabled:
            return

        cpu_res, gpu_res = self.get_stats()
        self.update_stats(cpu_res, gpu_res)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.disabled:
            return

        torch.cuda.synchronize()  # this allows to raise errors triggered previously by the GPU

        cpu_res, gpu_res = self.get_stats()
        self.update_stats(cpu_res, gpu_res)

    def update_stats(self, cpu_res, gpu_res):
        """
        Update the memory usage statistics.

        Args:
            cpu_res (float): The memory usage of the CPU.
            gpu_res (list): The memory usage of the GPUs.
        """
        if self.disabled:
            return

        self._it += 1

        alpha = 1 / self._it
        if self.initial_cpu_res is not None:
            self.avg_cpu_res = self.avg_cpu_res + alpha * (cpu_res - self.avg_cpu_res)
            self.max_cpu_res = max(self.max_cpu_res, cpu_res)

        if self.initial_gpu_res is not None:
            self.avg_gpu_res = {g: (g_res + alpha * (g_res - self.avg_gpu_res[g])) for g, g_res in enumerate(gpu_res)}
            self.max_gpu_res = {g: max(self.max_gpu_res[g], g_res) for g, g_res in enumerate(gpu_res)}
            gpu_res = {g: g_res for g, g_res in enumerate(gpu_res)}

        if self.logger is not None:
            self.logger.log_system_stats(cpu_res, gpu_res)

    def print_stats(self):
        """
        Print the memory usage statistics.
        """

        cpu_res, gpu_res = self.get_stats()

        # Print initial, average, final, and max memory usage
        print("System stats:")
        if cpu_res is not None:
            print(f"\tInitial CPU memory usage: {self.initial_cpu_res:.2f} MB", flush=True)
            print(f"\tAverage CPU memory usage: {self.avg_cpu_res:.2f} MB", flush=True)
            print(f"\tFinal CPU memory usage: {cpu_res:.2f} MB", flush=True)
            print(f"\tMax CPU memory usage: {self.max_cpu_res:.2f} MB", flush=True)

        if gpu_res is not None:
            for gpu_id, g_res in enumerate(gpu_res):
                print(f"\tInitial GPU {gpu_id} memory usage: {self.initial_gpu_res[gpu_id]:.2f} MB", flush=True)
                print(f"\tAverage GPU {gpu_id} memory usage: {self.avg_gpu_res[gpu_id]:.2f} MB", flush=True)
                print(f"\tFinal GPU {gpu_id} memory usage: {g_res:.2f} MB", flush=True)
                print(f"\tMax GPU {gpu_id} memory usage: {self.max_gpu_res[gpu_id]:.2f} MB", flush=True)
