#! /usr/bin/env python3
# coding=utf-8


import glob
import logging
import os
import re
import shutil

logger = logging.getLogger(__name__)


def sorted_ckpts(ckpts_dir, ckpt_prefix="checkpoint", use_mtime=False):
    ckpts_path = []
    glob_ckpts = glob.glob(os.path.join(ckpts_dir, "{}-*".format(ckpt_prefix)))

    for path in glob_ckpts:
        if use_mtime:
            ckpts_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(ckpt_prefix), path)
            if regex_match and regex_match.groups():
                ckpts_path.append((int(regex_match.groups()[0]), path))

    ckpts_sorted = sorted(ckpts_path)
    ckpts_sorted = [checkpoint[1] for checkpoint in ckpts_sorted]
    return ckpts_sorted


def rotate_ckpts(ckpts_dir, save_total_limit, ckpt_prefix="checkpoint", use_mtime=False):
    if not save_total_limit:
        return
    if save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    ckpts_sorted = sorted_ckpts(ckpts_dir, ckpt_prefix, use_mtime)
    if len(ckpts_sorted) <= save_total_limit:
        return

    num_ckpts_del = max(0, len(ckpts_sorted) - save_total_limit)
    checkpoints_to_be_deleted = ckpts_sorted[:num_ckpts_del]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)


def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{"|".join(prefices)})')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)
