import torch
import torch.distributed as dist
import os, sys
import json
import logging
from rich.logging import RichHandler


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def paser_config_save(args, PATH):
    if isinstance(args, dict):
        with open(PATH+'/'+'config.json', 'w') as f:
            json.dump(args, f, indent=2)
    else:
        with open(PATH+'/'+'config.json', 'w') as f:
            json.dump(args.__dict__, f, indent=2)


def set_logging_defaults(logdir: str, name: str = 'main'):
    RICH_FORMAT = "[%(filename)s:%(lineno)s] >> %(message)s"
    FILE_HANDLER_FORMAT = \
    "[%(asctime)s]\t%(levelname)s\t[%(filename)s:%(funcName)s:%(lineno)s]\t>> %(message)s"

    logging.basicConfig(
        format=RICH_FORMAT,
        level=logging.INFO,
        handlers=[RichHandler(rich_tracebacks=True)])

    logger = logging.getLogger(name)
    log_path = os.path.join(logdir, f'{name}_log.txt')
    file_handler = logging.FileHandler(log_path,
                                       mode="a",
                                       encoding="utf-8")
    file_handler.setFormatter(logging.Formatter(FILE_HANDLER_FORMAT))
    logger.addHandler(file_handler)
    if is_main_process():
        logger.info(f"{name} logging start!")


def handle_exception(exc_type, exc_value, exc_traceback):
    logger = logging.getLogger('main')
    logger.error("Unexpected exception",
                 exc_info=(exc_type, exc_value, exc_traceback))
