from typing import Optional, Tuple, Union
import os
import torch
from accelerate import PartialState
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
    GenerationConfig
)
from trl import AutoModelForCausalLMWithValueHead


def load_model_and_tokenizer(model_name, 
                             peft_config=None,
                             wrap_value_head=False):
    device_map = {"": PartialState().process_index}
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    if wrap_value_head:
        model = AutoModelForCausalLMWithValueHead.from_pretrained(
            model_name,
            trust_remote_code=True,
            device_map=device_map,
            quantization_config=bnb_config
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            trust_remote_code=True,
            device_map=device_map,
            quantization_config=bnb_config
        )

    if peft_config:
        model = get_peft_model(model, peft_config)
    
    if wrap_value_head:
        gen_config = GenerationConfig.from_pretrained(model_name)
        model.generation_config = gen_config

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    return model.eval(), tokenizer


def load_model_and_tokenizer_eval(model_name, peft_checkpoint=None, quantization="4bit"):
    """
    Loads a model and tokenizer, with optional LoRA PEFT checkpoint.
    
    Args:
        model_name (str): The base model name or path.
        peft_checkpoint (str, optional): Path to the LoRA checkpoint.
        quantization (str, optional): Quantization type ("4bit" or "16bit").
        for_eval (bool): If True, merges LoRA adapters for inference.

    Returns:
        model (AutoModelForCausalLM): Loaded model (with LoRA merged for eval).
        tokenizer (AutoTokenizer): Tokenizer.
    """
    if quantization == "4bit":
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=False,
            bnb_4bit_compute_dtype=torch.bfloat16
        )
        torch_dtype = torch.bfloat16
    else:
        bnb_config = None
        torch_dtype = torch.float16
    
    device_map = {"": PartialState().process_index}
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        device_map=device_map,
        quantization_config=bnb_config,
        torch_dtype=torch_dtype
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    merged_model_path = None

    if peft_checkpoint:
        print(f"Loading LoRA checkpoint from {peft_checkpoint}...")
        model = PeftModel.from_pretrained(model, peft_checkpoint)
        # For evaluating, merge LoRA into base model for efficient inference
        print("Merging LoRA weights for evaluation...")
        merged_model_path = os.path.join(peft_checkpoint, "merged_model")
        model = model.merge_and_unload()
        model.save_pretrained(merged_model_path)
        tokenizer.save_pretrained(merged_model_path)

    return model.eval(), tokenizer, merged_model_path

