import os
import importlib
import torch

from transformers import AutoModelForCausalLM

from CoLM import utils
from CoLM.configs import TrainArg

MODEL_CONFIGURATION_REGISTRY = {}
MODEL_REGISTRY = {}
MODEL_ARCHITECTURE_REGISTRY = {}
ARCHITECTURE_TO_MODEL_REIGSTRY = {}

def register_model_configuration(model_config_name: str):
    
    def register_model_configurtion_cls(cls):
        if model_config_name in MODEL_CONFIGURATION_REGISTRY:
            raise ValueError(f"Cannot register duplicate {model_config_name} model configuration.")
        
        MODEL_CONFIGURATION_REGISTRY[model_config_name] = cls
        return cls

    return register_model_configurtion_cls

def register_model(model_name: str):
    
    def register_model_cls(cls):
        if model_name in MODEL_REGISTRY:
            raise ValueError(f"Cannot register duplicate {model_name} model configuration.")
        
        MODEL_REGISTRY[model_name] = cls
        return cls

    return register_model_cls


def register_model_architecture(model_name: str, model_architecture: str):
    
    def register_model_arch_fn(fn):
        if model_name not in MODEL_REGISTRY:
            raise ValueError(f"{model_name} is an invalid model architecture, please re-register model.")
        if model_architecture in MODEL_ARCHITECTURE_REGISTRY:
            raise ValueError(f"Cannot register duplicate {model_architecture} model architecture.")
        MODEL_ARCHITECTURE_REGISTRY[model_architecture] = fn
        ARCHITECTURE_TO_MODEL_REIGSTRY[model_architecture] = model_name
    return register_model_arch_fn


def build_model_configuration(architecture_name: str, vocab_size: int = -1):
    model_name = ARCHITECTURE_TO_MODEL_REIGSTRY[architecture_name]
    args = MODEL_CONFIGURATION_REGISTRY[model_name]()
    args = MODEL_ARCHITECTURE_REGISTRY[architecture_name](args)
    args.vocab_size = vocab_size
    return args


def build_model(architecture_name: str, vocab_size: int = -1, token: str = None) -> torch.nn.Module:
    if architecture_name in MODEL_ARCHITECTURE_REGISTRY:
        args = build_model_configuration(architecture_name=architecture_name, vocab_size=vocab_size)
        model_name = ARCHITECTURE_TO_MODEL_REIGSTRY[architecture_name]
        return MODEL_REGISTRY[model_name].build_model(args)
    return AutoModelForCausalLM.from_pretrained(architecture_name, token=token)


def build_model_and_tokenizer(args: TrainArg):
    """Build model and tokenizer"""
    tokenizer = utils.get_tokenizer(args.tokenizer, args.auth)
    model = build_model(
        architecture_name=args.model_name_or_path, 
        vocab_size=len(tokenizer),
        token=args.auth,
    )
    return model, tokenizer


def import_models(model_dir: str):

    for file in os.listdir(model_dir):
        path = os.path.join(model_dir, file)

        if (
            not file.startswith("_")
            and not file.startswith(".")
            and file.endswith(".py")
        ):
            model_name = file[:file.find(".py")]
            importlib.import_module("CoLM.models." + model_name)


import_models(os.path.dirname(__file__))