import logging
import os
from pathlib import Path
from typing import Literal, Callable, Optional, Mapping, Any

import numpy as np
import torch
import transformers
from huggingface_hub import hf_hub_download
from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model
from peft.tuners.lora import LoraLayer
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from bof4.quantization import Quantizer, load_from_file, patch_bnb
from bof4.util import (
    smart_tokenizer_and_embedding_resize,
    fix_untrained_tokens,
    prepare_model_for_kbit_training,
)

_logger = logging.getLogger(__name__)


def _is_linear_layer(layer_name):
    target_modules = [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ]
    for target in target_modules:
        if target in layer_name:
            return True
    return False


def linear_layers(model):
    for name, module in model.named_modules():
        if _is_linear_layer(name):
            yield name, module


@torch.inference_mode()
def quantize_linear_layers(
    model,
    quantizer: Quantizer,
    predicate: Callable[[str], bool] = lambda _: True,
    quantization_device: str = "cuda",
):
    torch.cuda.empty_cache()
    for name, module in linear_layers(model):
        if not predicate(name):
            _logger.info(
                f"Skipping quantization of {name}, name does not meet the provided condition"
            )
            continue
        _logger.info(
            f"Quantizing {name}, with {getattr(quantizer, 'name', 'unnamed quantizer')}"
        )
        if quantization_device is not None:
            # temporarily move the tensor to the quantization device for quantization
            module_device = module.weight.device
            module.weight.to(quantization_device)
            quantizer.quantize_inplace(module.weight)
            module.weight.to(module_device)
        else:
            quantizer.quantize_inplace(module.weight)


def is_adapter_model(hf_id_or_path: str):
    try:
        if os.path.isdir(hf_id_or_path) and (Path(hf_id_or_path) / "adapter_config.json").exists():
            return True
        else:
            hf_hub_download(repo_id=hf_id_or_path, filename="adapter_config.json")
            return True
    except Exception:
        pass
    return False


QUANTIZER_CONFIG_FILENAME = "quantizer.yaml"
DEFAULT_PAD_TOKEN = "[PAD]"

def get_quantizer_from_config(model_path):
    quant_config_path = Path(model_path) / QUANTIZER_CONFIG_FILENAME
    if quant_config_path.exists():
        quantizer = load_from_file(quant_config_path)
        return quantizer
    else:
        return None

def load_model_and_tokenizer(
    model_id_or_path: str,
    quantizer: Optional[Quantizer] = None,
    model_dtype=torch.bfloat16,
    overwrite_hf_model_kwargs: Optional[Mapping[str, Any]] = None,
    overwrite_bnb_kwargs: Optional[Mapping[str, Any]] = None,
):
    model_cls = (
        AutoPeftModelForCausalLM
        if is_adapter_model(model_id_or_path)
        else AutoModelForCausalLM
    )
    device = "cuda" if torch.cuda.is_available() else "cpu"

    if get_quantizer_from_config(model_id_or_path) is not None and quantizer is not None:
        _logger.warning(
            f"Quantizer was specified twice (once by {QUANTIZER_CONFIG_FILENAME} and once as an argument). "
            "Make sure that the same quantizer is used for finetuning and evaluation."
            "Using quantizer from function argument."
        )
    elif (quantizer_from_config:= get_quantizer_from_config(model_id_or_path)) is not None:
        quantizer = quantizer_from_config
        _logger.info(f"Using quantizer from model config: {quantizer}")

    if quantizer is not None:
        quantizer.to(device)
        patch_bnb(quantizer)

        bnb_args = dict(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=model_dtype,
            llm_int8_enable_fp32_cpu_offload=False,
            bnb_4bit_use_double_quant=False,
        )

        bnb_config = BitsAndBytesConfig(**(bnb_args | (overwrite_bnb_kwargs or {})))
    else:
        bnb_config = None

    model_args = dict(
        quantization_config=bnb_config,
        device_map="auto",
        attn_implementation="flash_attention_2",
        torch_dtype=model_dtype,
    )

    model = model_cls.from_pretrained(
        model_id_or_path,
        **(model_args | (overwrite_hf_model_kwargs or {})),
    )
    model.config.torch_dtype = model_dtype

    # Make sure everything has the right datatype
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            module.to(model_dtype)

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_id_or_path, use_fast=False
    )

    if tokenizer.pad_token is None:
        smart_tokenizer_and_embedding_resize(
            special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
            tokenizer=tokenizer,
            model=model,
        )


    if tokenizer.eos_token_id is None:
        _logger.warning(f"tokenizer eos_token is not set, setting to eos token from model config {model.config.eos_token_id}")
        tokenizer.eos_token_id = model.config.eos_token_id
    if tokenizer.bos_token_id is None:
        _logger.warning(f"tokenizer bos_token is not set, setting to bos token from model config {model.config.bos_token_id}")
        tokenizer.bos_token_id = model.config.bos_token_id

    return model, tokenizer


def prepare_model_for_finetuning(
    model,
    quantizer: Optional[Quantizer] = None,
    lora_config: Optional[LoraConfig] = None,
):
    fix_untrained_tokens(model)

    if quantizer is not None:
        model = prepare_model_for_kbit_training(model)
        if LoraConfig is None:
            _logger.warning(
                "Model is prepared for quantized fine-tuning but no trainable LoRA-parameters are added. Make sure that the model is trainable"
            )

    if LoraConfig is not None:
        model = get_peft_model(model, lora_config)

    # Ensure everything has the correct datatype
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            module.to(torch.bfloat16)
        if "norm" in name:
            module.to(torch.float32)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                if module.weight.dtype == torch.float32:
                    module.to(torch.bfloat16)

    return model
