import re

import torch.nn as nn
from transformers import PreTrainedModel
from peft import LoraConfig, get_peft_model, TaskType
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm


def init_transformer_block_weights(block, std=0.02):
    """
    Applies initialization to Linear, LayerNorm and RMSNorm layers in the block.
    """
    print(f"--- Initializing block: {block.__class__.__name__} with std={std} ---")
    initialized_modules = 0
    skipped_types = set()

    for name, module in block.named_modules(): # Using named_modules for better logging
        module_type_name = module.__class__.__name__

        if isinstance(module, nn.Linear):
            print(f"  Initializing Linear weights: {name} ({module_type_name})")
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                print(f"  Initializing Linear bias: {name} ({module_type_name})")
                module.bias.data.zero_()
            initialized_modules += 1
        elif isinstance(module, nn.LayerNorm):
            print(f"  Initializing LayerNorm: {name} ({module_type_name})")
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            initialized_modules += 1
        # ----- NEW: Processing RMSNorm -----
        elif Qwen2RMSNorm is not None and isinstance(module, Qwen2RMSNorm):
            print(f"  Initializing RMSNorm weight: {name} ({module_type_name})")
            if hasattr(module, 'weight') and module.weight is not None:
                 module.weight.data.fill_(1.0)
            else:
                 print(f"    WARNING: RMSNorm module {name} does not have 'weight' attribute or it's None.")
            # RMSNorm usually does not have bias, so we don't initialize it
            initialized_modules += 1
        # ------------------------------------
        elif name == "": # Skip the root block itself
             pass
        else:
             # Log module types that were NOT initialized (excluding basic containers)
             is_container = isinstance(module, (nn.ModuleList, nn.Sequential, nn.ModuleDict)) or len(list(module.children())) > 0
             if not is_container and module_type_name not in skipped_types:
                  # print(f"  Skipping module type: {module_type_name} (name: {name})")
                  skipped_types.add(module_type_name)


    print(f"--- Finished initializing block. Processed {initialized_modules} Linear/LayerNorm/RMSNorm modules. ---")
    if skipped_types:
        print(f"--- Skipped module types encountered: {', '.join(skipped_types)} ---")



def setup_lora(model: PreTrainedModel, lora_cfg: dict) -> PreTrainedModel:
    target_modules_patterns = lora_cfg.pop('target_modules_patterns', [])

    all_module_names = [name for name, _ in model.named_modules()]
    final_target_modules = set()

    if target_modules_patterns:
        print("\n=== Applying LoRA to modules based on patterns: ===")
        for pattern_str in target_modules_patterns:
            try:
                pattern = re.compile(pattern_str)
            except re.error as e:
                print(f"Warning: Invalid regex pattern '{pattern_str}': {e}. Skipping this pattern.")
                continue
            
            found_for_pattern = False
            for module_name in all_module_names:
                if pattern.fullmatch(module_name):
                    final_target_modules.add(module_name)
                    if not found_for_pattern:
                        print(f"  Pattern '{pattern_str}' matched:")
                        found_for_pattern = True
                    print(f"    - {module_name}")
            if not found_for_pattern:
                print(f"  Pattern '{pattern_str}' did not match any module names.")
        print("==================================================")
    else:
        print("Warning: No target_modules_patterns specified in LoRA config.")

    if not final_target_modules:
        print("Warning: No modules were matched by the provided patterns for LoRA. LoRA will not be applied to any specific layers directly via target_modules. Check your patterns.")

    lora_config_params = {
        **lora_cfg,
        'task_type': TaskType.CAUSAL_LM,
        'target_modules': list(final_target_modules) if final_target_modules else None,
    }

    lora_config = LoraConfig(**lora_config_params)
    
    model = get_peft_model(model, lora_config)
    print("\n=== LoRA Model Trainable Parameters (after get_peft_model) ===")
    model.print_trainable_parameters()
    print("===========================================================")
    return model

def set_trainable_parameters(model: PreTrainedModel) -> PreTrainedModel:
    for param in model.base_model.model.heat_embedding.parameters():
         param.requires_grad = True
    for param in model.base_model.model.visual.heat_block.parameters():
         param.requires_grad = True

    model.print_trainable_parameters()
    return model
