# ===========================================================================
# Project:      Sparse Model Soups
# File:         models/imagenet.py
# Description:  ImageNet Models
# ===========================================================================

import torchvision

from utilities.utilities import Utilities as Utils


def ResNet50():
    def get_permutation_spec():
        conv = lambda name, p_in, p_out: {f"{name}/kernel": (None, None, p_in, p_out)}
        norm = lambda name, p: {f"{name}/scale": (p,), f"{name}/bias": (p,)}
        dense = lambda name, p_in, p_out: {f"{name}/kernel": (p_in, p_out), f"{name}/bias": (p_out,)}

        # This is for easy blocks that use a residual connection, without any change in the number of channels.
        easyblock = lambda name, p: {
            **conv(f"{name}/ConvBlock_0/Conv_0", p, f"P_{name}_inner1"),
            **norm(f"{name}/ConvBlock_0/BatchNorm_0", f"P_{name}_inner1"),
            #
            **conv(f"{name}/ConvBlock_1/Conv_0", f"P_{name}_inner1", f"P_{name}_inner2"),
            **norm(f"{name}/ConvBlock_1/BatchNorm_0", f"P_{name}_inner2"),
            #
            **conv(f"{name}/ConvBlock_2/Conv_0", f"P_{name}_inner2", p),
            **norm(f"{name}/ConvBlock_2/BatchNorm_0", p),
        }

        # This is for blocks that use a residual connection, but change the number of channels via a Conv.
        shortcutblock = lambda name, p_in, p_out: {
            **conv(f"{name}/ConvBlock_0/Conv_0", p_in, f"P_{name}_inner1"),
            **norm(f"{name}/ConvBlock_0/BatchNorm_0", f"P_{name}_inner1"),
            #
            **conv(f"{name}/ConvBlock_1/Conv_0", f"P_{name}_inner1", f"P_{name}_inner2"),
            **norm(f"{name}/ConvBlock_1/BatchNorm_0", f"P_{name}_inner2"),
            #
            **conv(f"{name}/ConvBlock_2/Conv_0", f"P_{name}_inner2", p_out),
            **norm(f"{name}/ConvBlock_2/BatchNorm_0", p_out),
            #
            **conv(f"{name}/ResNetSkipConnection_0/ConvBlock_0/Conv_0", p_in, p_out),
            **norm(f"{name}/ResNetSkipConnection_0/ConvBlock_0/BatchNorm_0", p_out),
        }

        return Utils.permutation_spec_from_axes_to_perm({
            **conv("layers_0/ConvBlock_0/Conv_0", None, "P_bg0"),
            **norm("layers_0/ConvBlock_0/BatchNorm_0", "P_bg0"),
            #
            **shortcutblock("blockgroups_0/blocks_0", "P_bg0", "P_bg1"),
            **easyblock("blockgroups_0/blocks_1", "P_bg1"),
            **easyblock("blockgroups_0/blocks_2", "P_bg1"),
            #
            **shortcutblock("blockgroups_1/blocks_0", "P_bg1", "P_bg2"),
            **easyblock("blockgroups_1/blocks_1", "P_bg2"),
            **easyblock("blockgroups_1/blocks_2", "P_bg2"),
            **easyblock("blockgroups_1/blocks_3", "P_bg2"),
            #
            **shortcutblock("blockgroups_2/blocks_0", "P_bg2", "P_bg3"),
            **easyblock("blockgroups_2/blocks_1", "P_bg3"),
            **easyblock("blockgroups_2/blocks_2", "P_bg3"),
            **easyblock("blockgroups_2/blocks_3", "P_bg3"),
            **easyblock("blockgroups_2/blocks_4", "P_bg3"),
            **easyblock("blockgroups_2/blocks_5", "P_bg3"),
            #
            **shortcutblock("blockgroups_3/blocks_0", "P_bg3", "P_bg4"),
            **easyblock("blockgroups_3/blocks_1", "P_bg4"),
            **easyblock("blockgroups_3/blocks_2", "P_bg4"),
            #
            **dense("dense", "P_bg4", None),
        })

    m = torchvision.models.resnet50(pretrained=False)
    m.get_permutation_spec = get_permutation_spec

    return m
