import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models.densenet import _densenet

def DenseNets(name="densenet121", num_classes=1860, pretrained=False):

    # if name=='densenet0':
    #     network = torchvision.models.densenet121()
    # elif name=='densenet1':
    #     network = torchvision.models.densenet161()
    # elif name=='densenet2':
    #     network = torchvision.models.densenet169()
    # elif name=='densenet3':
    #     network = torchvision.models.densenet201()
    # elif name=='densenet4':
    #     network = _densenet(growth_rate=32, block_config=(6, 12, 7*8, 5*8), num_init_features=64, weights=None, progress=True)
    # elif name=='densenet5':
    #     network = _densenet(growth_rate=32, block_config=(6, 12, 8*8, 6*8), num_init_features=64, weights=None, progress=True)
    # elif name=='densenet6':
    #     network = _densenet(growth_rate=32, block_config=(6, 12, 9*8, 7*8), num_init_features=64, weights=None, progress=True)
    # elif name=='densenet7':
    #     network = _densenet(growth_rate=32, block_config=(6, 12, 10*8, 8*8), num_init_features=64, weights=None, progress=True)

    size=int(name[8:])
    network = _densenet(growth_rate=32, block_config=(6, 12, (size+3)*8, (size+1)*8), num_init_features=64, weights=None, progress=True)


    if pretrained:
        network.load_state_dict(torch.load('../models/weights/weights_'+name+'.pth'))

    if network.classifier.out_features != num_classes:
        network.classifier = nn.Linear(in_features=network.classifier.in_features, out_features=num_classes)

    return network