import os
import json
from configs.utils import populate_defaults

from utils import get_expt_name, get_model_aug_str


def get_cfg(args, is_train=True):
    if hasattr(args, 'weights_model') and args.weights_model in ['clip', 'simsiam-in', 'simclr-in', 'barlow-in', 'erm-in']:
        args.arch = 'resnet50'
    elif args.dataset == 'spur_cifar10':
        if args.arch is None:
            args.arch = 'resnet18'
        if 'cifar_' not in args.arch:
            args.arch = 'cifar_' + args.arch
    
    if args.arch is not None:
        args.model = args.arch
    
    if args.weights_model == 'simclr-in' and not args.adapt:
        args.algorithm = 'simclr'
    
    if hasattr(args, 'eval_fullgrid') and args.eval_fullgrid:
        args.eval_only = True
    
    for name, val in vars(args).items():
        print(f'{name.replace("_"," ").capitalize()}: {val}')
    config = populate_defaults(args)

    config.name = get_expt_name(config, with_rep=False, is_train=is_train)

    if config.prune_perc == 0 or config.prune_method == '-':
        config.prune_method = '-'
        config.prune_perc = 0
    if config.reinit_method == '-' and config.prune_method == '-' and hasattr(args, 'head_weight_decay') and config.head_weight_decay is None:
        config.model_aug = False
    if config.reinit_method == '-':
        config.reinit_layer_th = -1
    
    if config.model_aug or config.use_pretrained:
        no_aug_log_dir = os.path.join(config.log_dir, config.name)
        seed_str = f'_S{args.seed}' if args.seed else ''
        config.pretrained_dir = os.path.join(no_aug_log_dir, f'ckpts{seed_str}/ckpt_last.pth.tar')
        config.name += get_model_aug_str(config)
    
    print(config.name)
    hparam_str = f'lr:{args.lr}_bs:{args.batch_size}_wd:{args.weight_decay}/' # m:{args.momentum}
    if config.no_hparam or config.model_aug or config.weights_model != '':
        hparam_str = ''
    else:
        config.hparam_str = hparam_str
    dir_str = f'{config.name}/{hparam_str}'
    print('dir str: ', dir_str)
    
    if not os.path.exists(f'{config.log_dir}/{config.name}'):
        os.mkdir(f'{config.log_dir}/{config.name}')
    config.log_dir = os.path.join(config.log_dir, dir_str)
    if not os.path.exists(config.log_dir):
        if config.eval_only:
            print(f'log directory does not exist: {config.log_dir}')
        os.mkdir(config.log_dir)
    print(config.log_dir)
    seed_str = f'_S{args.seed}' if args.seed else ''
    ft_str = '_ft' if args.use_pretrained else ''
    config.log_ckpts_dir = os.path.join(config.log_dir, f'ckpts{seed_str}{ft_str}/')
    if not os.path.exists(config.log_ckpts_dir):
        os.mkdir(config.log_ckpts_dir)
    config.ckpts_dir = config.output_dir
    
    if not os.path.exists(config.log_dir):
        os.mkdir(config.log_dir)
    if not os.path.exists(config.ckpts_dir):
        os.mkdir(config.ckpts_dir)
    argparse_dict = vars(config)
    config.eval_dir = os.path.join(config.log_dir, 'eval')
    if not os.path.exists(config.eval_dir):
        os.mkdir(config.eval_dir)
    
    if hasattr(args, 'head_weight_decay') and config.head_weight_decay is not None:
        config.model_aug = False

    if not config.eval_only:
        with open(f'{config.log_ckpts_dir}/config.json', 'w') as fp:
            json.dump(argparse_dict, fp,  indent=4)
            print('saving json file to ', f'{config.log_ckpts_dir}/config.json')
    
    if 'pretrained' in config.model_kwargs.keys():
        del config.model_kwargs['pretrained']
    
    if config.weights_model == 'erm-in':
        from torchvision.models import ResNet18_Weights, ResNet50_Weights
        if config.model == 'resnet50':
            config.model_kwargs['weights'] = ResNet50_Weights.DEFAULT
        elif config.model == 'resnet18':
            config.model_kwargs['weights'] = ResNet18_Weights.DEFAULT
        else:
            KeyError(f'no weights for model {config.model}')
    elif config.weights_model == '':
        config.model_kwargs['weights'] = None
    
    if hasattr(args, 'eval_ckpt') and config.eval_ckpt == '':
        config.eval_ckpt = 'last' if config.weights_model == '' else 'init'


    if config.dataset == 'bgchallenge':
        if config.train_set not in ['train', 'val', 'mixed_rand', 'only_fg']:
            config.train_set = 'train'  # change default
        assert config.eval_set in ['test', 'val', 'mixed_rand', 'only_fg', 'no_fg']
        assert config.train_set != config.eval_set

        config.eval_mi = False
        config.eval_spur = False

        config.eval_grouping = False

    if config.dataset == 'cifar10':
        config.eval_set = 'test'

    if hasattr(args, 'extra_vals'):
        if config.train_set in config.extra_vals:
            config.extra_vals.remove(config.train_set)
        if config.eval_set in config.extra_vals:
            config.extra_vals.remove(config.eval_set)

    if not config.eval_only:
        config.extra_vals = []

    return config, args
