import argparse
 
parser = argparse.ArgumentParser(description='Training settings')

# training args
parser.add_argument('--device', default=0, metavar='torch device', type=int)
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--loader_workers', default=4, type=int, help='num parallel data loader processes')
parser.add_argument('--samples', default=-1, type=int, help='number of train/val/test samples to use, -1 means all')
parser.add_argument('--train_samples', default=-1, type=int, help='number of train samples to use, -1 means all')
parser.add_argument('--val_interval', default=5, type=int, help='how many train epochs between validation')
parser.add_argument('--proj_interval', default=0, type=int, help='how many batches between projecting the mapping weights to [0,1]')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--tqdm', action='store_true', help='use the tqdm to show epoch batch progress')
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--l1_loss', default=0.1, type=float, help='weight for l1 loss on reason network')
parser.add_argument('--cons_loss', default=0.0, type=float, help='weight for consistency loss on relation predictions')

# model params
parser.add_argument('--model_type', default='neuraltlp', choices=['neuraltlp', 'lstm', 'map', 'mode'])
parser.add_argument('--time_dim_quant', default=301, type=int)
parser.add_argument('--no_time_mult', action='store_false')
parser.add_argument('--agg_mode', default='agg_after', type=str)
parser.add_argument('--agg_type', default='sum', type=str)
parser.add_argument('--conv_fill', default=1.0, type=float)
parser.add_argument('--append_thresh', default=0.01, type=float)

# lstm params
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--bidirectional', action='store_true')
parser.add_argument('--num_layers', default=1, type=int)
parser.add_argument('--attention_dim', default=32, type=int)

# load/save args
# parser.add_argument('--exp_name', type=str, help='name of the experiment, will be assigned to model file name', required=True)
parser.add_argument('--exp_name', type=str, help='name of the experiment, will be assigned to model file name', default='model')
parser.add_argument('--load_model', action='store_true', help='load model exp_name.pth')
parser.add_argument('--model_path', type=str, default='/localscratch/sysuser/trained_models/cater/')
parser.add_argument('--cater_path', type=str, default='/localscratch/cater_dataset/max2action')
parser.add_argument('--freeze_rela', action='store_true', help='freeze convolution and relation params')
parser.add_argument('--freeze_mapping', action='store_true', help='use the GT mapping and freeze')
parser.add_argument('--gt', action='store_true', help='use the ground truth atomic events')

# ablation args
parser.add_argument('--two_pred', action='store_true', help='evaluate on rules of length two predicates')
parser.add_argument('--test_gen', action='store_true', help='evaluate on generated data, assume data is already generated')
parser.add_argument('--var_len', action='store_true', help='predict variable length rules')
parser.add_argument('--var_len_fixed', action='store_true', help='in the variable length model only use a fixed max rule len, for testing')
parser.add_argument('--var_skip_proj_training', action='store_true', help='assume projection weights are already trained and loaded')
parser.add_argument('--gen_path', type=str, default='/localscratch/cater_dataset/generated_data/')
parser.add_argument('--gen_len', type=int, default=1, help='max rule length of generated data')
parser.add_argument('--gen_len_beam', type=int, default=10, help='beam length when generating combinatioral rule possibilities')
parser.add_argument('--gen_len_batch_size', type=int, default=64, help='different batch size from training, size var rule module takes more memory')
parser.add_argument('--gen_epochs', type=int, default=10, help='different epochs from training, usually needs fewer to converge')
parser.add_argument('--gen_rules_beam', type=int, default=100, help='beam length when generating the original rules')
parser.add_argument('--gen_events', type=int, default=3, help='max composite co-occurring events per time series')
parser.add_argument('--gen_samples', type=int, default=10000, help='number of training samples')
