import copy
import json
import os

DATASET_SETTINGS = {
    "agn": {
        "task_name": "agn",
        "max_seq_length": 128,
    },
    "alphanli": {
        "task_name": "alphanli",
        "max_seq_length": 80,
    },
    "codah": {
        "task_name": "codah",
        "max_seq_length": 80,
    },
    "csqa": {
        "task_name": "csqa",
        "max_seq_length": 64,
    },
    "hellaswag": {
        "task_name": "hellaswag",
        "max_seq_length": 136,
    },
    "mrp": {
        "task_name": "mr",
        "max_seq_length": 80,
    },
    "mrs": {
        "task_name": "mr",
        "max_seq_length": 80,
    },
    "piqa": {
        "task_name": "piqa",
        "max_seq_length": 80,
    },
    "swag": {
        "task_name": "swag",
        "max_seq_length": 80,
    },
}

AL_TYPES = ['bald', 'batchbald', 'entropy', 'greedy_coreset', 'random']

MODEL_SETTINGS = {
    'roberta-large': {
        'model_type': 'roberta',
        'per_gpu_train_batch_size': 2,
        'gradient_accumulation_steps': 8,
    },
    'roberta-base': {
        'model_type': 'roberta',
        'per_gpu_train_batch_size': 8,
        'gradient_accumulation_steps': 2,
    },
    'bert-base-uncased': {
        'model_type': 'bert',
        'per_gpu_train_batch_size': 8,
        'gradient_accumulation_steps': 2,
        'do_lower_case': True,
    },
}

if __name__ == '__main__':
    with open('runconfigs/base.json') as f:
        base_config = json.load(f)

    with open('runconfigs/data_paths.json') as f:
        data_paths = json.load(f)
        for k in DATASET_SETTINGS:
            DATASET_SETTINGS[k]['data_dir'] = os.path.abspath(data_paths[k])

    DATASET_SETTINGS = {k: v for k, v in DATASET_SETTINGS.items() if k == 'mrp'}
    MODEL_SETTINGS = {k: v for k, v in MODEL_SETTINGS.items() if k == 'roberta-base'}
    AL_TYPES = ['random']

    i = 1
    for dataset in DATASET_SETTINGS.keys():
        for model_name in MODEL_SETTINGS.keys():
            for al_type in AL_TYPES:
                config = copy.deepcopy(base_config)
                config.update(DATASET_SETTINGS[dataset])
                config.update(MODEL_SETTINGS[model_name])
                config['model_name_or_path'] = model_name
                config['run_config']['al_config']['score_method'] = al_type
                with open('runconfigs/config_v{}.json'.format(i), 'w') as f:
                    json.dump(config, f)
                i += 1

    # Multi-training
    #num_retrains = 5
    #for dataset in DATASET_SETTINGS.keys():
    #    for model_name in ['roberta-base']:
    #        for al_type in ['random', 'batchbald']:
    #            config = copy.deepcopy(base_config)
    #            config.update(DATASET_SETTINGS[dataset])
    #            config.update(MODEL_SETTINGS[model_name])
    #            config['num_train_retries'] = num_retrains
    #            config['model_name_or_path'] = model_name
    #            config['run_config']['al_config']['score_method'] = al_type
    #            with open('runconfigs/config_v{}.json'.format(i), 'w') as f:
    #                json.dump(config, f)
    #            i += 1

    n = i - 1
    print('num configs: {}'.format(n))
    print('python run_experiment.py --experiments_dir ../data/experiments/ run_al_mc.py {{1..{}}}'.format(n))
    print('python summarize_experiments.py --experiments_dir ../data/experiments/ experiment_v{{1..{}}}'.format(n))

