import torch
from vgg import vgg16, vgg16_bn, vgg19
from torchvision.models.vgg import vgg16 as cleanVGG16
from torchvision.models.vgg import vgg19 as cleanVGG19
from mobilenet import mobilenet_v2
from parameters import get_matrix

def pad_zero(mixture):
    if mixture.dim() == 4:
        zero_size = (int(mixture.size(0) / 2), mixture.size(1), mixture.size(2), mixture.size(3))
    elif mixture.dim() == 1:
        zero_size = int(mixture.size(0) / 2)
    else:
        zero_size = (int(mixture.size(0) / 2), mixture.size(1))
    return torch.cat((mixture, torch.zeros(zero_size, device=mixture.device, dtype=mixture.dtype)), axis=0)


def copy_model(from_m, to_m):
    to_m.load_state_dict(from_m.state_dict())

def get_model(name, is_blind, device, dtype, num_classes=1000, pretrained=False):
    if name == 'vgg16' or name == 'VGG16':
        if is_blind:
            model = vgg16(num_classes=num_classes, device=device, dtype=dtype, use_bias=not is_blind)
            model.addNoise()
            model.initialize_weights()
            if pretrained:
                model.load_imageNet_weight('vgg16')
        else:
            model = cleanVGG16(num_classes=num_classes)
    elif name == 'vgg19' or name == 'VGG19':
        if is_blind:
            model = vgg19(num_classes=num_classes, device=device, dtype=dtype, use_bias=not is_blind)
            model.addNoise()
            model.initialize_weights()
            if pretrained:
                model.load_imageNet_weight('vgg19')
        else:
            model = cleanVGG19(num_classes=num_classes)
    elif name == 'mobilenet' or name == 'MobileNet':
        if is_blind:
            model = mobilenet_v2(num_classes=num_classes,
                          device=device, dtype=dtype,
                          pretrained=pretrained)
            bm, um, gm, igm = get_matrix(device, dtype)
            model.addNoise(bm, um, gm, igm)
        else:
            model = mobilenet_v2(num_classes=num_classes,
                      pretrained=pretrained)        
    else:
        print(name, "not recognized")
        assert(False)

    return model
def getdata_set_name(name):
    if name == "cifar10":
        return 1
    elif name == "cifar100":
        return 3
    else:
        assert(False)

def num_class(num):
    if num == 1:
        return 10
    elif num == 3:
        return 100
    else:
        assert(False)