import argparse

def get_config():
    # parser
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # settings
    parser.add_argument('--seed', help='random seed', default=123, type=int)
    parser.add_argument('--data', help='dataset folder', default='ALFRED/alfred/data/json_feat_2.1.0')
    parser.add_argument('--processed_data', help='dataset folder', default='ALFRED/alfred/data/processed_json_feat_2.1.0')
    parser.add_argument('--splits', help='json file containing train/dev/test splits', default='ALFRED/alfred/data/splits/oct21.json')
    parser.add_argument('--preprocess', help='store preprocessed data to json files', action='store_true')
    parser.add_argument('--pp_folder', help='folder name for preprocessed data', default='pp')
    parser.add_argument('--save_every_epoch', help='save model after every epoch (warning: consumes a lot of space)', action='store_true')
    parser.add_argument('--model', help='model to use', default='seq2seq_im_mask')
    parser.add_argument('--gpu', type=bool, default=True)
    parser.add_argument('--dout', help='where to save model', default='exp/model:{model}')
    parser.add_argument('--use_templated_goals', help='use templated goals instead of human-annotated goal descriptions (only available for train set)', action='store_true')
    parser.add_argument('--resume', help='load a checkpoint')
    
    # network size
    parser.add_argument('--goal_size', help='goal size', default=512, type=int)
    parser.add_argument('--action_hidden_size', help='action hidden size', default=128, type=int)
    parser.add_argument('--action_size', help='action size', default=15, type=int)
    parser.add_argument('--hidden_size', help='hidden size', default=512, type=int)

    # hyper parameters
    parser.add_argument('--history_frame', help='', default=1, type=int)
    parser.add_argument('--batch_size', help='batch size', default=64, type=int)
    parser.add_argument('--batch_size_clip', help='batch size clip', default=64, type=int)
    parser.add_argument('--epoch', help='number of epochs', default=20, type=int)
    parser.add_argument('--lr', help='optimizer learning rate', default=3e-4, type=float)
    parser.add_argument('--decay_epoch', help='num epoch to adjust learning rate', default=10, type=int)
    parser.add_argument('--dhid', help='hidden layer size', default=256, type=int)
    parser.add_argument('--dframe', help='image feature vec size', default=2500, type=int)
    parser.add_argument('--demb', help='language embedding size', default=100, type=int)
    parser.add_argument('--pframe', help='image pixel size (assuming square shape eg: 300x300)', default=300, type=int)
    parser.add_argument('--mask_loss_wt', help='weight of mask loss', default=1., type=float)
    parser.add_argument('--action_loss_wt', help='weight of action loss', default=1., type=float)
    parser.add_argument('--subgoal_aux_loss_wt', help='weight of subgoal completion predictor', default=0., type=float)
    parser.add_argument('--pm_aux_loss_wt', help='weight of progress monitor', default=0., type=float)

    # dropouts
    parser.add_argument('--zero_goal', help='zero out goal language', action='store_true')
    parser.add_argument('--zero_instr', help='zero out step-by-step instr language', action='store_true')
    parser.add_argument('--lang_dropout', help='dropout rate for language (goal + instr)', default=0., type=float)
    parser.add_argument('--input_dropout', help='dropout rate for concatted input feats', default=0., type=float)
    parser.add_argument('--vis_dropout', help='dropout rate for Resnet feats', default=0.3, type=float)
    parser.add_argument('--hstate_dropout', help='dropout rate for LSTM hidden states during unrolling', default=0.3, type=float)
    parser.add_argument('--attn_dropout', help='dropout rate for attention', default=0., type=float)
    parser.add_argument('--actor_dropout', help='dropout rate for actor fc', default=0., type=float)

    # other settings
    parser.add_argument('--dec_teacher_forcing', help='use gpu', action='store_true')
    parser.add_argument('--temp_no_history', help='use gpu', action='store_true')

    # debugging
    parser.add_argument('--fast_epoch', type=bool, default=False)
    parser.add_argument('--dataset_fraction', help='use fraction of the dataset for debugging (0 indicates full size)', default=0, type=int)
    
    
    parser.add_argument("--gamma", type=float, default=0.99, help="")
    parser.add_argument("--tau", type=float, default=1e-3, help="")
    parser.add_argument("--n_atoms", type=int, default=51, help="")
    parser.add_argument("--feature_size", type=int, default=512, help="")
    parser.add_argument("--feature_extract", type=str, default='resnet', help="")
    parser.add_argument("--goal_format", type=str, default='multi-hot', help="")
    
    parser.add_argument("--alpha", type=float, default=1, help="the CQL hyper-parameter")
    parser.add_argument("--learning_rate", type=float, default=3e-4, help="")
    parser.add_argument("--q_learning_rate", type=float, default=3e-4, help="")
    parser.add_argument("--clip_learning_rate", type=float, default=3e-5, help="")
    
    parser.add_argument("--train_goal_q", type=bool, default=True, help="")
    parser.add_argument("--train_goal_clip", type=bool, default=True, help="")
    parser.add_argument("--train_state_q", type=bool, default=True, help="")
    parser.add_argument("--train_state_clip", type=bool, default=True, help="")
    
    parser.add_argument("--save_every", type=int, default=1, help="")
    parser.add_argument("--eval_every", type=int, default=2000, help="")
    parser.add_argument("--switch_every", type=int, default=1000, help="")
    parser.add_argument("--load_model", type=bool, default=False, help="")
    parser.add_argument("--eval_data", type=int, default=-1, help="the number of data evaluated during training, -1 for all data to be evaluated")
    
    # == eval ==
    parser.add_argument('--eval_split', type=str, default='valid_seen', choices=['train', 'valid_seen', 'valid_unseen'])
    parser.add_argument('--shuffle', type=bool, default=True)
    parser.add_argument('--num_threads', type=int, default=1)
    parser.add_argument('--subgoals', type=str, help="subgoals to evaluate independently, eg:all or GotoLocation,PickupObject...", default="all")
    parser.add_argument('--reward_config', type=str, default='ALFRED/alfred/models/config/rewards.json')
    parser.add_argument('--num_eval_file', type=int, default=50, help='the number of files tested')
    parser.add_argument('--model_path', type=str, default='data/', help='the number of files tested')
    parser.add_argument('--test_mode', type=bool, default=True, help='whether in test mode')
    
    # eval params
    parser.add_argument('--max_steps', type=int, default=1000, help='max steps before episode termination')
    parser.add_argument('--max_steps_taken', type=int, default=40, help='max steps for each subgoal')
    parser.add_argument('--max_fails', type=int, default=5, help='max API execution failures before episode termination')
    
    # eval settings
    parser.add_argument('--smooth_nav', dest='smooth_nav', action='store_true', help='smooth nav actions (might be required based on training data)')
    parser.add_argument('--skip_model_unroll_with_expert', type=bool, default=True)
    parser.add_argument('--no_teacher_force_unroll_with_expert', action='store_true', help='no teacher forcing with expert')
    
    parser.add_argument('--debug', dest='debug', action='store_true')
    
    parser.add_argument('--if_clip', type=bool, default=False)
    parser.add_argument('--if_regularize', type=bool, default=False)
    
    parser.add_argument('--if_q', type=bool, default=True)
    parser.add_argument('--if_actor', type=bool, default=True)
    parser.add_argument('--if_actor_q', type=bool, default=False)
    parser.add_argument('--generate_dataset', type=bool, default=False)
    parser.add_argument('--model_type', type=str, default=f'C51')
    parser.add_argument('--update_frequency', type=int, default=1000)
    parser.add_argument('--LSTM', type=bool, default=True)
    
    parser.add_argument('--device', type=str, default='cuda:1')
    
    # args and init
    return parser.parse_args()