"""
Custom PEFT utilities to support trainable_token_indices and proper embedding handling.
"""
from typing import Optional, Union, List, Dict
from peft import LoraConfig
from trl import ModelConfig
import torch
import torch.nn as nn
import logging


def get_peft_config_with_trainable_tokens(
    model_config: ModelConfig,
    trainable_token_indices: Optional[Union[List[int], Dict[str, List[int]]]] = None,
    train_lm_head: bool = False,
) -> "Optional[LoraConfig]":
    """
    Enhanced version of TRL's get_peft_config that supports trainable_token_indices.
    
    When trainable_token_indices is specified:
    - It applies only to embedding layers (PEFT ignores it for linear layers)
    - Regular LoRA is applied to all other target modules
    - This gives you both specific token training AND model adaptation
    
    Args:
        model_config: TRL ModelConfig object
        trainable_token_indices: Token indices to selectively fine-tune in embedding layers
            - List[int]: Applied to embedding layers, ignored for linear layers
            - None: Regular LoRA on all target modules
    
    Returns:
        LoraConfig that applies trainable_token_indices to embeddings and LoRA to linear layers
    """
    if model_config.use_peft is False:
        return None

    _trainable_token_indices={'embed_tokens': trainable_token_indices}
    if train_lm_head:
        logging.info(f"Training lm_head with trainable_token_indices: {trainable_token_indices}")
        _trainable_token_indices['lm_head'] = trainable_token_indices

    peft_config = LoraConfig(
        task_type=getattr(model_config, 'lora_task_type', 'CAUSAL_LM'),
        r=model_config.lora_r,
        target_modules=model_config.lora_target_modules,  # Keep all modules
        lora_alpha=model_config.lora_alpha,
        lora_dropout=getattr(model_config, 'lora_dropout', 0.0),
        bias="none",
        use_rslora=getattr(model_config, 'use_rslora', False),
        use_dora=getattr(model_config, 'use_dora', False),
        modules_to_save=getattr(model_config, 'lora_modules_to_save', None),
        trainable_token_indices=_trainable_token_indices
    )

    return peft_config


def embeddings_are_tied(model) -> bool:
    """Detect if input and output embeddings are tied.
    
    Prefer model.config.tie_word_embeddings and presence of _tied_weights_keys if available,
    with a safe fallback to pointer equality of weights.
    """
    try:
        cfg = getattr(model, 'config', None)
        if cfg is not None:
            tie_cfg = getattr(cfg, 'tie_word_embeddings', False)
            has_keys = getattr(model, '_tied_weights_keys', None) is not None
            if tie_cfg and has_keys:
                return True
    except Exception:
        pass

    try:
        output_emb = model.get_output_embeddings()
        input_emb = model.get_input_embeddings()
    except Exception:
        return False
    if output_emb is None or input_emb is None:
        return False
    try:
        return output_emb.weight is input_emb.weight
    except AttributeError:
        return False
