import argparse
import ast


parser = argparse.ArgumentParser(description="hyper-parameter for GVLL")

# ====================== Dataset Config ===========================
parser.add_argument('--dataset', metavar='DATASET', default='./data/msrvtt/annotation/meta_msrvtt_v2.json', help='training datasets to use')
parser.add_argument('--base_dir', default='./data/msrvtt', type=str, help='Dataset directory containing image folders.')
parser.add_argument('--savedmodel_path', default='./runs/debug', type=str, help='Dataset directory containing image folders.')
parser.add_argument('--ckpt_file', default=None, type=str, help='Dataset directory containing image folders.')
parser.add_argument('--delta_file', default='/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/gvll_A100/runs/v1_s2/checkpoints/checkpoint_epoch10_step2387_val_loss4.483766.pth', type=str, help='Dataset directory containing image folders.')
# parser.add_argument('--delta_file', default=None, type=str, help='Dataset directory containing image folders.')
parser.add_argument('--text_embed', default='clip_embeds_ali', type=str, help='Dataset directory containing image folders.')
parser.add_argument('--use_embed', default=False, type=lambda x: (str(x).lower() == 'true'), help='load video embedding or video')

# ====================== Model Config ===========================
# parser.add_argument('--llm_model', default='/apdcephfs/share_916081/vinnylywang/zhanyuwang/Data/Checkpoints/Llama-2-7b-chat-hf/', help='LLM to use, meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--llm_model', default='./Checkpoints/Llama-2-7b-chat-hf', help='LLM to use, meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--visual_model', default='./Checkpoints/xclip-large-patch14', type=str, help="Visual encoder to use. microsoft/xclip-base-patch32")
parser.add_argument('--freeze_vm', default=True, type=lambda x: (str(x).lower() == 'true'), help="freeze visual model or not")
parser.add_argument('--freeze_llm', default=True, type=lambda x: (str(x).lower() == 'true'), help="freeze llm model or not")
parser.add_argument('--llm_use_lora', default=True, type=lambda x: (str(x).lower() == 'true'), help="freeze llm model or not")
parser.add_argument('--lora_inference', default=True, type=lambda x: (str(x).lower() == 'true'), help="freeze llm model or not")
parser.add_argument('--llm_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
parser.add_argument('--llm_alpha', default=16, type=int, help='Scaling factor.')
parser.add_argument('--lora_dropout', default=0.1, type=float, help='lora dropout')
parser.add_argument('--proj_type', default='linear', type=str, help="the way of projecting visual features to llm")

# ====================== Training Config ===========================
parser.add_argument('--batch_size', default=4, type=int, help='mini-batch size for training')
parser.add_argument('--val_batch_size', default=24, type=int, help='mini-batch size for validation')
parser.add_argument('--num_workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--prefetch_factor', default=2, type=int, metavar='N', help='Number of batches loaded in advance by each worker')
parser.add_argument('--learning_rate', default=1e-4, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--lr_warmup_steps', default=2000, type=int, metavar='N', help='Number of steps to warm up lr.')
parser.add_argument('--lr_schedule_step_size', default=5, type=int, metavar='N', help='Number of steps before decaying lr.')
parser.add_argument('--lr_schedule_gamma', default=0.1, type=float, metavar='N', help='Decay parameter for learning rate scheduler.')
parser.add_argument('--grad_accumulation_steps', default=1, type=int, metavar='N', help='number of gradient accumulation steps')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping amount')

parser.add_argument('--num_frames', default=8, type=int, metavar='N', help='Number of frames to use.')
parser.add_argument('--clip_len', default=8, type=int, metavar='N', help='Number of video frames.') 
parser.add_argument('--projection_dim', default=768, type=int, metavar='N', help='Number of video frames.')
parser.add_argument('--text_emb_layers', default=[-1], type=list, help='Number of CLIP token to use for generation.')

# ====================== Pytorch Lightning ===========================
parser.add_argument('--cap_loss_scale', type=float, default=1.0, help="Scale on captioning loss.")
parser.add_argument('--gen_loss_scale', type=float, default=1.0, help="Scale on retrieval loss.")
parser.add_argument('--concat_captions_prob', type=float, default=0.0, help="Probability of concatenating two examples sequentially for captioning.")
parser.add_argument('--input_prompt', default='A video shows', type=str, help="Input prompt for the language model, if any.")
parser.add_argument('--image_size', default=224, type=int, metavar='N', help='Size of images.')

# ====================== Decoding ===========================
parser.add_argument('--max_length', default=40, type=int, metavar='N', help='Maximum length to truncate captions / generations to.')
parser.add_argument('--repetition_penalty', type=float, default=1)
parser.add_argument('--length_penalty', type=float, default=1.0)
parser.add_argument('--diversity_penalty', type=float, default=0)
parser.add_argument('--temperature', type=float, default=0.1)
parser.add_argument('--beam_size', type=int, default=3)
parser.add_argument('--do_sample', type=bool, default=False)
parser.add_argument('--n_visual_tokens', default=1, type=int, metavar='N', help='Number of visual tokens to use for the Frozen model.')

# ====================== Mapper ===========================
parser.add_argument('--hidden_dim', default=512, type=int, help='Maximum length to truncate captions / generations to.')
parser.add_argument('--num_decoder_layers', type=int, default=4)
parser.add_argument('--num_encoder_layers', type=int, default=4)
parser.add_argument('--head', type=int, default=8)
parser.add_argument('--gen_emb_dim', default=768, type=int, help='Embedding dimension for generation.')
parser.add_argument('--num_tokens', default=8, type=int, help='Number of [IMG] tokens to use.')
parser.add_argument('--num_clip_tokens', default=77, type=int, help='Number of CLIP token to use for generation.')

# ====================== Pytorch Lightning ===========================
parser.add_argument('--devices', type=int, default=1, help='how many gpus to use')
parser.add_argument('--num_nodes', type=int, default=1, help='Number of GPU nodes for distributed training.')
parser.add_argument('--accelerator', type=str, default="cpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps"], help='accelerator types')
parser.add_argument('--strategy', type=str, default="ddp", help='default ddp for multi-gpus')
parser.add_argument('--precision', type=str, default='bf16-mixed', help='16 or 32 bf16-mixed, using for original pytorch amp auto cast')
parser.add_argument('--limit_val_batches', type=float, default=1.0, help='How much of validation dataset to check (float = fraction, int = num_batches).')
parser.add_argument('--limit_train_batches', type=float, default=1.0, help='How much of training dataset to check (float = fraction, int = num_batches)')
parser.add_argument('--max_steps', default=1500000, type=int, metavar='N', help='Stop training after this number of steps. ')
parser.add_argument('--max_epochs', type=int, default=60, help='Stop training once this number of epochs is reached')
parser.add_argument('--every_n_train_steps', type=int, default=0, help='How many training steps to save a checkpoint')
parser.add_argument('--val_check_interval', type=float, default=1.0, help='How often to check the validation set')
parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Accumulates gradients over k batches before stepping the optimizer')
parser.add_argument('--log_every_n_steps', type=int, default=20, help='How often to log within steps')
parser.add_argument("--num_sanity_val_steps", type=int, default=2, help='Sanity check runs n validation batches before starting the training routine')
