import torchvision.transforms as transforms
import numpy as np

from models.densenet import DenseNet3
from models.mobilenetv2 import MobileNetV2
from models.vgg16 import vgg16
from models.vit import vit_b_16
from models.densenet_dice import DenseNet3_Dice
from models.mobilenetv2_dice import MobileNetV2_Dice
from models.densenet_knn import DenseNet3_KNN


def build_model(config, num_classes, p=None):
    model_name = config['model_name']
    coeff_file = config['coeff_file']
    if model_name == 'mobilenetv2':
        model = MobileNetV2(num_classes)
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return model, transform
    if model_name == 'mobilenetv2_dice':
        info = np.load(f"./cache/imagenet_mobilenet_feat_stat.npy")
        model = MobileNetV2_Dice(num_classes, info=info, p=p)
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return model, transform
    if model_name == 'densenet100':
        model = DenseNet3(100, int(num_classes), coeff_file=coeff_file)
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            # transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)),
        ])
        return model, transform
    if model_name == 'vgg16':
        model = vgg16(pretrained=True)
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return model, transform
    if model_name == 'vit':
        model = vit_b_16(pretrained=True)
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        return model, transform
    if model_name == 'densenet100_dice':
        if int(num_classes) == 10:
            info = np.load(f"./cache/CIFAR-10_densenet_feat_stat.npy")
        else:
            info = np.load(f"./cache/CIFAR-100_densenet_feat_stat.npy")
        model = DenseNet3_Dice(100, int(num_classes), info=info, p=p)
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            # transforms.Normalize((125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255.0, 66.7 / 255.0)),
        ])
        return model, transform
    if model_name == 'densenet100_knn':
        model = DenseNet3_KNN(100, num_classes, 12, reduction=0.5, bottleneck=True,
                                     dropRate=0, normalizer=None, method='', p=p)
        transform = transforms.Compose([
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        return model, transform
    exit('{} model is not supported'.format(model_name))
