import sys
from collections import OrderedDict

import torch
from torchvision import models

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


def load_model(arch, checkpoint=None, dropout_rate=0):
    if arch == "spectral":
        sys.path.append("..")
        from spectral_normalized_models.resnet import ResNet50
        model = ResNet50(use_sn=True, num_classes=1000)
    elif arch == 'resnet50':
        model = models.resnet50()
    elif arch == 'dropout':
        pass  # TODO
        # model = models.__dict__['resnet50'](num_classes=1000)
        # from spectral_normalized_models.resnet import resnet50_dropout, ResNet50
        # model = resnet50_dropout(use_sn=True, dropout_rate=dropout_rate, last_layer=False, num_classes=1000)
        # model = ResNet50(use_sn=True, num_classes=1000)
        # model = resnet50_dropout(dropout_rate=dropout_rate, num_classes=1000, use_sn=False)
        # from face_uncertainty.models.resnet import resnet50_dropout
        # model = resnet50_dropout(dropout_rate=0.02, num_classes=1000)

    if checkpoint is None:
        _arch = {
            'resnet50': 'resnet50',
            'dropout': 'spectral',
            'spectral': 'spectral',
            'ddu': 'spectral'
        }[arch]
        checkpoint = f'checkpoint/model_best_{_arch}.pth.tar'

    print(f"=> loading checkpoint '{checkpoint}'")
    checkpoint = torch.load(checkpoint, map_location=device)
    state_dict = checkpoint["state_dict"]

    state_dict_clean = OrderedDict()
    for key, value in state_dict.items():
        if key.startswith('module.'):
            key = key[7:]
        state_dict_clean[key] = value
    model.load_state_dict(state_dict_clean)
    model.eval()
    model = model.to(device)
    return model

