# Copyright 2025 The corr_faith Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Methods for logging resource usage and running time."""

import functools
import os
import signal
import subprocess
import sys
import threading
import time
from typing import Any, Callable, Mapping, Optional
from absl import logging
import accelerate
import numpy as np
import psutil
import torch


BYTE_UNITS = (" bytes", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB")
LOG_DELIMITER = ";  "


def human_readable_size(
    n_bytes: int,
) -> str:
  """Converts an integer into a human-readable representation.

  Args:
    n_bytes: The number of bytes to convert.

  Returns:
    A string representing the human-readable size using binary prefixes (KiB,
    MiB, GiB, etc.).
  """
  if n_bytes == 0:
    return "0" + BYTE_UNITS[0]
  i = 0
  while n_bytes >= 1024 and i < len(BYTE_UNITS) - 1:
    n_bytes /= 1024.0
    i += 1
  return f"{n_bytes:.2f}{BYTE_UNITS[i]}"


# Use cache so we record disk usage when first called.
@functools.cache
def get_initial_disk_usage() -> Any:
  return psutil.disk_usage(os.getcwd())


def log_gpu_and_cpu_mem_stats(
    prefix: str = "",
    log_cpu_interval: Optional[float] = None,
    log_level=logging.INFO,
):
  """Log statistics about resource usage for debugging purposes."""
  cpu_usage_message = ""
  if log_cpu_interval is not None:
    # This is a blocking call; best to use it in a separate thread.
    # OK to use threads. psutil uses time.sleep, which does release the GIL:
    # https://github.com/python/cpython/blob/7ba1f75f3f02b4b50ac6d7e17d15e467afa36aac/Modules/timemodule.c#L1880-L1882
    cpu_usage_by_core = psutil.cpu_percent(
        interval=log_cpu_interval, percpu=True
    )
    cpu_usage_message = (
        f"; CPU avg: {np.mean(cpu_usage_by_core):.2f}; by core:"
        f" {cpu_usage_by_core};"
    )
  cpu_mem = psutil.virtual_memory()
  used_cpu_mem = cpu_mem.total - cpu_mem.available
  disk_usage = psutil.disk_usage(os.getcwd())
  usage_strings = [
      f"Total RAM: {human_readable_size(cpu_mem.total)}",
      f"Used RAM: {human_readable_size(used_cpu_mem)}",
      f"RAM usage: {cpu_mem.percent:.1f}%",
      f"Total disk: {human_readable_size(disk_usage.total)}",
      f"Used disk: {human_readable_size(disk_usage.used)}",
      (
          "Disk usage minus initial:"
          f" {human_readable_size(disk_usage.used - get_initial_disk_usage().used)}"
      ),
      f"Disk usage: {disk_usage.percent:.1f}%",
  ]
  if cpu_usage_message:
    usage_strings.append(cpu_usage_message)
  if torch.cuda.is_available():
    free_gpu_mem, total_gpu_mem = torch.cuda.mem_get_info()
    used_gpu_mem = total_gpu_mem - free_gpu_mem
    usage_strings.extend([
        f"Total GPU: {human_readable_size(total_gpu_mem)}",
        f"Used GPU: {human_readable_size(used_gpu_mem)}",
        f"GPU mem usage: {used_gpu_mem / total_gpu_mem:.1%}",
    ])
  logging.log(log_level, "%s%s", prefix, LOG_DELIMITER.join(usage_strings))


class LogWallTime(object):
  """Compute the time between entry and exit, and log on exit."""

  def __init__(
      self, name: str, stats_dict: Optional[Mapping[str, float]] = None
  ):
    self.name = name
    self.stats_dict = stats_dict

  def __enter__(self):
    self.wall_time = time.perf_counter()
    return self

  def __exit__(self, exc_type, exc_val, exc_tb):
    self.wall_time = time.perf_counter() - self.wall_time
    logging.info("%s took %.2fs", self.name, self.wall_time)
    if self.stats_dict is not None:
      self.stats_dict[self.name] = self.wall_time


def run_fn_periodically(
    fn: Callable[[], None],
    period: float,
    stop_event: threading.Event,
):
  call_time = time.time()
  fn()
  if not stop_event.is_set():
    time_to_next_call = period - (time.time() - call_time)
    threading.Timer(
        time_to_next_call, run_fn_periodically, [fn, period, stop_event]
    ).start()


class RunPeriodically(object):
  """Call a function periodically in a new thread."""

  def __init__(self, fn: Callable[[], None], period: float):
    self._fn = fn
    self._period = period
    self._stop_event = threading.Event()

  def __enter__(self):
    threading.Timer(
        0.0, run_fn_periodically, [self._fn, self._period, self._stop_event]
    ).start()

  def __exit__(self, exc_type, exc_val, exc_tb):
    logging.info("PERIODIC LOG: finishing run.")
    self._stop_event.set()


class LogMemoryPeriodically(RunPeriodically):

  def __init__(self, period: float = 30.0):
    super().__init__(
        functools.partial(
            log_gpu_and_cpu_mem_stats,
            prefix="PERIODIC LOG: ",
            log_cpu_interval=period * 0.99,
        ),
        period,
    )


def init_logging(accelerator: accelerate.Accelerator) -> None:
  """Suppress logs for processes other than the main."""
  initial_disk_usage = get_initial_disk_usage()
  absl_handler = logging.get_absl_handler()
  logging.info("Initial absl stream: %s", absl_handler.python_handler.stream)
  absl_handler.python_handler.stream = sys.stdout
  logging.info("Updated absl stream: %s", absl_handler.python_handler.stream)
  logging.get_absl_handler().python_handler.stream = sys.stdout
  process_id = os.getpid()
  user_id = os.getuid()
  process_index = accelerator.process_index
  logging.info(
      "PROCESS_INDEX: %d; PID: %d; UID: %d",
      process_index,
      process_id,
      user_id,
  )
  if accelerator.is_main_process:
    logging.info("Process %d will write logs.", process_index)
  else:
    logging.info(
        "Suppressing logs for process %d to prevent duplication.", process_index
    )
    logging.set_verbosity("error")
    logging.info("Process %d, this should not appear.", process_index)

  logging.info(
      "Initial disk usage: %s", human_readable_size(initial_disk_usage.used)
  )


def log_nvidia_smi_if_cuda_available(log_level: int = logging.INFO) -> None:
  if torch.cuda.is_available():
    try:
      nvidia_smi_out = subprocess.check_output("nvidia-smi").decode("utf-8")
    except subprocess.CalledProcessError as e:
      raise RuntimeError(
          f"command '{e.cmd}' return with error (code {e.returncode}):"
          f" {e.output}"
      ) from e
    logging.log(log_level, "nvidia-smi output:\n%s", nvidia_smi_out)
  else:
    logging.log(log_level, "CUDA is not available, not logging nvidia-smi.")


def log_system_debugging_info() -> None:
  """Log info such as python and CUDA versions."""
  logging.info("Python version: %s", sys.version)
  cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
  logging.info("cuda_visible_devices: %s", cuda_visible_devices)
  cuda_is_available = torch.cuda.is_available()
  logging.info("cuda.is_available: %s", cuda_is_available)
  logging.info("cuda.device_count: %s", torch.cuda.device_count())
  if cuda_is_available:
    log_nvidia_smi_if_cuda_available()
    current_device = torch.cuda.current_device()
    logging.info("cuda.current_device: %s", current_device)
    logging.info("cuda.device: %s", torch.cuda.device(current_device))
    logging.info("device name: %s", torch.cuda.get_device_name(current_device))
    log_gpu_and_cpu_mem_stats()
  distributed_is_initialized = torch.distributed.is_initialized()
  logging.info(
      "torch.distributed.is_initialized(): %s", distributed_is_initialized
  )
  if torch.distributed.is_initialized():
    # pylint: disable=logging-fstring-interpolation
    logging.info(f"{torch.distributed.get_rank()=}")
    logging.info(f"{torch.distributed.get_world_size()=}")
    logging.info(f"{torch.distributed.is_mpi_available()=}")
    logging.info(f"{torch.distributed.is_nccl_available()=}")
    logging.info(f"{torch.distributed.is_gloo_available()=}")
    logging.info(f"{torch.distributed.is_torchelastic_launched()=}")


def handle_stop_signals(signum, frame):
  """If we receive a stop signal, log debug stats before exiting."""
  logging.error("Received stop signal %s.", signum)
  logging.error("Frame: %s", frame)
  log_gpu_and_cpu_mem_stats(log_level=logging.ERROR)
  log_nvidia_smi_if_cuda_available(log_level=logging.ERROR)
  logging.error("Completed error logging, finishing run now.")
  # Restore default handler for the given signal.
  signal.signal(signum, signal.SIG_DFL)
  sys.exit(128 + signum)


def add_stop_signal_handlers() -> None:
  signal.signal(signal.SIGINT, handle_stop_signals)
  signal.signal(signal.SIGTERM, handle_stop_signals)
  # Trying to catch SIGKILL gives "OSError: [Errno 22] Invalid argument."
  # signal.signal(signal.SIGKILL, handle_stop_signals)
