import argparse

argparser = argparse.ArgumentParser(description='AE2 training')
argparser.add_argument('--num_gpus', type=int, default=1, help='gpus')
argparser.add_argument('--epochs', type=int, default=300, help='Maximum epoch')
argparser.add_argument('--task', type=str, default='align', help='Tasks')
argparser.add_argument('--output_dir', type=str, default='debug', help='Path to results')
argparser.add_argument('--viz_path', type=str, default='ours', help='visualize path')
argparser.add_argument('--fast_dev_run', action='store_true', help='fast dev run')
argparser.add_argument('--sweep', action='store_true', help='sweep params')
argparser.add_argument('--find_lr', action='store_true', help='find lr')
argparser.add_argument('--freeze_base', action='store_true', help='whether to freeze base model')
argparser.add_argument('--cls_every_n_epoch', type=int, default=25, help='downstream classification every n epochs')
argparser.add_argument('--save_every', type=int, default=25, help='save every n epochs')
argparser.add_argument('--eval_only', action='store_true', help='eval only')
argparser.add_argument('--train_eval_mode', type=str, default='val', help='training time downstream')
argparser.add_argument('--eval_mode', type=str, default='test', help='visualize train or val')
argparser.add_argument('--imagenet_norm', action='store_true', help='whether to use imagenet_norm')
argparser.add_argument('--gen_video', action='store_true', help='generate test video')
argparser.add_argument('--extract_embedding', action='store_true', help='extract embeddings')
argparser.add_argument('--label_all', action='store_true', help='use all labels (pouring)')
argparser.add_argument('--label_only', action='store_true', help='only extract label')
argparser.add_argument('--modify_data', action='store_true', help='for tennis (whether to modify data)')
argparser.add_argument('--random_test', action='store_true', help='for visualization')
argparser.add_argument('--vlabel', action='store_true', help='verb stands for one class')
argparser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
argparser.add_argument('--min_lr', type=float, default=5e-5, help='min learning rate')
argparser.add_argument('--use_scheduler', action='store_true', help='use lr scheduler')
argparser.add_argument('--wd', type=float, default=1e-5, help='Weight decay')
argparser.add_argument('--batch_size', type=int, default=4, help='batch size')
argparser.add_argument('--num_workers', type=int, default=0, help='number of workers')
argparser.add_argument('--ckpt', type=str, default='', help='model ckpt')

## TCC config
argparser.add_argument('--dataset', type=str, default='break_egg', help='dataset name')
argparser.add_argument('--view1', type=str, default='ego', help='view 1')
argparser.add_argument('--view2', type=str, default='exo', help='view 2 (can be same as view 1)')
argparser.add_argument('--loss', type=str, default='tcc', help='train loss: tcc/d2tw/...')
argparser.add_argument('--merge_all', action='store_true', help='train both ego and exo videos together')
argparser.add_argument('--tcc_num_frames', type=int, default=32, help='TCC number of frames')
argparser.add_argument('--tcc_frame_stride', type=int, default=15, help='TCC frame stride')
argparser.add_argument('--tcc_num_context_steps', type=int, default=2, help='TCC context steps')
argparser.add_argument('--tcc_random_offset', type=int, default=1, help='TCC random offset')

argparser.add_argument('--base_model_name', type=str, default='resnet50', help='Base model name')
argparser.add_argument('--tcc_input_size', type=int, default=224, help='TCC input size: 168 or 224')
argparser.add_argument('--hidden_dim', type=int, default=128, help='transformer hidden dim')
argparser.add_argument('--n_layers', type=int, default=3, help='transformer layer num')
argparser.add_argument('--embedding_size', type=int, default=128, help='TCC output embedding size')
argparser.add_argument('--tcc_temp', type=float, default=0.1, help='TCC temperature')
argparser.add_argument('--dtw_scale_factor', type=float, default=0.01, help='DTW loss scale factor')
argparser.add_argument('--dtw_beta', type=float, default=0.0, help='DTW contrastive beta')
argparser.add_argument('--dtw_ratio', type=float, default=1.0, help='DTW contrastive loss ratio')
argparser.add_argument('--dtw_shuffle_num', type=int, default=4, help='DTW contrastive loss shuffle num')
# argparser.add_argument('--dtw_permute_num', type=int, default=4, help='DTW contrastive loss permute num')
argparser.add_argument('--num_negatives', type=int, default=2, help='number of negative samples')
argparser.add_argument('--drop_percent', type=float, default=0.3, help='dropdtw drop percentage')
argparser.add_argument('--drop_l2norm', action='store_true', help='dropdtw l2 normalization')

argparser.add_argument('--eval_task', type=str, default='01234', help='downstream evaluation')
argparser.add_argument('--sample_frame_mode', type=str, default='prob', help='frame sample mode')

argparser.add_argument('--bbox_expansion_ratio', type=float, default=1.0, help='Bounding box expansion ratio')
argparser.add_argument('--bbox_threshold', type=float, default=0.0, help='Bounding box threshold')
argparser.add_argument('--sample_by_bbox', action='store_true', help='(tennis) sample frames by bbox')
argparser.add_argument('--use_mask', action='store_true', help='transformer use mask')
argparser.add_argument('--one_object_bbox', action='store_true', help='use one object not two objects')
argparser.add_argument('--use_bbox_pe', action='store_true', help='use bounding box positional embedding')
argparser.add_argument('--weigh_token_by_bbox', action='store_true', help='whether to weigh local tokens by bbox confidence')
