import torch.nn as nn
import json

from nets.Baseline.MobileNet.mobilenet import MobileNetV2, IOLayer, Block
from nets.SubsetNets.utils.utils import filter_table

with open('nets/SubsetNets/MobileNet/tables/table__Subset_x64_SMobileNet.json', 'r') as fd:
    _g_config_list_sd_mobilenet = json.load(fd)
with open('nets/SubsetNets/MobileNet/tables/table__Subset_x64_SMobileNetLarge.json', 'r') as fd:
    _g_config_list_sd_mobilenetlarge = json.load(fd)


class SMobileNet(MobileNetV2):
    def __init__(self, num_classes=10, keep_factor=1.0):
        super(MobileNetV2, self).__init__()
        self._register_load_state_dict_pre_hook(self.sd_hook)
        self._keep_factor = keep_factor

        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
        self.layer_in = IOLayer(3, int(32*keep_factor), 3, 1, 1)
        self.layers = self._make_layers(in_planes=int(32*keep_factor))
        self.layer_out = IOLayer(int(320*keep_factor), int(1280*keep_factor), 1, 1, 0)
        self.linear = nn.Linear(int(1280*keep_factor), num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1]*(num_blocks-1)
            for stride in strides:
                layers.append(Block(in_planes, int(out_planes*self._keep_factor), expansion, stride))
                in_planes = int(out_planes*self._keep_factor)
        return nn.Sequential(*layers)


    def sd_hook(self, state_dict, *_):
        #cut incomming state dict to correct size
        for name, param in self.named_parameters():
            if name not in state_dict.keys():
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:param.shape[0], 0:param.shape[1], :, :]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:param.shape[0], 0:param.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:param.shape[0]]
            else:
                raise NotImplementedError

    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_sd_mobilenet)
        


class SMobileNetLarge(SMobileNet):
    def __init__(self, num_classes=10, keep_factor=1):
        self.cfg = [(1,  16, 1, 1),
                    (6,  24, 2, 2), # NOTE: change conv1 stride 2 XCHEST
                    (6,  32, 3, 2),
                    (6,  64, 4, 2),
                    (6,  96, 3, 1),
                    (6, 160, 3, 2),
                    (6, 320, 1, 1)]
        super().__init__(num_classes, keep_factor)
        self.layer_in = IOLayer(3, int(32*keep_factor), 3, 2, 1)  # NOTE: change conv1 stride 2 XCHEST

    def forward(self, x):
        out = self.layer_in(x)
        out = self.layers(out)
        out = self.layer_out(out)
        out = nn.functional.avg_pool2d(out, 7) # NOTE: change pooling kernel_size 7 for XCHEST
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_sd_mobilenetlarge)