import ipdb

class ModelFactory():
    def __init__(self):
        pass

    @staticmethod
    def get_model(model_type, sizes, dataset='mnist', args=None):
        parametric_normalization = args.parametric_normalization
        net_list = []
        if "mnist" in dataset:
            if model_type=="linear":
                for i in range(0, len(sizes) - 1):
                    net_list.append(('linear', [sizes[i+1], sizes[i]], ''))
                    if i < (len(sizes) - 2):
                        net_list.append(('relu', [True], ''))
                    if i == (len(sizes) - 2):
                        net_list.append(('rep', [], ''))
                return net_list

        elif dataset == "tinyimagenet":

            if model_type == 'pc_cnn':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [640, 16 * channels], ''),
                    ('relu', [True], ''),

                    ('linear', [640, 640], ''),
                    ('relu', [True], ''),
                    ('linear', [sizes[-1], 640], '')
                ]

            if model_type == 'pc_cnn_extention':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [640, 16 * channels], ''),
                    ('relu', [True], ''),

                    ('linear', [640, 640], ''),
                    ('non-linear', [640, 640], ''),
                    ('linear', [sizes[-1], 640], '')
                ]

            if model_type == 'pc_cnn_extention[1]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [640, 16 * channels], ''),
                    ('non-linear', [640, 640], ''),
                    #('relu', [True], ''),

                    ('linear', [640, 640], ''),
                    #('non-linear', [640, 640], ''),
                    ('relu', [True], ''),
                    ('linear', [sizes[-1], 640], '')
                ]

            if model_type == 'pc_cnn_extention[2]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [640, 16 * channels], ''),
                    ('relu', [True], ''),

                    ('non-linear', [640, 640], ''),

                    ('linear', [sizes[-1], 640], '')
                ]

        elif ((dataset == "cifar100") or (dataset == "cifar10")):

            if model_type == 'pc_cnn':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),

                    ('linear', [320, 320], ''),
                    ('relu', [True], ''),
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[nn-doub]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-doub', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
		    ('non-linear', [320, 320], ''),
                    #('relu', [True], ''),
                    
		    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[nn-l5-l6]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('non-linear', [320, 16 * channels], ''),
                    #('linear', [320, 16 * channels], ''),
                    #('relu', [True], ''),
                    
                    ('non-linear', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[on-2nd-back]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('non-linear', [320, 16 * channels], ''),
                    #('relu', [True], ''),

                    ('linear', [320, 320], ''),
                    ('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention_nlt[soft]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-soft', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention_v0':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-v0', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention-nl-full':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-full', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention_v1':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-v1', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention_v2':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-v2', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention_v3':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-relu', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]
            if model_type == 'pc_cnn_extention[linear-simple]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),

                    ('linear-simple', [320, 320], ''),
                    #('only-sparse', [], ''),
                    #(parametric_normalization, [320, 320], ''),
                    ('linear', [sizes[-1], 320], '')
                ]
            if model_type == 'pc_cnn_extention[linear-simple-n]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),

                    #('relu', [True], ''),
                    #('non-linear-relu', [320, 320], ''),
                    ('linear', [320, 320], ''),
                    (parametric_normalization, [320, 320], ''),
                    #('non-linear-sparse', [320, 320], ''),
                    ('relu', [True], ''),
                    #('only-sparse', [], ''),
                    ('linear', [sizes[-1], 320], '')
                ]                
            if model_type == 'pc_cnn_extention[relu]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-relu', [320, 320], ''),

                    ('linear', [sizes[-1], 320], '')
                ]
            if model_type == 'pc_cnn_extention[sparse]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-sparse', [320, 320], ''),

                    ('linear', [sizes[-1], 320], '')
                ]
            if model_type == 'pc_cnn_extention-non-linear-[e-no-copy]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-[e-no-copy]', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[random]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-random', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[norm_const]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('non-linear-norm-const', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]

            if model_type == 'pc_cnn_extention[linear-[v0]]':
                channels = 160
                return [
                    ('conv2d', [channels, 3, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),
                    
                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('conv2d', [channels, channels, 3, 3, 2, 1], ''),
                    ('relu', [True], ''),

                    ('flatten', [], ''),
                    ('rep', [], ''),

                    ('linear', [320, 16 * channels], ''),
                    ('relu', [True], ''),
                    
                    ('linear-[v0]', [320, 320], ''),
                    #('relu', [True], ''),
                    
                    ('linear', [sizes[-1], 320], '')
                ]
        else:
            print("Unsupported model; either implement the model in model/ModelFactory or choose a different model")
            assert (False)
