import torch
import os
import shutil
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
from model.configuration_hybrid_rwkv_qwen import HybridRWKVQwenConfig
from model.modeling_hybrid_rwkv_qwen import HybridQwenRWKVForConditionalGeneration


def count_parameters(model):
    """Count total and trainable parameters in the model."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


def format_params(num_params):
    """Format parameter count for better readability."""
    if num_params < 1e6:
        return f"{num_params:,}"
    elif num_params < 1e9:
        return f"{num_params/1e6:.2f}M"
    else:
        return f"{num_params/1e9:.2f}B"


def load_qwen_weights_to_hybrid_model(qwen_model_path, hybrid_config=None, device="cuda", save_path=None):
    """
    Load Qwen2.5-VL weights into HybridRWKVQwen model.
    Initialize Qwen2MLP in RWKV layers using corresponding layer weights.

    Args:
        qwen_model_path (str): Path to pretrained Qwen2.5-VL model (local or HF hub).
        hybrid_config (HybridRWKVQwenConfig, optional): Custom hybrid model config.
        device (str): Device to load model on ("cuda" or "cpu").
        save_path (str, optional): Path to save initialized hybrid model.

    Returns:
        HybridQwenRWKVForConditionalGeneration: Initialized hybrid model.
    """
    # Fallback to CPU if CUDA is unavailable
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, falling back to CPU")
        device = "cpu"

    try:
        # 1. Load original Qwen2.5-VL model
        print(f"Loading Qwen2.5-VL model from: {qwen_model_path}")
        qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            qwen_model_path,
            torch_dtype=torch.float16 if device == "cuda" else "auto",
            device_map=device if device != "cuda" else "auto"
        )
        print("✅ Qwen2.5-VL model loaded successfully")

        # Calculate and print parameter counts
        qwen_total_params, qwen_trainable_params = count_parameters(qwen_model)
        print(f"\n📊 Original Qwen2.5-VL Model Parameters:")
        print(f"  Total: {format_params(qwen_total_params)}")
        print(f"  Trainable: {format_params(qwen_trainable_params)}")
    except Exception as e:
        print(f"❌ Error loading Qwen2.5-VL model: {e}")
        raise

    try:
        # 2. Create HybridRWKVQwen model config
        if hybrid_config is None:
            qwen_config = Qwen2_5_VLConfig.from_pretrained(qwen_model_path)
            qwen_config_dict = qwen_config.to_dict() if hasattr(qwen_config, 'to_dict') else vars(qwen_config)
            vision_config = qwen_config_dict.pop('vision_config', None)
            qwen_config_dict.update({
                "rwkv_layer_offset": 3,
                "rwkv_layer_period": 4,
                "auto_map": {
                    "AutoModel": "modeling_hybrid_rwkv_qwen.HybridQwenRWKVForConditionalGeneration",
                    "AutoModelForCausalLM": "modeling_hybrid_rwkv_qwen.HybridQwenRWKVForConditionalGeneration",
                    "AutoConfig": "configuration_hybrid_rwkv_qwen.HybridRWKVQwenConfig"
                }
            })
            hybrid_config = HybridRWKVQwenConfig(
                vision_config=vision_config,
                **qwen_config_dict
            )

        # 3. Initialize hybrid model and move to device
        hybrid_model = HybridQwenRWKVForConditionalGeneration(hybrid_config)
        hybrid_model = hybrid_model.to(device)
        print("✅ Hybrid model initialized")

        # Print initial parameter counts
        hybrid_total_params, hybrid_trainable_params = count_parameters(hybrid_model)
        print(f"\n📊 Hybrid Model Initial Parameters:")
        print(f"  Total: {format_params(hybrid_total_params)}")
        print(f"  Trainable: {format_params(hybrid_trainable_params)}")
    except Exception as e:
        print(f"❌ Error creating hybrid model: {e}")
        raise

    try:
        # 4. Get state dictionaries
        qwen_state_dict = qwen_model.state_dict()
        hybrid_state_dict = hybrid_model.state_dict()

        # 5. Copy weights layer by layer, skipping RWKV layers initially
        loaded_keys = []
        skipped_keys = []

        # Generate hybrid layer pattern based on offset and period
        num_layers = hybrid_config.text_config.num_hidden_layers
        hybrid_layers = [(i - hybrid_config.rwkv_layer_offset) % hybrid_config.rwkv_layer_period == 0 
                        for i in range(num_layers)]

        for key in hybrid_state_dict.keys():
            if key in qwen_state_dict:
                try:
                    qwen_tensor = qwen_state_dict[key]
                    hybrid_tensor = hybrid_state_dict[key]
                    if qwen_tensor.device != hybrid_tensor.device:
                        qwen_tensor = qwen_tensor.to(hybrid_tensor.device)
                    hybrid_tensor.copy_(qwen_tensor)
                    loaded_keys.append(key)
                except Exception as e:
                    print(f"❌ Error copying weight {key}: {e}")
                    skipped_keys.append(key)
            else:
                skipped_keys.append(key)

        print(f"\n✅ Successfully loaded {len(loaded_keys)} parameters")
        print(f"⚠️  Skipped {len(skipped_keys)} parameters (RWKV layers or mismatched keys)")

        # 6. Handle Qwen2MLP weights in RWKV layers
        rwkv_layer_indices = [i for i, is_rwkv in enumerate(hybrid_layers) if is_rwkv]
        print(f"🔧 RWKV Layer Indices: {rwkv_layer_indices}")

        for layer_idx in rwkv_layer_indices:
            print(f"🔧 Processing Qwen2MLP weights for RWKV layer {layer_idx}...")

            mlp_weight_mapping = {
                f"model.language_model.layers.{layer_idx}.mlp.gate_proj.weight": 
                    f"model.language_model.layers.{layer_idx}.mlp.gate_proj.weight",
                f"model.language_model.layers.{layer_idx}.mlp.up_proj.weight": 
                    f"model.language_model.layers.{layer_idx}.mlp.up_proj.weight",
                f"model.language_model.layers.{layer_idx}.mlp.down_proj.weight": 
                    f"model.language_model.layers.{layer_idx}.mlp.down_proj.weight"
            }

            for qwen_key, hybrid_key in mlp_weight_mapping.items():
                if qwen_key in qwen_state_dict and hybrid_key in hybrid_state_dict:
                    try:
                        qwen_tensor = qwen_state_dict[qwen_key]
                        hybrid_tensor = hybrid_state_dict[hybrid_key]
                        if qwen_tensor.shape == hybrid_tensor.shape:
                            if qwen_tensor.device != hybrid_tensor.device:
                                qwen_tensor = qwen_tensor.to(hybrid_tensor.device)
                            hybrid_tensor.copy_(qwen_tensor)
                            loaded_keys.append(hybrid_key)
                            if hybrid_key in skipped_keys:
                                skipped_keys.remove(hybrid_key)
                            print(f"  ✅ Loaded {hybrid_key}")
                        else:
                            print(f"  ⚠️ Shape mismatch: {qwen_key} {qwen_tensor.shape} -> {hybrid_key} {hybrid_tensor.shape}")
                    except Exception as e:
                        print(f"  ❌ Error copying MLP weight {hybrid_key}: {e}")
                else:
                    print(f"  ⚠️ Weight not found: {qwen_key} or {hybrid_key}")

        # 7. Initialize RWKV_Tmix_x060b weights using Qwen attention weights
        print(f"\n🔧 Initializing RWKV_Tmix_x060b weights...")
        print(f"RWKV Layer Indices: {rwkv_layer_indices}")

        for layer_idx in rwkv_layer_indices:
            print(f"🔧 Initializing RWKV_Tmix_x060b weights for layer {layer_idx}...")

            # Qwen attention weight keys
            qwen_q_key = f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight"
            qwen_k_key = f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight"
            qwen_v_key = f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight"
            qwen_o_key = f"model.language_model.layers.{layer_idx}.self_attn.o_proj.weight"

            # RWKV weight keys
            rwkv_r_key = f"model.language_model.layers.{layer_idx}.rwkv.att.receptance.weight"
            rwkv_k_key = f"model.language_model.layers.{layer_idx}.rwkv.att.key.weight"
            rwkv_v_key = f"model.language_model.layers.{layer_idx}.rwkv.att.value.weight"
            rwkv_o_key = f"model.language_model.layers.{layer_idx}.rwkv.att.output.weight"
            rwkv_q_key = f"model.language_model.layers.{layer_idx}.rwkv.cross_attn_q_proj.weight"
            rwkv_q_bias = f"model.language_model.layers.{layer_idx}.rwkv.cross_attn_q_proj.bias"
            rwkv_cross_attn_kv_proj_key = f"model.language_model.layers.{layer_idx}.rwkv.cross_attn_kv_proj.weight"
            rwkv_cross_attn_kv_proj_bias = f"model.language_model.layers.{layer_idx}.rwkv.cross_attn_kv_proj.bias"

            # Initialize receptance using query weights
            if qwen_q_key in qwen_state_dict and rwkv_r_key in hybrid_state_dict:
                try:
                    qwen_q_weight = qwen_state_dict[qwen_q_key]
                    rwkv_r_weight = hybrid_state_dict[rwkv_r_key]
                    if qwen_q_weight.shape != rwkv_r_weight.shape:
                        print(f"  ⚠️ Receptance shape mismatch: Qwen {qwen_q_weight.shape} -> RWKV {rwkv_r_weight.shape}")
                        # Handle input dimension mismatch
                        if qwen_q_weight.shape[1] != rwkv_r_weight.shape[1]:
                            if qwen_q_weight.shape[1] > rwkv_r_weight.shape[1]:
                                qwen_q_weight = qwen_q_weight[:, :rwkv_r_weight.shape[1]]
                            else:
                                repeat_times = rwkv_r_weight.shape[1] // qwen_q_weight.shape[1]
                                remainder = rwkv_r_weight.shape[1] % qwen_q_weight.shape[1]
                                if remainder > 0:
                                    qwen_q_weight = torch.cat([qwen_q_weight.repeat(1, repeat_times), qwen_q_weight[:, :remainder]], dim=1)
                                else:
                                    qwen_q_weight = qwen_q_weight.repeat(1, repeat_times)
                        # Handle output dimension mismatch
                        if qwen_q_weight.shape[0] != rwkv_r_weight.shape[0]:
                            if qwen_q_weight.shape[0] > rwkv_r_weight.shape[0]:
                                qwen_q_weight = qwen_q_weight[:rwkv_r_weight.shape[0], :]
                            else:
                                repeat_times = rwkv_r_weight.shape[0] // qwen_q_weight.shape[0]
                                remainder = rwkv_r_weight.shape[0] % qwen_q_weight.shape[0]
                                if remainder > 0:
                                    qwen_q_weight = torch.cat([qwen_q_weight.repeat(repeat_times, 1), qwen_q_weight[:remainder, :]], dim=0)
                                else:
                                    qwen_q_weight = qwen_q_weight.repeat(repeat_times, 1)
                    if qwen_q_weight.device != rwkv_r_weight.device:
                        qwen_q_weight = qwen_q_weight.to(rwkv_r_weight.device)
                    rwkv_r_weight.copy_(qwen_q_weight)
                    print(f"  ✅ Initialized receptance weight: {rwkv_r_key}")
                except Exception as e:
                    print(f"  ❌ Error initializing receptance weight: {e}")

            # Initialize key using key weights
            if qwen_k_key in qwen_state_dict and rwkv_k_key in hybrid_state_dict:
                try:
                    qwen_k_weight = qwen_state_dict[qwen_k_key]
                    rwkv_k_weight = hybrid_state_dict[rwkv_k_key]
                    if qwen_k_weight.shape != rwkv_k_weight.shape:
                        print(f"  ⚠️ Key shape mismatch: Qwen {qwen_k_weight.shape} -> RWKV {rwkv_k_weight.shape}")
                        if qwen_k_weight.shape[1] != rwkv_k_weight.shape[1]:
                            if qwen_k_weight.shape[1] > rwkv_k_weight.shape[1]:
                                qwen_k_weight = qwen_k_weight[:, :rwkv_k_weight.shape[1]]
                            else:
                                repeat_times = rwkv_k_weight.shape[1] // qwen_k_weight.shape[1]
                                remainder = rwkv_k_weight.shape[1] % qwen_k_weight.shape[1]
                                if remainder > 0:
                                    qwen_k_weight = torch.cat([qwen_k_weight.repeat(1, repeat_times), qwen_k_weight[:, :remainder]], dim=1)
                                else:
                                    qwen_k_weight = qwen_k_weight.repeat(1, repeat_times)
                        if qwen_k_weight.shape[0] != rwkv_k_weight.shape[0]:
                            if qwen_k_weight.shape[0] > rwkv_k_weight.shape[0]:
                                qwen_k_weight = qwen_k_weight[:rwkv_k_weight.shape[0], :]
                            else:
                                repeat_times = rwkv_k_weight.shape[0] // qwen_k_weight.shape[0]
                                remainder = rwkv_k_weight.shape[0] % qwen_k_weight.shape[0]
                                if remainder > 0:
                                    qwen_k_weight = torch.cat([qwen_k_weight.repeat(repeat_times, 1), qwen_k_weight[:remainder, :]], dim=0)
                                else:
                                    qwen_k_weight = qwen_k_weight.repeat(repeat_times, 1)
                    if qwen_k_weight.device != rwkv_k_weight.device:
                        qwen_k_weight = qwen_k_weight.to(rwkv_k_weight.device)
                    rwkv_k_weight.copy_(qwen_k_weight)
                    print(f"  ✅ Initialized key weight: {rwkv_k_key}")
                except Exception as e:
                    print(f"  ❌ Error initializing key weight: {e}")

            # Initialize value using value weights
            if qwen_v_key in qwen_state_dict and rwkv_v_key in hybrid_state_dict:
                try:
                    qwen_v_weight = qwen_state_dict[qwen_v_key]
                    rwkv_v_weight = hybrid_state_dict[rwkv_v_key]
                    if qwen_v_weight.shape != rwkv_v_weight.shape:
                        print(f"  ⚠️ Value shape mismatch: Qwen {qwen_v_weight.shape} -> RWKV {rwkv_v_weight.shape}")
                        if qwen_v_weight.shape[1] != rwkv_v_weight.shape[1]:
                            if qwen_v_weight.shape[1] > rwkv_v_weight.shape[1]:
                                qwen_v_weight = qwen_v_weight[:, :rwkv_v_weight.shape[1]]
                            else:
                                repeat_times = rwkv_v_weight.shape[1] // qwen_v_weight.shape[1]
                                remainder = rwkv_v_weight.shape[1] % qwen_v_weight.shape[1]
                                if remainder > 0:
                                    qwen_v_weight = torch.cat([qwen_v_weight.repeat(1, repeat_times), qwen_v_weight[:, :remainder]], dim=1)
                                else:
                                    qwen_v_weight = qwen_v_weight.repeat(1, repeat_times)
                        if qwen_v_weight.shape[0] != rwkv_v_weight.shape[0]:
                            if qwen_v_weight.shape[0] > rwkv_v_weight.shape[0]:
                                qwen_v_weight = qwen_v_weight[:rwkv_v_weight.shape[0], :]
                            else:
                                repeat_times = rwkv_v_weight.shape[0] // qwen_v_weight.shape[0]
                                remainder = rwkv_v_weight.shape[0] % qwen_v_weight.shape[0]
                                if remainder > 0:
                                    qwen_v_weight = torch.cat([qwen_v_weight.repeat(repeat_times, 1), qwen_v_weight[:remainder, :]], dim=0)
                                else:
                                    qwen_v_weight = qwen_v_weight.repeat(repeat_times, 1)
                    if qwen_v_weight.device != rwkv_v_weight.device:
                        qwen_v_weight = qwen_v_weight.to(rwkv_v_weight.device)
                    rwkv_v_weight.copy_(qwen_v_weight)
                    print(f"  ✅ Initialized value weight: {rwkv_v_key}")
                except Exception as e:
                    print(f"  ❌ Error initializing value weight: {e}")

            # Initialize output using output weights
            if qwen_o_key in qwen_state_dict and rwkv_o_key in hybrid_state_dict:
                try:
                    qwen_o = qwen_state_dict[qwen_o_key]
                    rwkv_o = hybrid_state_dict[rwkv_o_key]
                    if qwen_o.shape != rwkv_o.shape:
                        print(f"  ⚠️ Output shape mismatch: Qwen {qwen_o.shape} -> RWKV {rwkv_o.shape}")
                        if qwen_o.shape[0] > rwkv_o.shape[0]:
                            qwen_o = qwen_o[:rwkv_o.shape[0]]
                        else:
                            repeat_times = rwkv_o.shape[0] // qwen_o.shape[0]
                            remainder = rwkv_o.shape[0] % qwen_o.shape[0]
                            if remainder > 0:
                                qwen_o = torch.cat([qwen_o.repeat(repeat_times), qwen_o[:remainder]], dim=0)
                            else:
                                qwen_o = qwen_o.repeat(repeat_times)
                    if qwen_o.device != rwkv_o.device:
                        qwen_o = qwen_o.to(rwkv_o.device)
                    rwkv_o.copy_(qwen_o)
                    print(f"  ✅ Initialized output: {rwkv_o_key}")
                except Exception as e:
                    print(f"  ❌ Error initializing output: {e}")

            # Initialize q_proj using original attention q
            if qwen_q_key in qwen_state_dict and rwkv_q_key in hybrid_state_dict:
                try:
                    qwen_q_weight = qwen_state_dict[qwen_q_key]
                    rwkv_q_weight = hybrid_state_dict[rwkv_q_key]
                    if qwen_q_weight.shape != rwkv_q_weight.shape:
                        print(f"  ⚠️ q_proj shape mismatch: Qwen {qwen_q_weight.shape} -> RWKV {rwkv_q_weight.shape}")
                        if qwen_q_weight.shape[1] != rwkv_q_weight.shape[1]:
                            if qwen_q_weight.shape[1] > rwkv_q_weight.shape[1]:
                                qwen_q_weight = qwen_q_weight[:, :rwkv_q_weight.shape[1]]
                            else:
                                repeat_times = rwkv_q_weight.shape[1] // qwen_q_weight.shape[1]
                                remainder = rwkv_q_weight.shape[1] % qwen_q_weight.shape[1]
                                if remainder > 0:
                                    qwen_q_weight = torch.cat([qwen_q_weight.repeat(1, repeat_times), qwen_q_weight[:, :remainder]], dim=1)
                                else:
                                    qwen_q_weight = qwen_q_weight.repeat(1, repeat_times)
                        if qwen_q_weight.shape[0] != rwkv_q_weight.shape[0]:
                            if qwen_q_weight.shape[0] > rwkv_q_weight.shape[0]:
                                qwen_q_weight = qwen_q_weight[:rwkv_q_weight.shape[0], :]
                            else:
                                repeat_times = rwkv_q_weight.shape[0] // qwen_q_weight.shape[0]
                                remainder = rwkv_q_weight.shape[0] % qwen_q_weight.shape[0]
                                if remainder > 0:
                                    qwen_q_weight = torch.cat([qwen_q_weight.repeat(repeat_times, 1), qwen_q_weight[:remainder, :]], dim=0)
                                else:
                                    qwen_q_weight = qwen_q_weight.repeat(repeat_times, 1)
                    if qwen_q_weight.device != rwkv_q_weight.device:
                        qwen_q_weight = qwen_q_weight.to(rwkv_q_weight.device)
                    rwkv_q_weight.copy_(qwen_q_weight)
                    print(f"  ✅ Initialized q_proj: {rwkv_q_key}")
                    loaded_keys.append(rwkv_q_key)
                    if rwkv_q_key in skipped_keys:
                        skipped_keys.remove(rwkv_q_key)
                except Exception as e:
                    print(f"  ❌ Error initializing q_proj: {e}")

            # Initialize cross_attn_kv_proj using original attention k and v concatenated
            if (qwen_k_key in qwen_state_dict and qwen_v_key in qwen_state_dict and 
                rwkv_cross_attn_kv_proj_key in hybrid_state_dict):
                try:
                    qwen_k_weight = qwen_state_dict[qwen_k_key]
                    qwen_v_weight = qwen_state_dict[qwen_v_key]
                    rwkv_kv_weight = hybrid_state_dict[rwkv_cross_attn_kv_proj_key]
                    qwen_kv_weight = torch.cat([qwen_k_weight, qwen_v_weight], dim=0)
                    if qwen_kv_weight.shape != rwkv_kv_weight.shape:
                        print(f"  ⚠️ cross_attn_kv_proj shape mismatch: Qwen {qwen_kv_weight.shape} -> RWKV {rwkv_kv_weight.shape}")
                        if qwen_kv_weight.shape[1] != rwkv_kv_weight.shape[1]:
                            if qwen_kv_weight.shape[1] > rwkv_kv_weight.shape[1]:
                                qwen_kv_weight = qwen_kv_weight[:, :rwkv_kv_weight.shape[1]]
                            else:
                                repeat_times = rwkv_kv_weight.shape[1] // qwen_kv_weight.shape[1]
                                remainder = rwkv_kv_weight.shape[1] % qwen_kv_weight.shape[1]
                                if remainder > 0:
                                    qwen_kv_weight = torch.cat([qwen_kv_weight.repeat(1, repeat_times), qwen_kv_weight[:, :remainder]], dim=1)
                                else:
                                    qwen_kv_weight = qwen_kv_weight.repeat(1, repeat_times)
                        if qwen_kv_weight.shape[0] != rwkv_kv_weight.shape[0]:
                            if qwen_kv_weight.shape[0] > rwkv_kv_weight.shape[0]:
                                qwen_kv_weight = qwen_kv_weight[:rwkv_kv_weight.shape[0], :]
                            else:
                                repeat_times = rwkv_kv_weight.shape[0] // qwen_kv_weight.shape[0]
                                remainder = rwkv_kv_weight.shape[0] % qwen_kv_weight.shape[0]
                                if remainder > 0:
                                    qwen_kv_weight = torch.cat([qwen_kv_weight.repeat(repeat_times, 1), qwen_kv_weight[:remainder, :]], dim=0)
                                else:
                                    qwen_kv_weight = qwen_kv_weight.repeat(repeat_times, 1)
                    if qwen_kv_weight.device != rwkv_kv_weight.device:
                        qwen_kv_weight = qwen_kv_weight.to(rwkv_kv_weight.device)
                    rwkv_kv_weight.copy_(qwen_kv_weight)
                    print(f"  ✅ Initialized cross_attn_kv_proj: {rwkv_cross_attn_kv_proj_key}")
                    loaded_keys.append(rwkv_cross_attn_kv_proj_key)
                    if rwkv_cross_attn_kv_proj_key in skipped_keys:
                        skipped_keys.remove(rwkv_cross_attn_kv_proj_key)
                except Exception as e:
                    print(f"  ❌ Error initializing cross_attn_kv_proj: {e}")

            # Handle bias (if exists)
            # q_proj bias
            qwen_q_bias_key = qwen_q_key.replace('.weight', '.bias')
            if (qwen_q_bias_key in qwen_state_dict and rwkv_q_bias in hybrid_state_dict):
                try:
                    qwen_q_bias = qwen_state_dict[qwen_q_bias_key]
                    rwkv_q_bias_tensor = hybrid_state_dict[rwkv_q_bias]
                    if qwen_q_bias.shape != rwkv_q_bias_tensor.shape:
                        print(f"  ⚠️ q_proj bias shape mismatch: Qwen {qwen_q_bias.shape} -> RWKV {rwkv_q_bias_tensor.shape}")
                        if qwen_q_bias.shape[0] > rwkv_q_bias_tensor.shape[0]:
                            qwen_q_bias = qwen_q_bias[:rwkv_q_bias_tensor.shape[0]]
                        else:
                            repeat_times = rwkv_q_bias_tensor.shape[0] // qwen_q_bias.shape[0]
                            remainder = rwkv_q_bias_tensor.shape[0] % qwen_q_bias.shape[0]
                            if remainder > 0:
                                qwen_q_bias = torch.cat([qwen_q_bias.repeat(repeat_times), qwen_q_bias[:remainder]], dim=0)
                            else:
                                qwen_q_bias = qwen_q_bias.repeat(repeat_times)
                    if qwen_q_bias.device != rwkv_q_bias_tensor.device:
                        qwen_q_bias = qwen_q_bias.to(rwkv_q_bias_tensor.device)
                    rwkv_q_bias_tensor.copy_(qwen_q_bias)
                    print(f"  ✅ Initialized q_proj bias: {rwkv_q_bias}")
                    loaded_keys.append(rwkv_q_bias)
                    if rwkv_q_bias in skipped_keys:
                        skipped_keys.remove(rwkv_q_bias)
                except Exception as e:
                    print(f"  ❌ Error initializing q_proj bias: {e}")

            # cross_attn_kv_proj bias
            qwen_k_bias_key = qwen_k_key.replace('.weight', '.bias')
            qwen_v_bias_key = qwen_v_key.replace('.weight', '.bias')
            if (qwen_k_bias_key in qwen_state_dict and qwen_v_bias_key in qwen_state_dict and 
                rwkv_cross_attn_kv_proj_bias in hybrid_state_dict):
                try:
                    qwen_k_bias = qwen_state_dict[qwen_k_bias_key]
                    qwen_v_bias = qwen_state_dict[qwen_v_bias_key]
                    rwkv_kv_bias = hybrid_state_dict[rwkv_cross_attn_kv_proj_bias]
                    qwen_kv_bias = torch.cat([qwen_k_bias, qwen_v_bias], dim=0)
                    if qwen_kv_bias.shape != rwkv_kv_bias.shape:
                        print(f"  ⚠️ cross_attn_kv_proj bias shape mismatch: Qwen {qwen_kv_bias.shape} -> RWKV {rwkv_kv_bias.shape}")
                        if qwen_kv_bias.shape[0] > rwkv_kv_bias.shape[0]:
                            qwen_kv_bias = qwen_kv_bias[:rwkv_kv_bias.shape[0]]
                        else:
                            repeat_times = rwkv_kv_bias.shape[0] // qwen_kv_bias.shape[0]
                            remainder = rwkv_kv_bias.shape[0] % qwen_kv_bias.shape[0]
                            if remainder > 0:
                                qwen_kv_bias = torch.cat([qwen_kv_bias.repeat(repeat_times), qwen_kv_bias[:remainder]], dim=0)
                            else:
                                qwen_kv_bias = qwen_kv_bias.repeat(repeat_times)
                    if qwen_kv_bias.device != rwkv_kv_bias.device:
                        qwen_kv_bias = qwen_kv_bias.to(rwkv_kv_bias.device)
                    rwkv_kv_bias.copy_(qwen_kv_bias)
                    print(f"  ✅ Initialized cross_attn_kv_proj bias: {rwkv_cross_attn_kv_proj_bias}")
                    loaded_keys.append(rwkv_cross_attn_kv_proj_bias)
                    if rwkv_cross_attn_kv_proj_bias in skipped_keys:
                        skipped_keys.remove(rwkv_cross_attn_kv_proj_bias)
                except Exception as e:
                    print(f"  ❌ Error initializing cross_attn_kv_proj bias: {e}")

        print(f"\n📊 Final Statistics: Successfully loaded {len(loaded_keys)} parameters")
        print(f"📊 Final Statistics: Skipped {len(skipped_keys)} parameters (RWKV layers or mismatched keys)")

        # 8. Ensure model is on correct device
        hybrid_model = hybrid_model.to(device)

        # Calculate and print final parameter counts
        final_hybrid_total_params, final_hybrid_trainable_params = count_parameters(hybrid_model)
        print(f"\n📊 Final Hybrid Model Parameters:")
        print(f"  Total: {format_params(final_hybrid_total_params)}")
        print(f"  Trainable: {format_params(final_hybrid_trainable_params)}")

        # Compare parameter counts
        print(f"\n📊 Parameter Comparison:")
        print(f"  Original Qwen Model: {format_params(qwen_total_params)}")
        print(f"  Hybrid Model: {format_params(final_hybrid_total_params)}")
        if final_hybrid_total_params > qwen_total_params:
            increase_percent = ((final_hybrid_total_params - qwen_total_params) / qwen_total_params * 100)
            print(f"  Increase: {format_params(final_hybrid_total_params - qwen_total_params)} (+{increase_percent:.2f}%)")
        else:
            decrease_percent = ((qwen_total_params - final_hybrid_total_params) / qwen_total_params * 100)
            print(f"  Decrease: {format_params(qwen_total_params - final_hybrid_total_params)} (-{decrease_percent:.2f}%)")

        # 9. Save model if save_path is specified
        if save_path is not None:
            print(f"\n💾 Saving model to {save_path}")
            os.makedirs(save_path, exist_ok=True)
            
            # Save model weights and config
            hybrid_model.save_pretrained(save_path)
            hybrid_config.save_pretrained(save_path)
            
            # Copy replace_files directory to model directory
            replace_files_dir = "./model"
            
            if os.path.exists(replace_files_dir):
                print(f"📋 Copying files from {replace_files_dir} to {save_path}")
                for item in os.listdir(replace_files_dir):
                    src_path = os.path.join(replace_files_dir, item)
                    dst_path = os.path.join(save_path, item)
                    if os.path.isfile(src_path):
                        shutil.copy2(src_path, dst_path)
                        print(f"  ✅ Copied file: {item}")
                    elif os.path.isdir(src_path):
                        if os.path.exists(dst_path):
                            shutil.rmtree(dst_path)
                        shutil.copytree(src_path, dst_path)
                        print(f"  ✅ Copied directory: {item}")
                print("✅ File copying completed")
            else:
                print(f"⚠️ Warning: {replace_files_dir} directory does not exist")
            
            print(f"✅ Model saved to {save_path}")

        return hybrid_model

    except Exception as e:
        print(f"❌ Error during weight loading: {e}")
        raise


hybrid_model = load_qwen_weights_to_hybrid_model(
    "Qwen/Qwen2.5-VL-7B-Instruct", 
    device="cuda",
    save_path="/path/to/your/output/hybrid_model"
)