import argparse
from ast import arg
import os


def parse_args():
    parser = argparse.ArgumentParser(description="")
    
    parser.add_argument('--debug', action='store_true', default=False, help='debug mode')

    parser.add_argument('--root_dir', type=str, default='../datasets')
    parser.add_argument('--dataset', type=str, default='r2r', choices=['r2r', 'r4r', 'gsa-r2r'])
    parser.add_argument('--output_dir', type=str, default='default', help='experiment id')
    parser.add_argument('--seed', type=int, default=0)

    parser.add_argument('--tokenizer', choices=['bert', 'xlm'], default='bert')
    
    parser.add_argument('--routing_mode', type=str, default='fixed')
    parser.add_argument('--router_model', type=str, default="Qwen2.5-VL-7B-Instruct")
    parser.add_argument('--routing_weights_type', type=str, default='int')

    parser.add_argument('--act_visited_nodes', action='store_true', default=False)
    parser.add_argument('--fusion', choices=['global', 'local', 'avg', 'dynamic'])
    parser.add_argument('--expl_sample', action='store_true', default=False)
    parser.add_argument('--expl_max_ratio', type=float, default=0.6)
    parser.add_argument('--expert_policy', default='spl', choices=['spl', 'ndtw'])

    # distributional training (single-node, multiple-gpus)
    parser.add_argument('--world_size', type=int, default=1, help='number of gpus')
    parser.add_argument('--local_rank', type=int, default=-1)
    parser.add_argument("--node_rank", type=int, default=0, help="Id of the node")
    
    # vLLM using
    parser.add_argument('--instruction_reorder', action='store_true', default=False, help='Enable instruction reordering')
    parser.add_argument('--localizer_gpu_id', type=int, default=0, help='GPU id for localizer model')
    parser.add_argument('--skill_gpu_id', type=int, default=1, help='GPU id for skill routing model')
    parser.add_argument('--localizer_model', type=str, default='Qwen2.5-VL-7B-Instruct', help='Localizer model name')
    parser.add_argument('--skill_model', type=str, default='Qwen2.5-VL-7B-Instruct', help='Skill routing model name')
    parser.add_argument('--gpu_memory_utilization', type=float, default=0.7, help='GPU memory utilization for vLLM')
    
    # General
    parser.add_argument('--iters', type=int, default=100000, help='training iterations')
    parser.add_argument('--log_every', type=int, default=1000)
    parser.add_argument('--eval_first', action='store_true', default=False)

    # Data preparation
    parser.add_argument('--max_instr_len', type=int, default=80)
    parser.add_argument('--max_action_len', type=int, default=15)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--ignoreid', type=int, default=-100, help='ignoreid for action')
    parser.add_argument('--train_env_names', type=str, nargs='+', default=['train'], help='train env names')
    parser.add_argument('--val_env_names', type=str, nargs='+', default=['val_unseen'], help='val env names')
    parser.add_argument('--partial_dataset', type=float, default=1.0, help='fraction of the dataset to use')
    # Load the model from
    parser.add_argument("--resume_files", type=str, nargs='+', default=None, help='path of the trained model')
    parser.add_argument("--resume_weights", type=float, nargs='+', default=[1.0], help='weight of the trained model')
    parser.add_argument("--resume_optimizer", action="store_true", default=False)

    # Augmented Paths from
    parser.add_argument("--aug", default=None)
    parser.add_argument('--bert_ckpt_file', default=None, help='init vlnbert')

    # Listener Model Config
    parser.add_argument("--ml_weight", type=float, default=0.20)
    parser.add_argument('--entropy_loss_weight', type=float, default=0.01)

    parser.add_argument("--features", type=str, default='vitbase')
    parser.add_argument("--feature_file", type=str, default="clip_vit-b16_mp3d_original.hdf5", help='feature files')
    parser.add_argument("--env_aug", action='store_true', default=False)
    parser.add_argument("--aug_times", type=int, default=19)

    parser.add_argument('--fix_lang_embedding', action='store_true', default=False)
    parser.add_argument('--fix_pano_embedding', action='store_true', default=False)
    parser.add_argument('--fix_local_branch', action='store_true', default=False)

    parser.add_argument('--num_l_layers', type=int, default=9)
    parser.add_argument('--num_pano_layers', type=int, default=2)
    parser.add_argument('--num_x_layers', type=int, default=4)

    parser.add_argument('--enc_full_graph', default=False, action='store_true')
    parser.add_argument('--graph_sprels', action='store_true', default=False)

    # Dropout Param
    parser.add_argument('--dropout', type=float, default=0.5)
    parser.add_argument('--feat_dropout', type=float, default=0.3)

    # Submision configuration
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--zero_shot', action='store_true', default=False)
    parser.add_argument("--submit", action='store_true', default=False)
    parser.add_argument('--no_backtrack', action='store_true', default=False)
    parser.add_argument('--detailed_output', action='store_true', default=False)

    # Training Configurations
    parser.add_argument(
        '--optim', type=str, default='rms',
        choices=['rms', 'adam', 'adamW', 'sgd']
    )    # rms, adam
    parser.add_argument('--lr', type=float, default=0.00001, help="the learning rate")
    parser.add_argument('--decay', dest='weight_decay', type=float, default=0.)
    parser.add_argument(
        '--feedback', type=str, default='sample',
        help='How to choose next position, one of ``teacher``, ``sample`` and ``argmax``'
    )
    parser.add_argument('--epsilon', type=float, default=0.1, help='')

    # Model hyper params:
    parser.add_argument("--angle_feat_size", type=int, default=4)
    parser.add_argument('--image_feat_size', type=int, default=2048)
    parser.add_argument('--obj_feat_size', type=int, default=0)
    parser.add_argument('--views', type=int, default=36)

    # # A2C
    parser.add_argument("--gamma", default=0.9, type=float, help='reward discount factor')
    parser.add_argument(
        "--normalize", dest="normalize_loss", default="total", 
        type=str, help='batch or total'
    )
    parser.add_argument('--train_alg', 
        choices=['imitation', 'dagger'], 
        default='imitation'
    )
    
    # LoRA arguments
    parser.add_argument('--use_lora', action='store_true', default=False, help='Whether to use LoRA for fine-tuning')
    parser.add_argument('--lora_r', type=int, default=8, help='Rank dimension for LoRA')
    parser.add_argument('--lora_alpha', type=int, default=16, help='Alpha scaling factor for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='Dropout probability for LoRA layers')
    parser.add_argument('--lora_target_modules', type=str, nargs='+', default=["query", "value"], help='Target modules for LoRA (e.g., q_proj, v_proj)')
    
    
    args, _ = parser.parse_known_args()

    args = postprocess_args(args)

    return args


def postprocess_args(args):
    ROOTDIR = args.root_dir

    # Setup input paths
    ft_file_map = {
        'clip.h14': 'clip_vit-h14_mp3d_hm3d_gibson.hdf5',
        'clip.b16': 'clip_vit-b16_mp3d_hm3d_gibson.hdf5'
    }
    
    args.aug_ft_file = os.path.join(ROOTDIR, 'R2R', 'features', ft_file_map[args.features])

    if args.features == 'clip.h14':
        args.mp3d_ft_files = [os.path.join(ROOTDIR, 'R2R', 'features', args.feature_file)]
        args.val_ft_file = os.path.join(ROOTDIR, 'R2R', 'features', args.feature_file)
    elif args.features == 'clip.b16':
        args.mp3d_ft_files = [os.path.join(ROOTDIR, 'R2R', 'features', args.feature_file)]
        args.val_ft_file = os.path.join(ROOTDIR, 'R2R', 'features', args.feature_file)

    if args.env_aug: # only h14
        args.mp3d_ft_files = [
            os.path.join(ROOTDIR, 'R2R', 'features', 'clip_vit-h14_mp3d_img_image_synthesis.hdf5'), 
            os.path.join(ROOTDIR, 'R2R', 'features', 'clip_vit-h14_mp3d_img_mask_image_synthesis.hdf5'),
            os.path.join(ROOTDIR, 'R2R', 'features', 'clip_vit-h14_mp3d_img_style_transfer.hdf5'),
            os.path.join(ROOTDIR, 'R2R', 'features', 'clip_vit-h14_mp3d_original.hdf5'),
            ]


    if args.aug or args.dataset != 'r2r':
        args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity')
    else:
        args.connectivity_dir = os.path.join(ROOTDIR, 'R2R', 'connectivity_mp3d')

    args.scan_data_dir = os.path.join(ROOTDIR, 'Matterport3D', 'v1_unzip_scans')

    args.anno_dir = os.path.join(ROOTDIR, args.dataset.upper(), 'annotations')

    # Build paths
    args.ckpt_dir = os.path.join(args.output_dir, 'ckpts')
        
    if args.zero_shot:
        args.log_dir = os.path.join(args.output_dir, 'zero_shot_logs')
    else:
        args.log_dir = os.path.join(args.output_dir, 'logs')
    args.pred_dir = os.path.join(args.output_dir, 'preds')

    if not args.zero_shot:
        os.makedirs(args.output_dir, exist_ok=True)
        os.makedirs(args.ckpt_dir, exist_ok=True)
        os.makedirs(args.pred_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    
    if args.resume_weights:
        # normalize the resume weights 
        normalize_weights = [ float(resume_weight) / sum(
            [float(x) for x in args.resume_weights]
        ) for resume_weight in args.resume_weights]
        
        args.resume_weights = normalize_weights

    return args

