import glob
import os
import sys
import importlib

class Registry:
    mapping = {
        "dataset": {},
        "vision_model": {},
        "language_model": {},
        "other_model": {},
        "model_assembler": {},
        "optimizer": {},
        "scheduler": {},
        "optimize_assembler": {},
        "utils" : {}
        
    }
    
    @classmethod
    def register_dataset(qr, name):
        def wrap(trainer_qr):
            qr.mapping["dataset"][name] = trainer_qr
            return trainer_qr
        
        return wrap
        
    @classmethod
    def register_vision_model(qr, name):
        def wrap(trainer_qr):
            qr.mapping["vision_model"][name] = trainer_qr
            return trainer_qr
        
        return wrap
    
    @classmethod
    def register_language_model(qr, name):
        def wrap(trainer_qr):
            qr.mapping["language_model"][name] = trainer_qr
            return trainer_qr
        
        return wrap
    
    @classmethod
    def register_other_model(qr, name):
        def wrap(trainer_qr):
            qr.mapping["other_model"][name] = trainer_qr
            return trainer_qr
        
        return wrap
    
    @classmethod
    def register_model_assembler(qr, name):
        def wrap(trainer_qr):
            qr.mapping["model_assembler"][name] = trainer_qr
            return trainer_qr
        
        return wrap
    
    @classmethod
    def register_optimizer(qr, name):
        def wrap(trainer_qr):
            qr.mapping["optimizer"][name] = trainer_qr
            return trainer_qr
        
        return wrap
    
    @classmethod
    def register_scheduler(qr, name):
        def wrap(trainer_qr):
            qr.mapping["scheduler"][name] = trainer_qr
            return trainer_qr
        
        return wrap
        
    @classmethod
    def register_optimize_assembler(qr, name):
        def wrap(trainer_qr):
            qr.mapping["optimize_assembler"][name] = trainer_qr
            return trainer_qr
        
        return wrap

    @classmethod
    def register_utils(qr, name):
        def wrap(trainer_qr):
            qr.mapping["utils"][name] = trainer_qr
            return trainer_qr
        
        return wrap
        
    @classmethod
    def get_dataset(qr, name):
        return qr.mapping["dataset"][name]

    @classmethod
    def get_vision_model(qr, name):
        return qr.mapping["vision_model"][name]
    
    @classmethod
    def get_language_model(qr, name):
        return qr.mapping["language_model"][name]
     
    @classmethod
    def get_other_model(qr, name):
        return qr.mapping["other_model"][name]
    
    @classmethod
    def get_model_assembler(qr, name):
        return qr.mapping["model_assembler"][name]
        
    @classmethod
    def get_optimizer(qr, name):
        return qr.mapping["optimizer"][name]

    @classmethod
    def get_scheduler(qr, name):
        return qr.mapping["scheduler"][name]
     
    @classmethod
    def get_optimize_assembler(qr, name):
        return qr.mapping["optimize_assembler"][name]
    
    @classmethod
    def get_utils(qr, name):
        return qr.mapping["utils"][name]
    
registry = Registry()

def setup_imports(base_folder="./"):
    abs_base = os.path.abspath(base_folder)

    while abs_base in sys.path:
        sys.path.remove(abs_base)
    sys.path.insert(0, abs_base)

    print(f"[setup_imports] abs_base: {abs_base}")
    print(f"[setup_imports] sys.path: {sys.path[:3]}")
    print(f"[setup_imports] cwd: {os.getcwd()}")

    
    folder_list = ["dataset", "model", "optimization", "pipeline", "utils"]
    files = sum([glob.glob(os.path.join(base_folder, folder) + "/**", recursive=True) for folder in folder_list], [])
    for f in files:
        if f.endswith(".py") and not "setup.py" in f:
            
            splits = f.split(os.sep)[1:]
            file_name = splits[-1]
            module_name = file_name[: file_name.find(".py")]
            module = ".".join(splits[0:-1] + [module_name])
            importlib.import_module(module)
    

if __name__ == "__main__":
    setup_imports()
    
