import torch

from godin.nets.deconfnet import CosineDeconf, InnerDeconf, EuclideanDeconf, \
    DeconfNetOOD
from godin.nets.densenet import DenseNet
from godin.nets.resnet import ResNet34
from godin.nets.wideresnet import WideResNet

h_dict = {
    'cosine': CosineDeconf,
    'inner': InnerDeconf,
    'euclid': EuclideanDeconf
}


def get_model(architecture='resnet', similarity='cosine', num_classes=10,
              model_dir='./models', device=0, data_in='CIFAR10', val_set=None,
              percentile_threshold=None):
    if architecture == 'densenet':
        underlying_net = DenseNet(num_classes=num_classes)
    elif architecture == 'resnet':
        underlying_net = ResNet34(num_classes=num_classes)
    elif architecture == 'wideresnet':
        underlying_net = WideResNet(num_classes=num_classes)
    else:
        raise Exception(f"Unsupported architecture: {architecture}.")

    underlying_net.to(device)

    # Construct g, h, and the composed deconf net
    if similarity == 'baseline':
        baseline = True
    else:
        baseline = False

    if baseline:
        h = InnerDeconf(in_features=underlying_net.output_size,
                        num_classes=num_classes)
    else:
        h = h_dict[similarity](in_features=underlying_net.output_size,
                               num_classes=num_classes)

    h.to(device)

    deconf_net = DeconfNetOOD(
        underlying_model=underlying_net, in_features=underlying_net.output_size,
        num_classes=num_classes, h=h, baseline=baseline, noise_magnitude=0.005)

    deconf_net.to(device)

    parameters = []
    h_parameters = []
    for name, parameter in deconf_net.named_parameters():
        if name == 'h.h.weight' or name == 'h.h.bias':
            h_parameters.append(parameter)
        else:
            parameters.append(parameter)

    print('Test the model.')

    suffix = f"-{similarity}-{architecture}-{data_in}"
    file_name = f'{model_dir}/model{suffix}.pth'
    print(f'Loading model: {file_name}.')
    deconf_net.load_state_dict(
        torch.load(file_name, map_location=f'cuda:{device}'))

    deconf_net.eval()

    if percentile_threshold is not None:
        deconf_net.set_threshold(percentile_threshold=percentile_threshold,
                                 val_set=val_set, device=device)
    print(f'deconf_net.threshold: {deconf_net.threshold}')

    return deconf_net
