import os
import importlib
import torch

# Dictionnary of available models
MODEL_DICT = {
    # Modèles PyTorch
    "torch": ["linear", 
              "polynomial",
              "mlp2",
              "mlp2tanh",
              "mlp3",
              "mlp3tanh",
              "cnn"],
    "skl": ["rdmforest"]
}

class ModelLoader:
    def __init__(self, 
                 model_name: str = "linear", 
                 params: dict = {'in_features': 22, 'out_features': 12},
                 seed: int = 42,
                 device:str = torch.device("cpu")):
        
        self.device = device
        self.model_name = model_name.lower().split("_")[0] 
        self.params = params    
        self.seed = seed    

        assert not ((self.model_name in MODEL_DICT["torch"]) and (self.model_name in MODEL_DICT["skl"]))
        if self.model_name in MODEL_DICT["torch"]:
            self.model_type = "torch" 
        elif self.model_name in MODEL_DICT["skl"]:
            self.model_type = "skl"
        else:
            raise ValueError(f"The model {self.model_name} does not exist according to the ModelLoader")

    def load_model(self):
        filename = f"src/models/model/{self.model_name}.py"
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Le fichier {filename} n'existe pas.")

        spec = importlib.util.spec_from_file_location(f"{self.model_name}.py", filename)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)

        # Find the only class in the module
        classes = [cls for cls in module.__dict__.values() if isinstance(cls, type)]
        model_class = classes[0]
        torch.manual_seed(self.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.seed)
        model = model_class(**self.params).to(self.device)

        return model, self.model_type
    
    

    
# tests
# import torch
# model_tmp, model_type = ModelLoader(model_name = "linear", params= {'input_dim': 22, 'output_dim': 12}).load_model()
# print(model_tmp.forward(torch.ones(22)))
