from typing import Dict, Any, Tuple
from unsloth import FastLanguageModel, is_bfloat16_supported


def build_model_and_tokenizer(paths: Dict[str, Any], max_seq_length: int, chat_template_cfg: Dict[str, Any],
                              unsloth_cfg: Dict[str, Any], lora_cfg: Dict[str, Any]):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=paths["base_model"],
        max_seq_length=max_seq_length,
        load_in_4bit=unsloth_cfg.get("load_in_4bit", True),
        dtype=unsloth_cfg.get("dtype", None),
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=lora_cfg.get("r", 16),
        lora_alpha=lora_cfg.get("lora_alpha", 16),
        lora_dropout=lora_cfg.get("lora_dropout", 0.0),
        target_modules=lora_cfg.get("target_modules", ["q_proj","k_proj","v_proj","up_proj","down_proj","o_proj","gate_proj"]),
        use_rslora=lora_cfg.get("use_rslora", True),
        use_gradient_checkpointing=unsloth_cfg.get("use_gradient_checkpointing", "unsloth"),
    )

    return model, tokenizer


def fp16_bf16_flags() -> Tuple[bool, bool]:
    bf16 = is_bfloat16_supported()
    return (not bf16), bf16


def save_model_and_tokenizer(model, tokenizer, save_dir: str):
    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)