from .config import MODELS_CONFIG
from .train_state import TrainState
import numpy as np
import pickle

def load_model_by_name(MODEL_NAME,DATASET,MODELS_DIR,NUM_CLASSES,device):
    if MODELS_CONFIG[MODEL_NAME]['model'] == 'resnetV2':
        from models.resnetV2_jax import ResNetV2 as ResNetV2jax
        from models.resnetV2_torch import ResNetV2 as ResNetV2torch

        jax_model = ResNetV2jax(num_classes=NUM_CLASSES,**MODELS_CONFIG[MODEL_NAME]['jax_kwargs'])
        jax_params = jax_model.load_params(MODELS_DIR+MODELS_CONFIG[MODEL_NAME]['saved_params'][DATASET])
        jax_model = TrainState.create(jax_model,jax_params)

        torch_model = ResNetV2torch(head_size=NUM_CLASSES,**MODELS_CONFIG[MODEL_NAME]['torch_kwargs'])
        torch_params = np.load(MODELS_DIR+MODELS_CONFIG[MODEL_NAME]['saved_params'][DATASET])
        torch_model.load_params(torch_params)
        torch_model.to(device)
        torch_model = torch_model.eval()
    else:
        from models.resnet_jax import ResNet as ResNetJax
        from models.resnet_torch import ResNet as ResNetTorch

        jax_model = ResNetJax(num_classes=NUM_CLASSES,**MODELS_CONFIG[MODEL_NAME]['jax_kwargs'])
        with open(MODELS_DIR+MODELS_CONFIG[MODEL_NAME]['saved_params'][DATASET], 'rb') as f:
            jax_params = pickle.load(f)
        jax_model = TrainState.create(jax_model,jax_params)

        torch_model = ResNetTorch(num_classes=NUM_CLASSES,**MODELS_CONFIG[MODEL_NAME]['torch_kwargs'])
        torch_model = torch_model.load_from_jax_params(jax_params)
        torch_model = torch_model.to(device)
        torch_model = torch_model.eval()

    return {'jax': jax_model, 
            'torch': torch_model}

