import torch
from peft import prepare_model_for_kbit_training
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
)

from egu.utils.utils import load_yaml


class HFModel:
    def __init__(
        self,
        model_name,
        model_path=None,
        config_path="./config",
        generation_config=None,
        ref=False,
    ):
        self.model_name = model_name
        self.model_config = load_yaml(f"{config_path}/{model_name}.yaml")
        use_kbit = self.model_config.get(
            "load_in_4bit", False
        ) or self.model_config.get("load_in_8bit", False)

        quantization_config = None
        # quantization_config = (
        #     BitsAndBytesConfig(
        #         load_in_4bit=self.model_config["load_in_4bit"],
        #         load_in_8bit=self.model_config["load_in_8bit"],
        #         bnb_4bit_quant_type=(
        #             "nf4" if self.model_config.get("load_in_4bit", False) else None
        #         ),
        #         bnb_4bit_use_double_quant=(
        #             True if self.model_config.get("load_in_4bit", False) else None
        #         ),
        #         bnb_4bit_compute_dtype=torch.float16,
        #     )
        #     if self.model_config["load_in_4bit"] or self.model_config["load_in_8bit"]
        #     else None
        # )
        if use_kbit:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=self.model_config.get("load_in_4bit", False),
                load_in_8bit=(
                    True if ref else self.model_config.get("load_in_8bit", False)
                ),
                # bnb_4bit_quant_type=(
                #     "nf4" if self.model_config.get("load_in_4bit", False) else None
                # ),
                # bnb_4bit_use_double_quant=(
                #     True if self.model_config.get("load_in_4bit", False) else None
                # ),
                # bnb_4bit_compute_dtype=torch.float16,
            )
        if ref:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=self.model_config.get("load_in_4bit", False),
                load_in_8bit=(True if ref else False),
            )

        torch_dtype = None if use_kbit else torch.float16

        model_args = {
            # "torch_dtype": torch.float16,
            # "attn_implementation": self.model_config["attn_implementation"],
            # "device_map": "auto",
            "quantization_config": quantization_config,
            "torch_dtype": torch_dtype,
            "trust_remote_code": (
                False
                if "c4ai-command-r-v01" in model_name.lower()
                or "falcon" in model_name.lower()
                or "phi-1_5" in model_name.lower()
                else True
            ),
        }

        self.model = AutoModelForCausalLM.from_pretrained(
            model_path if model_path else self.model_config["hf_name"], **model_args
        )

        if use_kbit:
            self.model = prepare_model_for_kbit_training(
                self.model, use_gradient_checkpointing=True
            )
            self.model.config.use_cache = False
            if hasattr(self.model, "enable_input_require_grads"):
                self.model.enable_input_require_grads()

        num_parameters = sum(p.numel() for p in self.model.parameters())
        print(f"Number of parameters: {num_parameters}")

        self.tokenizer = AutoTokenizer.from_pretrained(
            (
                self.model_config["hf_name"]
                if "openelm" not in model_name.lower()
                else "meta-llama/Llama-2-7b-hf"
            ),
        )
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        if getattr(self.tokenizer, "padding_side", None) != "right":
            self.tokenizer.padding_side = "right"

        self.model.generation_config = (
            GenerationConfig(do_sample=False, use_cache=True)
            if generation_config is None
            else generation_config
        )
        self.device = self.model.device
        self.generation_config = self.model.generation_config
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # prevent the error caused by padding_side for qwen model
        # if "qwen" in model_name.lower() or "starcoder2" in model_name.lower():
        #     self.tokenizer.padding_side = "left"

    def __call__(self, *args, **kwargs):
        # Remove the "prompts" key from the kwargs if it exists
        for key in ["prompts", "answers"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model(*args, **kwargs)

    def generate(self, *args, **kwargs):
        # Remove the "prompts" key from the kwargs if it exists
        for key in ["prompts"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model.generate(*args, **kwargs)
