import inspect
import torch
from dataset import get_dataset

from models.ModifiedEvidentialN import ModifiedEvidentialNet

create_model = {'menet': ModifiedEvidentialNet}


def load_model(directory_model, name_model, model_type, batch_size_eval=1024):
    model_path = directory_model + name_model
    map_location = None
    if not torch.cuda.is_available():
        map_location = "cpu"

    # Select arguments for model creation
    args = inspect.getfullargspec(create_model[model_type])[0][1:]
    config_dict = torch.load(f'{model_path}', map_location=map_location)['model_config_dict']
    seed = config_dict['seed'] if 'seed' in config_dict.keys() else config_dict['seed_dataset']
    _, _, _, config_dict['N'], _ = get_dataset(config_dict['dataset_name'],
                                            batch_size=config_dict['batch_size'],
                                            split=config_dict['split'],
                                            seed=seed,
                                            test_shuffle_seed=None,
                                            batch_size_eval=batch_size_eval)

    # filtered_config_dict = {arg: config_dict[arg] for arg in args}
    filtered_config_dict = {}
    for arg in args:
        if arg == 'seed' and 'seed' not in config_dict.keys():
            filtered_config_dict['seed'] = config_dict['seed_model']
        elif arg == 'kl_c' and 'kl_c' not in config_dict.keys():
            filtered_config_dict['kl_c'] = 0  # kl_c does not participate in testing
        elif arg == 'lamb1' and 'lamb1' not in config_dict.keys():
            filtered_config_dict['lamb1'] = 1
        elif arg == 'lamb2' and 'lamb2' not in config_dict.keys():
            filtered_config_dict['lamb2'] = 1
        elif arg == 'mix' and 'mix' not in config_dict.keys():
            filtered_config_dict['mix'] = False
        elif arg == 'mix_inter' and 'mix_inter' not in config_dict.keys():
            filtered_config_dict['mix_inter'] = False
        elif arg == 'mix_inter_alpha' and 'mix_inter_alpha' not in config_dict.keys():
            filtered_config_dict['mix_inter_alpha'] = 1.0
        elif arg == 'mix_inter_beta' and 'mix_inter_beta' not in config_dict.keys():
            filtered_config_dict['mix_inter_beta'] = 1.0
        elif arg == 'mix_noise' and 'mix_noise' not in config_dict.keys():
            filtered_config_dict['mix_noise'] = False
        elif arg == 'noise_mix_alpha' and 'noise_mix_alpha' not in config_dict.keys():
            filtered_config_dict['noise_mix_alpha'] = 1.0
        elif arg == 'noise_mix_beta' and 'noise_mix_beta' not in config_dict.keys():
            filtered_config_dict['noise_mix_beta'] = 1.0
        elif arg == 'noise_mix_ratio' and 'noise_mix_ratio' not in config_dict.keys():
            filtered_config_dict['noise_mix_ratio'] = 1.0
        elif arg == 'use_sample_wise_kl_weight' and 'use_sample_wise_kl_weight' not in config_dict.keys():
            filtered_config_dict['use_sample_wise_kl_weight'] = False
        elif arg == 'kl_start_epoch' and 'kl_start_epoch' not in config_dict.keys():
            filtered_config_dict['kl_start_epoch'] = 100
        else:
            try:    
                filtered_config_dict[arg] = config_dict[arg]
            except Exception as e:
                print(f"Error loading model: {e}")
                print(f"Config dict: {config_dict}")
                print(f"Arg: {arg}")

    # Create model
    model = create_model[model_type](**filtered_config_dict)

    # Load weights
    model.load_state_dict(torch.load(f'{model_path}', map_location=map_location)['model_state_dict'])
    model.eval()

    return model
