import os

def train_roberta_ks(
        gpu, exp_name, batch_size=4, eval_batch_size=1,  accum_steps=32, pretrain_file=None, pretrain_steps=1000, label_smooth=False,
        lr=1e-5, warmup_steps=5000, num_epochs=3, num_steps=99999999999, valid_every=500, num_neg=3, schedule='linear',
        gpu_ratio=0.85, n_device=8, home_path=""
):
    params = {
        'train_file': 'debug_data/reddit_train_ks_{}neg.hdf5'.format(num_neg),
        'test_seen_file': 'debug_data/wizard_test_seen_ks.hdf5',
        'test_unseen_file': 'debug_data/wizard_test_unseen_ks.hdf5',
        'test_cmudog_file': 'debug_data/cmudog_test_ks.hdf5',

        'batch_size': batch_size,
        'eval_batch_size': eval_batch_size,
        'num_steps': num_steps,
        'accum_steps': accum_steps,
        'lr': lr,
        'clip': 1.0,
        'schedule': schedule,
        'label_smooth': label_smooth,

        'weight_decay': 0.01,
        'warmup_steps': warmup_steps,
        'adam_epsilon': 1e-8,
        'num_epochs': num_epochs,

        'print_every': 10,
        'valid_every': valid_every,

        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': 42,

        'bert_config': '{}/Data/pretrain-models/roberta-base'.format(home_path),
        'block_size': 128,
        'knowledge_trunc': 64,
        'text_trunc': 64,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u train_ks_roberta.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)





def train_bart_full(
        gpu, exp_name, batch_size=4, eval_batch_size=1, percentage=0.125, accum_steps=32, schedule='linear', opt='adamw', bleu_percent=1.0,
        max_length=40, min_length=0, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3, seed=42, drop_know=0.0, infilling=False, mlm=0.0, random_know=0.0, test_knowledge_num=999,
        model_size='base', lr=1e-5, lr2=1e-5, warmup_steps=5000, num_epochs=3, num_steps=99999999999, valid_every=500, use_sep=False, mix_know=0.0, test_knowledge_truncate=64, knowledge_truncate=64,
        min_segment_len=6, split_threshold=0.2, know_threshold=0.4, copy_threshold=1.0, stop_words_size=200, merge=False, reverse=False, add_prefix_space=True,
        max_source_len=210, max_target_len=46, full_knowledge_attn=False, gpu_ratio=0.85, n_device=8, home_path=""
):
    params = {
        'train_file': 'debug_data/reddit_train.hdf5',
        'test_seen_file': 'debug_data/wizard_valid_seen.hdf5',

        'batch_size': batch_size,
        'eval_batch_size': eval_batch_size,
        'num_steps': num_steps,
        'accum_steps': accum_steps,
        'lr': lr,
        'lr2': lr2,
        'clip': 1.0,

        'weight_decay': 0.01,
        'warmup_steps': warmup_steps,
        'adam_epsilon': 1e-8,  # 1e-8
        'num_epochs': num_epochs,
        'schedule': schedule,
        'opt': opt,

        'print_every': 10,
        'valid_every': valid_every,

        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': seed,

        'bart_config': '{}/Data/pretrain-models/facebook/bart-{}'.format(home_path, model_size),

        'min_segment_len': min_segment_len,
        'split_threshold': split_threshold,
        'know_threshold': know_threshold,
        'copy_threshold': copy_threshold,
        'stop_words_path': 'debug_data/stop_words.txt',
        'stop_words_size': stop_words_size,
        'merge': merge,
        'percentage': percentage,
        'drop_know': drop_know,
        'infilling': infilling,
        'mlm': mlm,
        'random_know': random_know,
        'bleu_percent': bleu_percent,
        'reverse': reverse,
        'use_sep': use_sep,
        'mix_know': mix_know,
        'add_prefix_space': add_prefix_space,
        'test_knowledge_truncate': test_knowledge_truncate,
        'test_knowledge_num': test_knowledge_num,

        'max_source_len': max_source_len,
        'max_target_len': max_target_len,
        'knowledge_trunc': knowledge_truncate,
        'text_trunc': 128,
        'full_knowledge_attn': full_knowledge_attn,

        'max_length': max_length,
        'min_length': min_length,
        'num_beams': num_beams,
        'repetition_penalty': repetition_penalty,
        'no_repeat_ngram_size': no_repeat_ngram_size,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u train_full.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)

