import torch.nn as nn
import json
import math

from nets.Baseline.DenseNet.densenet import Bottleneck, Transition, DenseNet


with open('nets/SubsetNets/DenseNet/tables/table__Subset_arm_SDenseNet40.json', 'r') as fd:
    _g_config_list_sdensenet40 = json.load(fd)


class SBottleneck(Bottleneck):
    def __init__(self, in_planes, expansion=4, growthRate=12, keep_factor=1.0):
        super(Bottleneck, self).__init__()
        planes = int(expansion * growthRate * keep_factor)
        growthRate = int((in_planes + growthRate)*keep_factor - int(in_planes*keep_factor))
        in_planes = int(in_planes * keep_factor)

        self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, track_running_stats=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, padding=1, bias=False)


class STransition(Transition):
    def __init__(self, in_planes, out_planes, keep_factor=1.0):
        super(Transition, self).__init__()
        in_planes = int(in_planes*keep_factor)
        out_planes = int(out_planes*keep_factor)
        self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False)
        self.relu = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)


class SDenseNet(DenseNet):
    def __init__(self, depth=22, num_classes=10, growthRate=12, compressionRate=2, keep_factor=1.0):
        super(DenseNet, self).__init__()

        self._register_load_state_dict_pre_hook(self.sd_hook)
        self._keep_factor = keep_factor

        assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
        n = (depth - 4) // 6

        self.growthRate = growthRate
        self.inplanes = growthRate * 2
        self.conv1 = nn.Conv2d(3, int(self.inplanes*keep_factor), kernel_size=3, padding=1, bias=False)

        self.dense1 = self._make_denseblock(SBottleneck, n)
        self.trans1 = self._make_transition(compressionRate)
        self.dense2 = self._make_denseblock(SBottleneck, n)
        self.trans2 = self._make_transition(compressionRate)
        self.dense3 = self._make_denseblock(SBottleneck, n)

        self.bn = nn.BatchNorm2d(int(self.inplanes*keep_factor), track_running_stats=False)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(int(self.inplanes*keep_factor), num_classes)


    #!overrides
    def _make_denseblock(self, block, blocks):
        layers = []
        for i in range(blocks):
            layers.append(block(self.inplanes, growthRate=self.growthRate, keep_factor=self._keep_factor))
            self.inplanes += self.growthRate
        return nn.Sequential(*layers)

    def sd_hook(self, state_dict, *_):
        #cut incomming state dict to correct size
        for name, param in self.named_parameters():
            if name not in state_dict.keys():
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1],:,:]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:param.shape[0],0:param.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:param.shape[0]]
            else:
                raise NotImplementedError
        for name, buffer in self.named_buffers():
            if name not in state_dict.keys():
                continue
            if 'num_batches_tracked' in name:
                continue
            if len(param.size()) == 4:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1],:,:]
            elif len(param.size()) == 2:
                state_dict[name] = state_dict[name][0:buffer.shape[0],0:buffer.shape[1]]
            elif len(param.size()) == 1:
                state_dict[name] = state_dict[name][:buffer.shape[0]]
            else:
                raise NotImplementedError

    #!overrides
    def _make_transition(self, compressionRate):
        inplanes = self.inplanes
        outplanes = math.floor(self.inplanes // compressionRate)
        self.inplanes = outplanes
        return STransition(inplanes, outplanes, keep_factor=self._keep_factor)

from nets.SubsetNets.utils.utils import filter_table


class SDenseNet40(SDenseNet):
    def __init__(self, num_classes=10, keep_factor=1.0):
        super(SDenseNet40, self).__init__(depth=40, num_classes=num_classes, growthRate=12,
                                                compressionRate=2, keep_factor=keep_factor)
    @staticmethod
    def get_keep_factor(relative_resources):
        return filter_table(relative_resources, _g_config_list_sdensenet40)