from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
import re
import torch
import os

def infer_target_modules(model_name):
    if "llama" in model_name.lower():
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    elif "phi" in model_name.lower():
        return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    elif "qwen" in model_name.lower() or "deepseek" in model_name.lower():
        return ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
    elif "olmo" in model_name.lower():
        return ["mlp.fc1", "mlp.fc2", "self_attn.q_proj", "self_attn.out_proj"]
    else:
        raise ValueError(f"Unsupported or unknown model type for: {model_name}")

def load_model_tokenizer_with_lora(model_name, quantize=True):
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # Configure model loading parameters
    model_kwargs = {
        "trust_remote_code": True,
        "torch_dtype": torch.bfloat16,
        "device_map": "auto",
        "max_memory": {0: "75GiB"},  # Reserve some memory for other operations
        "offload_folder": "offload",  # Enable disk offloading if needed
    }

    # Add quantization config if needed
    if quantize:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
        )
        model_kwargs["quantization_config"] = bnb_config

    # Set environment variable for memory allocation
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

    # Load model with proper initialization
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        **model_kwargs,
        low_cpu_mem_usage=True,
    )

    # Enable gradient checkpointing
    model.gradient_checkpointing_enable()

    # Prepare for k-bit training
    model = prepare_model_for_kbit_training(model)

    # Get target modules
    target_modules = infer_target_modules(model_name)

    # Define LoRA config with smaller parameters
    lora_config = LoraConfig(
        r=8,             # Reduced from 16
        lora_alpha=16,   # Reduced from 32
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
        inference_mode=False,
    )

    # Apply LoRA
    model = get_peft_model(model, lora_config)
    
    return model, tokenizer
