from models.resnet_cifar import *
from models.densenet_cifar import *
from models.mobilenet_v2_cifar import *
from models.wide_resnet_cifar import *
from models.densenet_cifar_real import *
from models.imagenet_models import *
from models.efficentnet import *
from models.vgg19_cifar import vgg19_bn
from models.densenet_40_slim import densenet_40_slim
import torch

class model_stats:
    def __init__(self,model):
        #umber of parameters in the model
        params = 0
        for layer in list(model.parameters()):
            counter = 1
            for size in list(layer.size()):
                counter = counter * size
            params += counter
        self.num_params = params

        # filter for convolutional
        if model.model_type == 'mobilenetv2_I':
            layer_names = list(dict(model.named_parameters()).keys())[::3]
            layer_names = layer_names[:-1]
        elif model.model_type == 'vgg19':
            layer_names = list(dict(model.named_parameters()).keys())[::3]
            #todo:change this back do not forget ..............
            layer_names = layer_names[:-2]
        else:
            layer_names = []
            for lay_name in dict(model.named_parameters()).keys():
                initial = lay_name.split('.')[-2][0]
                #thiy might be adapted to duifferent neural networks ......
                if initial == 'c': # or initial == '0':
                    layer_names.append(lay_name)
        self.layer_names = layer_names

        # calculate number of layers
        self.num_layers = len(layer_names)

        # calculate total number of layers
        self.total_num_layers = len(dict(model.named_parameters()).keys())

        #calculate number of structures
        num_strucs = 0
        for name in layer_names:
            lay = dict(model.named_parameters())[name]
            num_strucs += lay.shape[0]
        self.num_strucs = num_strucs

        # define the total number of prunable parameters and their shape
        prunable_parameters = []
        for name in layer_names:
            prunable_parameters.append(dict(model.named_parameters())[name])
        self.prunable_parameters = prunable_parameters

        nn_shapes_prunable_layers = []
        for layer in prunable_parameters:
            nn_shapes_prunable_layers.append(layer.shape)
        self.nn_shapes_prunable_layers = nn_shapes_prunable_layers

    def print_model_stats(self):
        print('----------------------------------------------------------')
        #print('The model prunable layers are: ', str(self.layer_names))
        print('The model has parameters: ', str(self.num_params))
        print('The model has prunable layers: ', str(self.num_layers))
        print('The model has in total layers: ', str(self.total_num_layers))
        print('The model has structures: ', str(self.num_strucs))
        #print('The number of channels per layer are: ', str(self.nn_shapes_prunable_layers))
        print('----------------------------------------------------------')


def create_model_by_name(model_name,device,num_classes = 10,width_factor = 1,pretrained_model_path=None,load_pretrained=False,mult_arr= None):
    """

    :param model_name: A string repsenting a name of a model-class defined
    :return: a torch nn model
    """

    model_class = globals()[model_name]


    #todo: change this to make also loading cifar data possible
    if load_pretrained==True and num_classes == 1000:
        net = model_class(load_pretrained=load_pretrained,pretrained_model_path=pretrained_model_path)
    else:
        if width_factor != 1:
            if mult_arr is None:
                net = model_class(width_factor = width_factor, num_classes = num_classes)
            else:
                net = model_class(mult_arr,width_factor = width_factor, num_classes = num_classes)
        else:
            if mult_arr is None:
                net = model_class(num_classes=num_classes)
            else:
                net = model_class(mult_arr,num_classes=num_classes)

        if load_pretrained==True:
            print('Loading pretrained model!')
            net.load_state_dict(torch.load(pretrained_model_path))

    net.to(device)
    return net