def test_bart_full(
        gpu, exp_name, eval_batch_size=1, percentage=0.125,  scores_file='',
        max_length=40, min_length=0, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3, seed=42, drop_know=0.0, infilling=False, mlm=0.0,
        model_size='base', pretrain_file='1212_latent_base_1', knowledge_truncate=64, text_truncate=128,
        min_segment_len=6, split_threshold=0.2, know_threshold=0.4, copy_threshold=1.0, stop_words_size=200, merge=False,
        max_source_len=210, max_target_len=46, full_knowledge_attn=False, gpu_ratio=0.85, n_device=8, home_path=""
):
    params = {
        'test_seen_file': 'debug_data/wizard_test_seen_{}.hdf5'.format(scores_file),
        'test_unseen_file': 'debug_data/wizard_test_unseen_{}.hdf5'.format(scores_file),
        'cmudog_file': 'debug_data/cmudog_test_{}.hdf5'.format(scores_file),

        'eval_batch_size': eval_batch_size,

        'exp_name': exp_name,
        'log': 'reddit/log',
        'seed': seed,

        'bart_config': '{}/Data/pretrain-models/facebook/bart-{}'.format(home_path, model_size),
        'pretrain_file': pretrain_file,

        'min_segment_len': min_segment_len,
        'split_threshold': split_threshold,
        'know_threshold': know_threshold,
        'copy_threshold': copy_threshold,
        'stop_words_path': 'debug_data/stop_words.txt',
        'stop_words_size': stop_words_size,
        'merge': merge,
        'percentage': percentage,
        'drop_know': drop_know,
        'infilling': infilling,
        'mlm': mlm,

        'max_source_len': max_source_len,
        'max_target_len': max_target_len,
        'knowledge_trunc': knowledge_truncate,
        'text_trunc': text_truncate,
        'full_knowledge_attn': full_knowledge_attn,

        'max_length': max_length,
        'min_length': min_length,
        'num_beams': num_beams,
        'repetition_penalty': repetition_penalty,
        'no_repeat_ngram_size': no_repeat_ngram_size,

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u test_full.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee reddit/{}.txt'.format(params['exp_name'])
    os.system(command)


def train_sst(
        gpu, exp_name, batch_size=4, eval_batch_size=1,  accum_steps=32, block_size=32,
        lr=1e-5, warmup_steps=5000, num_epochs=3, num_steps=99999999999, valid_every=500, schedule='linear',
        gpu_ratio=0.85, n_device=8, home_path=""
):
    params = {
        'batch_size': batch_size,
        'eval_batch_size': eval_batch_size,
        'num_steps': num_steps,
        'accum_steps': accum_steps,
        'lr': lr,
        'clip': 1.0,
        'schedule': schedule,

        'weight_decay': 0.01,
        'warmup_steps': warmup_steps,
        'adam_epsilon': 1e-8,  # 1e-8
        'num_epochs': num_epochs,

        'print_every': 10,
        'valid_every': valid_every,

        'exp_name': exp_name,
        'log': 'sst2/log',
        'seed': 42,

        'bert_config': '{}/Data/pretrain-models/roberta-base'.format(home_path),
        'block_size': block_size, # 32/16

        'gpu_list': gpu,
        'gpu_ratio': gpu_ratio,
        'n_device': n_device,
    }
    command = "python -u train_sst2.py"
    sorted_params = sorted(params.items(), key=lambda x: x[0])
    for param in sorted_params:
        command += " --{}={}".format(param[0], param[1])
    print(command)
    command += ' | tee sst2/{}.txt'.format(params['exp_name'])
    os.system(command)




if __name__ == '__main__':
    train_roberta_ks('0', 'reddit_ks', num_neg=7, batch_size=32, eval_batch_size=32, accum_steps=2, lr=1e-5, warmup_steps=200, num_epochs=1, num_steps=36000000, valid_every=300, gpu_ratio=0.85, n_device=8, home_path="")
    train_sst('0', 'sst_block32', block_size=32, schedule='linear', batch_size=128, eval_batch_size=128, accum_steps=1, lr=1e-5, warmup_steps=100, num_epochs=10, num_steps=99999999999, valid_every=300, gpu_ratio=0.85, n_device=8, home_path="" )
    train_bart_full('0', 'latent_base', test_knowledge_truncate=40, schedule='linear', batch_size=32, eval_batch_size=16, percentage=0.5, accum_steps=2, max_length=46, num_beams=1, repetition_penalty=2.0, no_repeat_ngram_size=0, model_size='base', lr=5e-6, lr2=1e-4, warmup_steps=100, num_epochs=3, num_steps=1000000, valid_every=300, min_segment_len=12, split_threshold=0.9, know_threshold=0.5, copy_threshold=2.0, stop_words_size=50, merge=False, full_knowledge_attn=False, max_source_len=210, max_target_len=46, gpu_ratio=0.85, n_device=8, home_path="")
