import os
import importlib

from LLMProxy import utils
from LLMProxy.option import ModelArg
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
)
from .model import Transformer
from .model_utils import register_model_forward_function

import torch


logger = utils.get_logger("MODEL")


MODEL_INITIALIZATION_CONFIGURATION = {}


def register_model_initialization(model_name: str, initialization: str):
    """
    Register model initialization for using different LLMs.

    When using HF checkpoint, please guarantee model_name is equal to HF checkpoint name
    """
    def register_model_initialization_func(func):
        if model_name not in MODEL_INITIALIZATION_CONFIGURATION:
            MODEL_INITIALIZATION_CONFIGURATION[model_name] = {}

        if initialization not in MODEL_INITIALIZATION_CONFIGURATION[model_name]:
            MODEL_INITIALIZATION_CONFIGURATION[model_name][initialization] = func
        else:
            raise ValueError(f"Cannot register duplicate {initialization} initialization for {model_name}")

    return register_model_initialization_func


def build_model_from_hf_checkpoint(args: ModelArg):
    """Build model from huggingface transformer architecture"""
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, token=args.auth, device_map='auto', torch_dtype=torch.bfloat16)
    return model


def build_model_from_hf_config(args: ModelArg):
    """Build model from huggingface transformer config"""
    config = AutoConfig.from_pretrained(args.model_name_or_path, token=args.auth)

    config.hidden_size = args.hidden_size
    config.intermediate_size = args.intermediate_size
    config.num_hidden_layers = args.num_hidden_layers
    config.num_attention_heads = args.num_attention_heads

    model = AutoModelForCausalLM.from_config(config)
    return model


def build_model(args: ModelArg, **kwargs):
    model = build_model_from_hf_checkpoint(args)
    model_name = model.config._name_or_path

    
    model_name = "meta-llama/Llama-2-7b-hf"
    

    if args.model_init is not None:
        if model_name not in MODEL_INITIALIZATION_CONFIGURATION:
            raise ValueError(f"Does not support this {model_name} initialization")
        model = MODEL_INITIALIZATION_CONFIGURATION[model_name][args.model_init](model)

    logger.info(f"{model}")
    logger.info(f"The total number of model parameters is {model.num_parameters()}")
    logger.info(f"The trainable parameter of model is {utils.count_trainable_parameters(model)}")
    return model


def import_model_configuration():
    configuration_file = os.path.join(os.path.dirname(__file__), "model_initialization.py")
    importlib.import_module("LLMProxy.models.model_initialization")


import_model_configuration()
