import yaml
from easydict import EasyDict
import os
from .logger import print_log


def log_args_to_file(args, pre='args', logger=None):
    for key, val in args.__dict__.items():
        print_log(f'{pre}.{key} : {val}', logger=logger)


def log_config_to_file(cfg, pre='cfg', logger=None):
    for key, val in cfg.items():
        if isinstance(cfg[key], EasyDict):
            print_log(f'{pre}.{key} = edict()', logger=logger)
            log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
            continue
        print_log(f'{pre}.{key} : {val}', logger=logger)


def merge_new_config(config, new_config):
    for key, val in new_config.items():
        if not isinstance(val, dict):
            if key == '_base_':
                with open(new_config['_base_'], 'r') as f:
                    try:
                        val = yaml.load(f, Loader=yaml.FullLoader)
                    except:
                        val = yaml.load(f)
                config[key] = EasyDict()
                merge_new_config(config[key], val)
            else:
                config[key] = val
                continue
        if key not in config:
            config[key] = EasyDict()
        merge_new_config(config[key], val)
    return config


def cfg_from_yaml_file(cfg_file):
    config = EasyDict()
    with open(cfg_file, 'r') as f:
        try:
            new_config = yaml.load(f, Loader=yaml.FullLoader)
        except:
            new_config = yaml.load(f)
    merge_new_config(config=config, new_config=new_config)
    return config


def get_config(args, logger=None):
    if args.resume:
        cfg_path = os.path.join(args.experiment_path, 'config.yaml')
        if not os.path.exists(cfg_path):
            print_log("Failed to resume", logger=logger)
            raise FileNotFoundError()
        print_log(f'Resume yaml from {cfg_path}', logger=logger)
        args.config = cfg_path
    config = cfg_from_yaml_file(args.config)
    if not args.resume and args.local_rank == 0 and args.exp_name:
        save_experiment_config(args, config, logger)
    return config


def save_experiment_config(args, config, logger=None):
    config_path = os.path.join(args.experiment_path, 'config.yaml')
    os.system('cp %s %s' % (args.config, config_path))
    print_log(f'Copy the Config file from {args.config} to {config_path}', logger=logger)


def set_batch_size(args, config):
    if 'dataset' in config.keys():
        if args.distributed:
            assert config.total_bs % args.world_size == 0
            if config.dataset.get('train'):
                config.dataset.train.others.bs = config.total_bs // args.world_size
            if config.dataset.get('extra_train'):
                config.dataset.extra_train.others.bs = config.total_bs // args.world_size * 2
            if config.dataset.get('val'):
                config.dataset.val.others.bs = config.total_bs // args.world_size * 2
            if config.dataset.get('test'):
                config.dataset.test.others.bs = config.total_bs // args.world_size
        else:
            if config.dataset.get('train'):
                config.dataset.train.others.bs = config.total_bs
            if config.dataset.get('extra_train'):
                config.dataset.extra_train.others.bs = config.total_bs * 2
            if config.dataset.get('val'):
                config.dataset.val.others.bs = config.total_bs * 2
            if config.dataset.get('test'):
                config.dataset.test.others.bs = config.total_bs
