from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from torch import float16
import numpy as np
from src.logger import logger


def count_trainable_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params



def get_model(
    model_name,
    quant=None
):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    if quant == "4bit":
        logger.warning("Loading model in 4 bit")
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=float16,
        )
        return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, quantization_config=quantization_config).cuda(), tokenizer
    elif quant == "8bit":
        logger.warning("Loading model in 8 bit")
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=0.0,
        )
        return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, quantization_config=quantization_config).cuda(), tokenizer
    logger.warning("Loading model in fp16")
    return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype=float16).cuda(), tokenizer
        
