import inspect
import warnings
import math
from collections import defaultdict

import torch
from huggingface_hub import HfApi, hf_hub_download
from safetensors import safe_open


def smart_tokenizer_and_embedding_resize(
    special_tokens_dict,
    tokenizer,
    model,
):
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))
    with torch.no_grad():
        if num_new_tokens > 0:
            input_embeddings_data = model.get_input_embeddings().weight.data
            output_embeddings_data = model.get_output_embeddings().weight.data

            input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(
                dim=0, keepdim=True
            )
            output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(
                dim=0, keepdim=True
            )

            input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
            output_embeddings_data[-num_new_tokens:] = output_embeddings_avg

def get_total_param_memory_of_quantized_model(model, block_size):
    total_mem = 0
    for name, param in model.named_parameters():
        total_mem += param.nelement() * param.element_size()
        if param.dtype == torch.uint8:
            total_mem += math.ceil(param.nelement()*2 / block_size) * 4
    return total_mem

def print_parameter_info(model):
    dtypes = defaultdict(int)
    trainable = 0
    for name, param in model.named_parameters():
        n = param.numel()
        if param.dtype == torch.uint8:
            # NF4 datatype stores 2 parameters per torch.uint8
            n *= 2
        dtypes[param.dtype] += n
        if param.requires_grad:
            trainable += n
        print(
            f"{name}, datatype: {param.dtype}, #params: {n}, trainable: {param.requires_grad}"
        )
    total = sum(dtypes.values())
    print(f"Total number of parameters: {total}")
    print(f"Trainable: {trainable} ({trainable / total})")
    for k, v in dtypes.items():
        print(f"{k}: {v} ({v / total})")
    print(
        "Params mem:",
        sum([param.nelement() * param.element_size() for param in model.parameters()]),
    )
    print(
        "Adapter params mem:",
        sum(
            [
                param.nelement() * param.element_size()
                for name, param in model.named_parameters()
                if "lora" in name
            ]
        ),
    )


@torch.inference_mode
def fix_untrained_tokens(model, eps=1e-16):
    """
    Adapted from https://github.com/unslothai/unsloth/blob/main/unsloth/tokenizer_utils.py#L556
    Llama-3 for example has untrained vectors in the base model.
    These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
    We reset them to the mean of the rest of the tokens
    """
    embedding_matrix = model.get_input_embeddings().weight.data
    lm_head_matrix = model.get_output_embeddings().weight.data

    # Get untrained tokens
    indicator_untrained = torch.amax(embedding_matrix, dim=1) <= eps
    where_untrained = torch.where(indicator_untrained)[0]
    n_untrained = where_untrained.shape[0]
    n_trained = embedding_matrix.shape[0] - n_untrained
    if n_untrained != 0:
        print(
            f"The used model has {n_untrained} untrained tokens.\n"
            "We shall set them to the mean of the other trained tokens."
        )

    # First set untrained to all 0s - sometimes it's not! 1e-23 for bfloat16
    embedding_matrix[where_untrained] = 0
    lm_head_matrix[where_untrained] = 0

    # Find sum
    sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, dim=0)
    sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, dim=0)

    # Find correct average by dividing by sum of trained tokens
    mean_embedding = (sum_embedding / n_trained).to(embedding_matrix.dtype)
    mean_lm_head = (sum_lm_head / n_trained).to(lm_head_matrix.dtype)

    # Set them to the mean
    embedding_matrix[where_untrained] = mean_embedding
    lm_head_matrix[where_untrained] = mean_lm_head

    return mean_embedding, mean_lm_head


def prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs=None,
):
    r"""
    Modified function from `peft` to avoid temporary upcasting of lm_head and embed_tokens layers
     during model preparation causing OOM error.

    This method wraps the entire protocol for preparing a model before running a training. This includes:
        1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
        head to fp32

    Args:
        model (`transformers.PreTrainedModel`):
            The loaded model from `transformers`
        use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
            If True, use gradient checkpointing to save memory at the expense of slower backward pass.
        gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
            Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
            `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
            Note this is only available in the latest transformers versions (> 4.34.1).
    """
    loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(
        model, "is_loaded_in_4bit", False
    )
    is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
    is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
    is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"
    is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(
        model, "hqq_quantized", False
    )

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {}

    for name, param in model.named_parameters():
        # freeze base model's layers
        param.requires_grad = False

    ####################### Modification start ################################
    fixed_datatype_params = [
        id(param)
        for name, param in model.named_parameters()
        if "lm_head" in name or "embed_tokens" in name
    ]

    if (
        not is_gptq_quantized
        and not is_aqlm_quantized
        and not is_eetq_quantized
        and not is_hqq_quantized
    ):
        # cast all non INT8 parameters to fp32
        for param in model.parameters():
            if (
                ((param.dtype == torch.float16) or (param.dtype == torch.bfloat16))
                and param.__class__.__name__ != "Params4bit"
                and id(param) not in fixed_datatype_params
            ):
                param.data = param.data.to(torch.float32)
    ####################### Modification end ################################

    if (
        loaded_in_kbit
        or is_gptq_quantized
        or is_aqlm_quantized
        or is_eetq_quantized
        or is_hqq_quantized
    ) and use_gradient_checkpointing:
        # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
        if (
            "use_reentrant" not in gradient_checkpointing_kwargs
            or gradient_checkpointing_kwargs["use_reentrant"]
        ):
            # For backward compatibility
            if hasattr(model, "enable_input_require_grads"):
                model.enable_input_require_grads()
            else:

                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                model.get_input_embeddings().register_forward_hook(
                    make_inputs_require_grad
                )

        # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
        _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
            inspect.signature(model.gradient_checkpointing_enable).parameters
        )

        if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
            warnings.warn(
                "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
                " if you want to use that feature, please upgrade to the latest version of transformers.",
                FutureWarning,
            )

        gc_enable_kwargs = (
            {}
            if not _supports_gc_kwargs
            else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
        )

        # enable gradient checkpointing for memory efficiency
        model.gradient_checkpointing_enable(**gc_enable_kwargs)
    return model


def load_layer_from_safetensors(repo_id: str, layer_name: str, framework: str = "pt", device="cuda"):
    """
    Loads a specific layer (tensor) by name from a Hugging Face model repository that contains
    safetensors files.

    Args:
        repo_id (str): The Hugging Face model repository ID (e.g., "username/model-name").
        layer_name (str): The key/name of the layer/tensor to load.
        framework (str): The target framework for the tensor ("pt" for PyTorch, "tf" for TensorFlow, etc.)

    Returns:
        The loaded tensor corresponding to layer_name.

    Raises:
        ValueError: If the specified layer is not found in any of the safetensors files.
    """
    api = HfApi()
    all_files = api.list_repo_files(repo_id)

    safetensor_files = [f for f in all_files if f.endswith(".safetensors")]

    for filename in safetensor_files:
        file_path = hf_hub_download(repo_id=repo_id, filename=filename)
        with safe_open(file_path, framework=framework, device=device) as f:
            if layer_name in f.keys():
                return f.get_tensor(layer_name)

    raise ValueError(f"Layer '{layer_name}' was not found in any of the safetensors files in {repo_id}.")

