import torch


def get_tokenizer(args):
    from transformers import AutoTokenizer
    if args.dataset == 'rxr' or args.tokenizer == 'xlm':
        cfg_name = 'bert_config/xlm-roberta-base'
    else:
        cfg_name = 'bert_config/bert-base-uncased'
    tokenizer = AutoTokenizer.from_pretrained(cfg_name)
    return tokenizer

def get_vlnbert_models(config=None):
    
    from transformers import PretrainedConfig, BertTokenizer, BertModel
    from vlnce_baselines.models.etp.vilmodel_cmt import GlocalTextPathNavCMT

    model_class = GlocalTextPathNavCMT

    model_name_or_path = config.pretrained_path
    new_ckpt_weights = {}
    if model_name_or_path is not None:
        ckpt_weights = torch.load(model_name_or_path, map_location='cpu')
        for k, v in ckpt_weights.items():
            if k.startswith('module'):
                new_ckpt_weights[k[7:]] = v
            if 'sap_head' in k:
                new_ckpt_weights['bert.' + k] = v
            else:
                 new_ckpt_weights[k] = v
    
    if config.task_type == 'r2r':
        #cfg_name = 'bert_config/bert-base-uncased'
        cfg_name = 'vlnce_baselines/bert-base-uncased'
    elif config.task_type == 'rxr':
        cfg_name = 'bert_config/xlm-roberta-base'
    vis_config = PretrainedConfig.from_pretrained(cfg_name)

    if config.task_type == 'rxr':
        vis_config.type_vocab_size = 2

    vis_config.max_action_steps = 100
    vis_config.image_feat_size = 512
    vis_config.use_depth_embedding = config.use_depth_embedding
    vis_config.depth_feat_size = 128
    vis_config.angle_feat_size = 4

    vis_config.num_l_layers = 9
    vis_config.num_pano_layers = 2
    vis_config.num_x_layers = 4
    vis_config.graph_sprels = config.use_sprels
    vis_config.glocal_fuse = 'global'

    vis_config.fix_lang_embedding = config.fix_lang_embedding
    vis_config.fix_pano_embedding = config.fix_pano_embedding

    vis_config.update_lang_bert = not vis_config.fix_lang_embedding
    vis_config.output_attentions = True
    vis_config.pred_head_dropout_prob = 0.1
    vis_config.use_lang2visn_attn = False

    visual_model = model_class.from_pretrained(
        pretrained_model_name_or_path=None, 
        config=vis_config, 
        state_dict=new_ckpt_weights)
    
    # Re-apply zero-init after from_pretrained to prevent override
    if hasattr(visual_model.global_encoder, 'nav_correction'):
        import torch.nn as nn
        nn.init.zeros_(visual_model.global_encoder.nav_correction[0].weight)
        nn.init.zeros_(visual_model.global_encoder.nav_correction[0].bias)
    
    # Zero-init for world model feedback layers
    if hasattr(visual_model.global_encoder, 'feedback_gate'):
        import torch.nn as nn
        nn.init.xavier_uniform_(visual_model.global_encoder.feedback_gate.weight, gain=0.01)
        nn.init.constant_(visual_model.global_encoder.feedback_gate.bias, 0.0)
        nn.init.zeros_(visual_model.global_encoder.feedback_delta.weight)
        nn.init.zeros_(visual_model.global_encoder.feedback_delta.bias)
        
    return visual_model
