import os
import sys
import argparse
from omegaconf import OmegaConf
import torch
import torch.nn as nn

# Add the current directory to sys.path to allow imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)

from models.rvfm import VisionFMMoE


def main():
    parser = argparse.ArgumentParser(description='Load VFMoE model with specified configuration')
    parser.add_argument('--model_size', type=str, choices=['tiny', 'small', 'base'], 
                        default='tiny', help='Model size (tiny, small, or base)')
    parser.add_argument('--ckpt_path', type=str, required=True,
                        help='Path to the checkpoint file')
    
    args = parser.parse_args()
    
    model_size = args.model_size
    ckpt_path = args.ckpt_path
    
    # Verify checkpoint file exists
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}")
    
    vfmoe_config_path = "./configs"
    base_config_path = os.path.join(vfmoe_config_path, "train_vfmoe.yaml")
    moe_config_path = os.path.join(vfmoe_config_path, "model", "moe", f"{model_size}_ffn_ap.yaml")

    # Load both configs
    base_cfg = OmegaConf.load(base_config_path)
    moe_cfg = OmegaConf.load(moe_config_path)

    # Access the model configuration from moe_cfg
    moe_model_cfg = moe_cfg['model']  # or moe_cfg.get('model')

    # Merge configurations
    cfg = OmegaConf.merge(
        base_cfg,
        {
            'model_size': model_size,
            'model': {
                'moe': moe_model_cfg  # Using the model configuration from tiny_ffn_ap.yaml
            },
            'training': 'frame_level_vfmoe_reconstruction',
            'training_stage': 'reconstruction'
        }
    )

    backbone = 'facebook/deit-size-patch16-224'
    backbone = backbone.replace('size', cfg.model_size)    

    feature_sizes = {'facebook/dinov2-large': (1024, 16, 16),
                    'google/vit-huge-patch14-224-in21k': (1280, 16, 16),
                    'openai/clip-vit-large-patch14': (1024, 16, 16),
                    'LiheYoung/depth-anything-large-hf': (32, 64, 64),
                    'facebook/sam-vit-huge': (256, 64, 64),
                    'llava-hf/llava-1.5-7b-hf': (1024, 24, 24)}

    model_names = ['facebook/dinov2-large', 
                'google/vit-huge-patch14-224-in21k', 
                'openai/clip-vit-large-patch14'
                ]
    target_feature_sizes = {t: feature_sizes[t] for t in model_names}

    translator = "lconv"
    translator_kwargs = OmegaConf.create({"hidden_size_factor": 1.0})
    target_loss_weights = None

    encoder = VisionFMMoE(
        translator=translator,
        translator_kwargs=translator_kwargs,
        target_feature_sizes=target_feature_sizes,
        target_loss_weights=target_loss_weights,
        backbone=backbone,
        moe_cfg=moe_cfg,
    )
    encoder.backbone.model.encoder.layer = nn.ModuleList(list(encoder.backbone.model.encoder.layer)[:-3])

    print(f"Loading checkpoint from: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location='cpu')

    def load_state_dict_without_module(model, state_dict):
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v  # Remove the 'module.' prefix
            else:
                new_state_dict[k] = v 
                
        missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
        print('.'*20)
        print('check missing_keys and unexpected_keys.')
        print('missing_keys', missing_keys)
        print('unexpected_keys', unexpected_keys)

    load_state_dict_without_module(encoder, ckpt)

    encoder.add_downstream_tasks(n_new_tasks=1, n_new_experts=0, new_task_names=['robot'], noising_gating=True, unfreeze_norm=True,topk=-1)

    fake_input = torch.zeros((1, 3, 224 ,224), dtype=torch.uint8)
    x = encoder.forward_feature(fake_input, task_name="robot", do_rescale=False)
    print(f"Output shape: {x.shape}")
    
    return encoder


if __name__ == "__main__":
    encoder = main()