from argparse import ArgumentParser
import os


def parse_train_args():
    parser = ArgumentParser()
    ## Trainer settings
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--validate", action='store_true', default=False)
    parser.add_argument("--num_workers", type=int, default=4)
    
    ## Epoch settings
    group = parser.add_argument_group("Epoch settings")
    group.add_argument("--epochs", type=int, default=100)
    group.add_argument("--overfit", action='store_true')
    group.add_argument("--overfit_frame", action='store_true')
    group.add_argument("--train_batches", type=int, default=None)
    group.add_argument("--val_batches", type=int, default=None)
    group.add_argument("--val_repeat", type=int, default=1)
    group.add_argument("--inference_batches", type=int, default=0)
    group.add_argument("--batch_size", type=int, default=8)
    group.add_argument("--val_freq", type=int, default=None)
    group.add_argument("--val_epoch_freq", type=int, default=1)
    group.add_argument("--no_validate", action='store_true')
    group.add_argument("--teaching_force", action='store_true')
    group.add_argument("--load_from_ckpt", action='store_true')
    group.add_argument('--pretrained_ckpt', type=str, default=None)
    group.add_argument("--auxiliary", action='store_true')

    ## Logging args
    group = parser.add_argument_group("Logging settings")
    group.add_argument("--print_freq", type=int, default=100)
    group.add_argument("--ckpt_freq", type=int, default=1)
    group.add_argument("--wandb", action="store_true")
    group.add_argument("--run_name", type=str, default="default")
    

    ## Optimization settings
    group = parser.add_argument_group("Optimization settings")
    group.add_argument("--accumulate_grad", type=int, default=1)
    group.add_argument("--grad_clip", type=float, default=1.)
    group.add_argument("--check_grad", action='store_true')
    group.add_argument('--grad_checkpointing', action='store_true')
    group.add_argument('--adamW', action='store_true')
    group.add_argument("--lr", type=float, default=1e-4)
    group.add_argument('--precision', type=str, default='32-true')
    group.add_argument("--threshold", type=int, default=2)
    
    
    ## Training data 
    group = parser.add_argument_group("Training data settings")
    group.add_argument('--train_split', type=str, default=None, required=True)
    group.add_argument('--val_split', type=str, default=None, required=True)
    group.add_argument('--data_dir', type=str, default=None, required=True)
    group.add_argument('--num_frames', type=int, default=50)
    group.add_argument('--crop', type=int, default=256)
    group.add_argument('--atlas', action='store_true')
    group.add_argument('--copy_frames', action='store_true')
    group.add_argument('--no_pad', action='store_true')
    group.add_argument('--short_md', action='store_true')
    group.add_argument('--sample', action='store_true')
    group.add_argument('--mode', choices=['gradient', 'data'], default='data')  
    group.add_argument('--sample_ratio', type=float, default=0.5)  
    group.add_argument('--suffix', type=str, default='')

    ### Masking settings
    group = parser.add_argument_group("Masking settings")
    group.add_argument('--no_aa_emb', action='store_true')
    group.add_argument("--no_torsion", action='store_true')
    group.add_argument("--no_design_torsion", action='store_true')
    group.add_argument("--supervise_no_torsions", action='store_true')
    group.add_argument("--supervise_all_torsions", action='store_true')

    ## Ablations settings
    group = parser.add_argument_group("Ablations settings")
    group.add_argument('--no_offsets', action='store_true')
    group.add_argument('--no_frames', action='store_true')
    
    ## Model settings
    group = parser.add_argument_group("Model settings")
    group.add_argument('--hyena', action='store_true')
    group.add_argument('--no_rope', action='store_true')
    group.add_argument('--dropout', type=float, default=0.0)
    group.add_argument('--scale_factor', type=float, default=1.0)
    group.add_argument('--interleave_ipa', action='store_true')
    group.add_argument('--prepend_ipa', action='store_true')
    group.add_argument('--oracle', action='store_true')
    group.add_argument('--num_layers', type=int, default=5)
    group.add_argument('--embed_dim', type=int, default=384)
    group.add_argument('--mha_heads', type=int, default=16)
    group.add_argument('--ipa_heads', type=int, default=4)
    group.add_argument('--ipa_head_dim', type=int, default=32)
    group.add_argument('--ipa_qk', type=int, default=8)
    group.add_argument('--ipa_v', type=int, default=8)
    group.add_argument('--add_auxiliary_loss', action='store_true')
    group.add_argument('--time_multiplier', type=float, default=100.)
    group.add_argument('--abs_pos_emb', action='store_true')
    group.add_argument('--abs_time_emb', action='store_true')
    group.add_argument('--latent_dim', type=int, default=13)
    group.add_argument('--time_model', choices=['rnn', 'xlstm', 'fno', 'transformer'])
    group.add_argument('--seq_emb', action='store_true')

    ## video settings
    group = parser.add_argument_group("Video settings")
    group.add_argument('--sim_condition', action='store_true')
    group.add_argument('--frame_interval', type=int, default=None)
    group.add_argument('--cond_interval', type=int, default=None) # for superresolution
    
    parser.add_argument('--seed', type=int, default=137,
                        help='Random seed for reproducibility')
    
    parser.add_argument('--local_rank', type=int, default=-1, metavar='N',
                        help='Local process rank.')

    args = parser.parse_args()
    os.environ["MODEL_DIR"] = os.path.join("workdir", args.run_name)
    
    return args


