"""
PEFT (Parameter-Efficient Fine-Tuning) Wrapper Module

This module provides wrapper functions for applying Parameter-Efficient Fine-Tuning techniques
to large language models, specifically implementing LoRA (Low-Rank Adaptation) for efficient
model fine-tuning with minimal computational overhead.

PEFT enables fine-tuning large pre-trained models by adding small trainable adapter layers
while keeping the original model parameters frozen, significantly reducing memory requirements
and training time while maintaining competitive performance.

Key Features:
    - LoRA (Low-Rank Adaptation) wrapper for LLaMA models
    - Registry-based PEFT technique selection
    - Configurable rank, alpha, and dropout parameters
    - Selective target module fine-tuning
    - Automatic trainable parameter reporting

Supported PEFT Methods:
    - LoRA: Low-rank matrix decomposition for attention layers

Dependencies:
    - peft: Hugging Face PEFT library for parameter-efficient fine-tuning
    - utils.registry: Custom registry system for PEFT wrapper selection

Technical Background:
    LoRA works by decomposing weight updates into low-rank matrices (A and B) such that
    ΔW = BA, where rank(ΔW) << min(d_in, d_out). This reduces trainable parameters from
    d_in × d_out to rank × (d_in + d_out), enabling efficient fine-tuning.

Author: AI Model Development Team
License: MIT
"""

import logging

from utils.registry import PEFT_WRAPPER

from peft import LoraConfig, TaskType, get_peft_model


logger = logging.getLogger(__name__)


@PEFT_WRAPPER.register("lora")
def lora_wrapper(
    llama_model, 
    lora_inference_mode, 
    lora_rank, 
    lora_alpha, 
    lora_dropout=0.1, 
    lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
):
    """
    Apply LoRA (Low-Rank Adaptation) wrapper to a LLaMA model for parameter-efficient fine-tuning.
    
    This function configures and applies LoRA adaptation to specified modules of a LLaMA model,
    enabling efficient fine-tuning by introducing low-rank decomposition matrices to attention
    projection layers. LoRA significantly reduces the number of trainable parameters while
    maintaining model performance.
    
    The LoRA technique decomposes weight updates as ΔW = BA, where B ∈ R^(d×r) and A ∈ R^(r×k)
    with rank r << min(d,k), making fine-tuning computationally efficient and memory-friendly.
    
    Args:
        llama_model (LlamaForCausalLM): The base LLaMA model to which LoRA adaptation will be applied.
                                      This should be a pre-trained model instance from transformers.
        lora_inference_mode (bool): Whether to set the LoRA adapter in inference mode.
                                   When True, adapter weights are merged with base weights for
                                   faster inference. When False, adapters remain separate for training.
        lora_rank (int): The rank (r) of the low-rank decomposition matrices. Higher ranks provide
                        more expressiveness but increase trainable parameters. Typical values: 4-64.
        lora_alpha (float): The scaling factor for LoRA adaptation. Controls the magnitude of
                           adaptations applied to the base model. Formula: scaling = alpha / rank.
        lora_dropout (float, optional): Dropout probability applied to LoRA layers during training
                                       for regularization. Defaults to 0.1 (10% dropout).
        lora_target_modules (list, optional): List of module names to apply LoRA adaptation to.
                                             Defaults to attention projection layers:
                                             ["q_proj", "k_proj", "v_proj", "o_proj"].
    
    Returns:
        PeftModel: A PEFT-wrapped model with LoRA adapters applied to specified modules.
                  The returned model has additional methods for adapter management and
                  efficient parameter updates during fine-tuning.
    
    Technical Details:
        - PEFT Type: LORA (Low-Rank Adaptation)
        - Task Type: CAUSAL_LM (Causal Language Modeling)
        - Target Modules: Attention projection layers (Query, Key, Value, Output)
        - Parameter Reduction: ~99% reduction in trainable parameters for typical configurations
    
    Example:
        >>> from transformers import LlamaForCausalLM
        >>> base_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> lora_model = lora_wrapper(
        ...     llama_model=base_model,
        ...     lora_inference_mode=False,
        ...     lora_rank=16,
        ...     lora_alpha=32,
        ...     lora_dropout=0.1,
        ...     lora_target_modules=["q_proj", "v_proj"]
        ... )
        >>> # Model now has LoRA adapters and can be fine-tuned efficiently
    
    Note:
        - The function automatically prints trainable parameter statistics
        - Only specified target modules will have LoRA adapters applied
        - Base model parameters remain frozen during fine-tuning
        - The wrapped model can be saved/loaded using PEFT-specific methods
    """
    # Configure LoRA parameters with specified settings
    peft_config = LoraConfig(
        peft_type="LORA",                              # Use LoRA adaptation technique
        task_type="CAUSAL_LM",                         # TaskType.CAUSAL_LM for language modeling
        inference_mode=lora_inference_mode,            # Set inference/training mode
        r=lora_rank,                                   # Rank of low-rank decomposition
        lora_alpha=lora_alpha,                         # Scaling factor for adaptations
        lora_dropout=lora_dropout,                     # Dropout for regularization
        target_modules=list(lora_target_modules)       # Modules to apply LoRA to
    )
    
    # Apply PEFT configuration to the base model
    wrapped_model = get_peft_model(llama_model, peft_config)
    
    # Display trainable parameter statistics for verification
    wrapped_model.print_trainable_parameters()
    
    # Log successful LoRA application
    logger.info('LoRA Training configuration applied successfully')
    
    return wrapped_model