import argparse

def parse_training_args():
    
    parser = argparse.ArgumentParser()
    
    # your dirs
    parser.add_argument('--home_root', type=str, default='')
    parser.add_argument('--ckpt_root', type=str, default='checkpoints')
    parser.add_argument('--latent_dir', type=str, default='data/latents')
    parser.add_argument('--saveto', type=str, default='generation-results')

    # general arguments
    parser.add_argument('--esm_path', type=str, default='esm/esm2_t33_650M_UR50D.pt')
    parser.add_argument('--random_seed', type=int, default=42)

    # dataset arguments
    parser.add_argument('--num_workers', type=int, default=16)
    parser.add_argument('--recon_quality_file', type=str, default='data/recon_quality.csv')
    parser.add_argument('--latent_stats_file', type=str, default='data/latent_distribution_stats.csv')
    parser.add_argument('--crop_longer_prot', type=bool, default=True)

    # model arguments
    parser.add_argument('--pretrained_model', type=str, default='dplm/dplm_150m')
    parser.add_argument('--yaml_config', type=str, default='utils/config/construct_150m.yaml')
    parser.add_argument('--lora', type=bool, default=True)
    parser.add_argument('--from_scratch', type=bool, default=False)
    parser.add_argument('--ckpt_period', type=int, default=10)

    # training arguments
    parser.add_argument('--date', type=str, default='')
    parser.add_argument('--continuous_training', type=bool, default=False)
    parser.add_argument('--continuous_training_all', type=bool, default=False)
    parser.add_argument('--prev_ckpt_path', type=str, default='')
    parser.add_argument('--parallel', type=bool, default=False)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--batch_size', type=int, default=48)
    parser.add_argument('--total_epochs', type=int, default=120)
    parser.add_argument('--complementary_masking', type=bool, default=True)
    parser.add_argument('--self_mixup', type=bool, default=False)
    parser.add_argument('--cfg_training', type=bool, default=True)
    parser.add_argument('--seq_reweighting', type=str, default='linear')
    parser.add_argument('--struct_reweighting', type=str, default='constant')

    parser.add_argument('--seq_struct_ratio', type=float, default=0.2)
    parser.add_argument('--add_orth_term', type=bool, default=False)
    parser.add_argument('--orth_term_scale', type=float, default=0.01)

    parser.add_argument('--opt_constant_only', type=bool, default=False)
    parser.add_argument('--lr_init', type=float, default=1e-4)
    parser.add_argument('--lr_min', type=float, default=1e-5)
    parser.add_argument('--warmup_epochs', type=int, default=5)

    parser.add_argument('--eval_seq_lens', default=[100, 200, 300, 400, 500], nargs='*', type=int)
    parser.add_argument('--num_seqs', type=int, default=100)
    parser.add_argument('--seq_temp', type=float, default=1.0)
    parser.add_argument('--struct_temp', type=float, default=0.5)
    parser.add_argument('--sampling_strategy', type=str, default='vanilla')
    parser.add_argument('--unmasking_strategy', type=str, default='deterministic')
    parser.add_argument('--seq_cfg', type=str, default=1.0)
    parser.add_argument('--seq_cfg_schedule', type=str, default='constant')
    parser.add_argument('--struct_cfg', type=str, default=1.0)
    parser.add_argument('--struct_cfg_schedule', type=str, default='constant')
    parser.add_argument('--max_iter', default=[100, 200, 300, 400, 500], nargs='*', type=int)

    args = parser.parse_args([])
    
    return args
