import logging
import os
import sys
import os.path as osp
from datetime import datetime
from functools import wraps

def setup_logger(name, save_dir, if_train):
    logger = logging.getLogger(name)
    logger.setLevel(logging.DEBUG)

    ch = logging.StreamHandler(stream=sys.stdout)
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    now = datetime.now().strftime("%Y-%m-%d")
    if save_dir:
        if not osp.exists(save_dir):
            os.makedirs(save_dir)
        if if_train:
            fh = logging.FileHandler(os.path.join(save_dir, "train_log_{}.txt".format(now)), mode='a')
        else:
            fh = logging.FileHandler(os.path.join(save_dir, "test_log_{}.txt".format(now)), mode='a')
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(formatter)
        logger.addHandler(fh)
        logger.info(f"\n\n\n{'='*100}\n{'='*40} Start training {'='*40}\n{'='*100}")

    return logger

def func_info(f):
    @wraps(f)
    def wrap(*args, **kwargs):
        print(f"\n\n{'='*30} {f.__name__.upper()} {'='*30}")
        res = f(*args, **kwargs)
        return res
    return wrap