
import torch.nn as nn
import json

from nets.Baseline.ResNet.resnet import InputLayer, BasicBlock, Bottleneck, ResNet

from nets.SubsetNets.utils.utils import filter_table

with open('nets/SubsetNets/ResNet/tables/table__Subset_arm_SResNet18.json', 'r') as fd:
    _g_config_list_sd_resnet18 = json.load(fd)
with open('nets/SubsetNets/ResNet/tables/table__Subset_x64_SResNet50.json', 'r') as fd:
    _g_config_list_sd_resnet50 = json.load(fd)


class SInputLayer(InputLayer):
    def __init__(self, planes=64):
        super(InputLayer, self).__init__()
        self.conv1 = nn.Conv2d(3, planes, kernel_size=3,
                                    stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, track_running_stats=False)
        self.relu1 = nn.ReLU()


class SResNet(ResNet):
    def __init__(self, block, num_blocks, num_classes=10, keep_factor=1.0):
        super(ResNet, self).__init__()
        self._register_load_state_dict_pre_hook(self.sd_hook)

        self.in_planes = int(64*keep_factor)

        self.input_layer = SInputLayer(self.in_planes)

        self.layer1 = self._make_layer(block, int(64*keep_factor), num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, int(128*keep_factor), num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, int(256*keep_factor), num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, int(512* keep_factor), num_blocks[3], stride=2)
        self.linear = nn.Linear(int(512*keep_factor)*block.expansion, num_classes)

    def sd_hook(self, state_dict, *_):
        # cut incomming state dict to correct size
        for name, param in self.named_parameters():
            if name not in state_dict.keys():
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1],:,:]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:param.shape[0]]
            else:
                raise NotImplementedError


class SResNet18(SResNet):
    def __init__(self, num_classes=10, keep_factor=1.0):
        super(SResNet18, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes,
                                        keep_factor=keep_factor)
    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_sd_resnet18)


class SResNet50(SResNet):
    def __init__(self, num_classes=10, keep_factor=1.0):
        super(SResNet50, self).__init__(Bottleneck, [3, 4, 6, 3], num_classes=num_classes,
                                        keep_factor=keep_factor)
    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_sd_resnet50)
