"""
Handy quality of life functions for experiments. Copy and reuse.
"""

import os
import pprint
import logging
import dataclasses
from ml_collections import ConfigDict, config_flags
from functools import partial
from pathlib import Path
from datetime import datetime

##### Renaming starts #####

bind = partial

##### Renaming ends #####


##### IO starts #####


def timestamp() -> str:
  return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")


def make_unique_output_dir(base_folder: str, slurm_job_id: str) -> Path:
  output_dir = Path(base_folder) / f"{timestamp()}-{slurm_job_id}"
  output_dir.mkdir(parents=True)
  return output_dir


def read_configs(CDataClass):
  _CONFIG = config_flags.DEFINE_config_dict(
    "C",
    ConfigDict(dataclasses.asdict(CDataClass())),
    "Config dict",
    lock_config=True,
  )
  logging.info("raw config:\n" + pprint.pformat(_CONFIG.value))
  config = CDataClass(**_CONFIG.value)
  logging.info("final config:\n" + pprint.pformat(config))
  # TODO: save streams of config like Hydra
  #  - default
  #  - read from file
  #  - overwrite from flags
  #  - final version
  return config


def safe_write(file_path, content) -> None:
  if os.path.exists(file_path):
    with open(file_path, "r") as f:
      old_content = f.read()
    if old_content == content:
      logging.warning(f"{file_path} exists with the same content, skipping")
      return
    logging.warning(f"different old content: {old_content}")
    logging.warning(f"replacing with new content: {content}")
  with open(file_path, "w") as f:
    f.write(content)


def prefix_with_key(d, prefix):
  return {f"{prefix}{k}": v for k, v in d.items()}


def collect_usage(cpu=True, nvsmi=False):
  out = {}
  if cpu:
    import psutil

    gb = 1024**3
    cpus = psutil.cpu_count()
    memory = psutil.virtual_memory()
    stats = {
      "proc_cpu_affinity": len(psutil.Process().cpu_affinity()),
      "proc_cpu_usage": psutil.Process().cpu_percent() / 100,
      "proc_ram_frac": psutil.Process().memory_info().rss / memory.total,
      "proc_ram_gb": psutil.Process().memory_info().rss / gb,
      "total_cpu_count": cpus,
      "total_cpu_frac": psutil.cpu_percent() / 100,
      "total_ram_frac": memory.percent / 100,
      "total_ram_total_gb": memory.total / gb,
      "total_ram_used_gb": memory.used / gb,
      "total_ram_avail_gb": memory.available / gb,
    }
    out.update(stats)
  if nvsmi:
    logging.warning("nvsmi not implemented yet")
  return prefix_with_key(out, "usage/")
