
import os
import pickle
from torch import load
from ..model import * 
from ..model.config import Config
# from pprint import pprint

def load_model(model_file: str, module=None, device='cuda'):
    dir_name, file_name = os.path.split(model_file)
    
    results_file_name = file_name.replace(f'model_', f'results_').replace('.pt', '.pkl')
    results_file = os.path.join(dir_name, results_file_name)
    with open(results_file, 'rb') as f:
        results = pickle.load(f)
    if module is not None:
        exec(f"from {module} import {results['config']['model']}")
        
    model_class = eval(results['config']['model'])

    model_config = Config(**results['config'])
    # pprint(model_config.__dict__)
    loaded_model = model_class(model_config)
    loaded_model.load_state_dict(load(model_file, map_location=device))
    loaded_model = loaded_model.eval()
    loaded_model = loaded_model.to(device)
    return loaded_model, results, model_config
    